Skip to content

Commit

Permalink
[SPIR-V] Handle vectors passed to asuint (#6953)
Browse files Browse the repository at this point in the history
`asuint` should be able to take vectors in addition to scalar values.
Previously, it would be lowered as a bitcast from the input value to a
vector of uints with a width of 2, which is not large enough if the
input value is larger than a scalar value. In order to handle, for
example, an input value that is a `double4`, we instead perform a
component-wise bitcast.

Fixes #6735
  • Loading branch information
cassiebeckley authored Oct 29, 2024
1 parent 6a7eae1 commit 5704c47
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 18 deletions.
120 changes: 103 additions & 17 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11419,6 +11419,78 @@ SpirvEmitter::processIntrinsicAllOrAny(const CallExpr *callExpr,
return nullptr;
}

void SpirvEmitter::splitDouble(SpirvInstruction *value, SourceLocation loc,
SourceRange range, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits) {
const QualType uintType = astContext.UnsignedIntTy;
const QualType uintVec2Type = astContext.getExtVectorType(uintType, 2);

SpirvInstruction *uints = spvBuilder.createUnaryOp(
spv::Op::OpBitcast, uintVec2Type, value, loc, range);

lowbits = spvBuilder.createCompositeExtract(uintType, uints, {0}, loc, range);
highbits =
spvBuilder.createCompositeExtract(uintType, uints, {1}, loc, range);
}

void SpirvEmitter::splitDoubleVector(QualType elemType, uint32_t count,
QualType outputType,
SpirvInstruction *value,
SourceLocation loc, SourceRange range,
SpirvInstruction *&lowbits,
SpirvInstruction *&highbits) {
llvm::SmallVector<SpirvInstruction *, 4> lowElems;
llvm::SmallVector<SpirvInstruction *, 4> highElems;

for (uint32_t i = 0; i < count; ++i) {
SpirvInstruction *elem =
spvBuilder.createCompositeExtract(elemType, value, {i}, loc, range);
SpirvInstruction *lowbitsResult = nullptr;
SpirvInstruction *highbitsResult = nullptr;
splitDouble(elem, loc, range, lowbitsResult, highbitsResult);
lowElems.push_back(lowbitsResult);
highElems.push_back(highbitsResult);
}

lowbits =
spvBuilder.createCompositeConstruct(outputType, lowElems, loc, range);
highbits =
spvBuilder.createCompositeConstruct(outputType, highElems, loc, range);
}

void SpirvEmitter::splitDoubleMatrix(QualType elemType, uint32_t rowCount,
uint32_t colCount, QualType outputType,
SpirvInstruction *value,
SourceLocation loc, SourceRange range,
SpirvInstruction *&lowbits,
SpirvInstruction *&highbits) {

llvm::SmallVector<SpirvInstruction *, 4> lowElems;
llvm::SmallVector<SpirvInstruction *, 4> highElems;

QualType colType = astContext.getExtVectorType(elemType, colCount);

const QualType uintType = astContext.UnsignedIntTy;
const QualType outputColType =
astContext.getExtVectorType(uintType, colCount);

for (uint32_t i = 0; i < rowCount; ++i) {
SpirvInstruction *column =
spvBuilder.createCompositeExtract(colType, value, {i}, loc, range);
SpirvInstruction *lowbitsResult = nullptr;
SpirvInstruction *highbitsResult = nullptr;
splitDoubleVector(elemType, colCount, outputColType, column, loc, range,
lowbitsResult, highbitsResult);
lowElems.push_back(lowbitsResult);
highElems.push_back(highbitsResult);
}

lowbits =
spvBuilder.createCompositeConstruct(outputType, lowElems, loc, range);
highbits =
spvBuilder.createCompositeConstruct(outputType, highElems, loc, range);
}

SpirvInstruction *
SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
// This function handles the following intrinsics:
Expand Down Expand Up @@ -11523,23 +11595,37 @@ SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
}
case 3: {
// Handling Method 6.
auto *value = doExpr(arg0);
auto *lowbits = doExpr(callExpr->getArg(1));
auto *highbits = doExpr(callExpr->getArg(2));
const auto uintType = astContext.UnsignedIntTy;
const auto uintVec2Type = astContext.getExtVectorType(uintType, 2);
auto *vecResult = spvBuilder.createUnaryOp(spv::Op::OpBitcast, uintVec2Type,
value, loc, range);
spvBuilder.createStore(
lowbits,
spvBuilder.createCompositeExtract(uintType, vecResult, {0},
arg0->getLocStart(), range),
loc, range);
spvBuilder.createStore(
highbits,
spvBuilder.createCompositeExtract(uintType, vecResult, {1},
arg0->getLocStart(), range),
loc, range);
const Expr *arg1 = callExpr->getArg(1);
const Expr *arg2 = callExpr->getArg(2);

SpirvInstruction *value = doExpr(arg0);
SpirvInstruction *lowbits = doExpr(arg1);
SpirvInstruction *highbits = doExpr(arg2);

QualType elemType = QualType();
uint32_t rowCount = 0;
uint32_t colCount = 0;

SpirvInstruction *lowbitsResult = nullptr;
SpirvInstruction *highbitsResult = nullptr;

if (isScalarType(argType)) {
splitDouble(value, loc, range, lowbitsResult, highbitsResult);
} else if (isVectorType(argType, &elemType, &rowCount)) {
splitDoubleVector(elemType, rowCount, arg1->getType(), value, loc, range,
lowbitsResult, highbitsResult);
} else if (isMxNMatrix(argType, &elemType, &rowCount, &colCount)) {
splitDoubleMatrix(elemType, rowCount, colCount, arg1->getType(), value,
loc, range, lowbitsResult, highbitsResult);
} else {
llvm_unreachable(
"unexpected argument type is not scalar, vector, or matrix");
return nullptr;
}

spvBuilder.createStore(lowbits, lowbitsResult, loc, range);
spvBuilder.createStore(highbits, highbitsResult, loc, range);

return nullptr;
}
default:
Expand Down
39 changes: 39 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,45 @@ class SpirvEmitter : public ASTConsumer {
/// the Vulkan memory model capability has been added to the module.
bool UpgradeToVulkanMemoryModelIfNeeded(std::vector<uint32_t> *module);

// Splits the `value`, which must be a 64-bit scalar, into two 32-bit wide
// uints, stored in `lowbits` and `highbits`.
void splitDouble(SpirvInstruction *value, SourceLocation loc,
SourceRange range, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits);

// Splits the value, which must be a vector with element type `elemType` and
// size `count`, into two composite values of size `count` and type
// `outputType`. The elements are split component-wise: the vector
// {0x0123456789abcdef, 0x0123456789abcdef} is split into `lowbits`
// {0x89abcdef, 0x89abcdef} and and `highbits` {0x01234567, 0x01234567}.
void splitDoubleVector(QualType elemType, uint32_t count, QualType outputType,
SpirvInstruction *value, SourceLocation loc,
SourceRange range, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits);

// Splits the value, which must be a matrix with element type `elemType` and
// dimensions `rowCount` and `colCount`, into two composite values of
// dimensions `rowCount` and `colCount`. The elements are split
// component-wise: the matrix
//
// { 0x0123456789abcdef, 0x0123456789abcdef,
// 0x0123456789abcdef, 0x0123456789abcdef }
//
// is split into `lowbits`
//
// { 0x89abcdef, 0x89abcdef,
// 0x89abcdef, 0x89abcdef }
//
// and `highbits`
//
// { 0x012345678, 0x012345678,
// 0x012345678, 0x012345678 }.
void splitDoubleMatrix(QualType elemType, uint32_t rowCount,
uint32_t colCount, QualType outputType,
SpirvInstruction *value, SourceLocation loc,
SourceRange range, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits);

public:
/// \brief Wrapper method to create a fatal error message and report it
/// in the diagnostic engine associated with this consumer.
Expand Down
135 changes: 134 additions & 1 deletion tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,141 @@ void main() {
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %double %value
// CHECK-NEXT: [[resultVec:%[0-9]+]] = OpBitcast %v2uint [[value]]
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec]] 0
// CHECK-NEXT: OpStore %lowbits [[resultVec0]]
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec]] 1
// CHECK-NEXT: OpStore %lowbits [[resultVec0]]
// CHECK-NEXT: OpStore %highbits [[resultVec1]]
asuint(value, lowbits, highbits);

double3 value3;
uint3 lowbits3;
uint3 highbits3;
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %v3double %value3
// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[value]] 0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]]
// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0
// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1
// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[value]] 1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]]
// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0
// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1
// CHECK-NEXT: [[value2:%[0-9]+]] = OpCompositeExtract %double [[value]] 2
// CHECK-NEXT: [[resultVec2:%[0-9]+]] = OpBitcast %v2uint [[value2]]
// CHECK-NEXT: [[low2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 0
// CHECK-NEXT: [[high2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 1
// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %v3uint [[low0]] [[low1]] [[low2]]
// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %v3uint [[high0]] [[high1]] [[high2]]
// CHECK-NEXT: OpStore %lowbits3 [[low]]
// CHECK-NEXT: OpStore %highbits3 [[high]]
asuint(value3, lowbits3, highbits3);

double2x2 value2x2;
uint2x2 lowbits2x2;
uint2x2 highbits2x2;
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %mat2v2double %value2x2
// CHECK-NEXT: [[row0:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 0
// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[row0]] 0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]]
// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0
// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1
// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[row0]] 1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]]
// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0
// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1
// CHECK-NEXT: [[lowRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[low0]] [[low1]]
// CHECK-NEXT: [[highRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[high0]] [[high1]]
// CHECK-NEXT: [[row1:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 1
// CHECK-NEXT: [[value2:%[0-9]+]] = OpCompositeExtract %double [[row1]] 0
// CHECK-NEXT: [[resultVec2:%[0-9]+]] = OpBitcast %v2uint [[value2]]
// CHECK-NEXT: [[low2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 0
// CHECK-NEXT: [[high2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 1
// CHECK-NEXT: [[value3:%[0-9]+]] = OpCompositeExtract %double [[row1]] 1
// CHECK-NEXT: [[resultVec3:%[0-9]+]] = OpBitcast %v2uint [[value3]]
// CHECK-NEXT: [[low3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 0
// CHECK-NEXT: [[high3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 1
// CHECK-NEXT: [[lowRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[low2]] [[low3]]
// CHECK-NEXT: [[highRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[high2]] [[high3]]
// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_2 [[lowRow0]] [[lowRow1]]
// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_2 [[highRow0]] [[highRow1]]
// CHECK-NEXT: OpStore %lowbits2x2 [[low]]
// CHECK-NEXT: OpStore %highbits2x2 [[high]]
asuint(value2x2, lowbits2x2, highbits2x2);

double3x2 value3x2;
uint3x2 lowbits3x2;
uint3x2 highbits3x2;
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %mat3v2double %value3x2
// CHECK-NEXT: [[row0:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 0
// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[row0]] 0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]]
// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0
// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1
// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[row0]] 1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]]
// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0
// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1
// CHECK-NEXT: [[lowRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[low0]] [[low1]]
// CHECK-NEXT: [[highRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[high0]] [[high1]]
// CHECK-NEXT: [[row1:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 1
// CHECK-NEXT: [[value2:%[0-9]+]] = OpCompositeExtract %double [[row1]] 0
// CHECK-NEXT: [[resultVec2:%[0-9]+]] = OpBitcast %v2uint [[value2]]
// CHECK-NEXT: [[low2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 0
// CHECK-NEXT: [[high2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 1
// CHECK-NEXT: [[value3:%[0-9]+]] = OpCompositeExtract %double [[row1]] 1
// CHECK-NEXT: [[resultVec3:%[0-9]+]] = OpBitcast %v2uint [[value3]]
// CHECK-NEXT: [[low3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 0
// CHECK-NEXT: [[high3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 1
// CHECK-NEXT: [[lowRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[low2]] [[low3]]
// CHECK-NEXT: [[highRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[high2]] [[high3]]
// CHECK-NEXT: [[row2:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 2
// CHECK-NEXT: [[value4:%[0-9]+]] = OpCompositeExtract %double [[row2]] 0
// CHECK-NEXT: [[resultVec4:%[0-9]+]] = OpBitcast %v2uint [[value4]]
// CHECK-NEXT: [[low4:%[0-9]+]] = OpCompositeExtract %uint [[resultVec4]] 0
// CHECK-NEXT: [[high4:%[0-9]+]] = OpCompositeExtract %uint [[resultVec4]] 1
// CHECK-NEXT: [[value5:%[0-9]+]] = OpCompositeExtract %double [[row2]] 1
// CHECK-NEXT: [[resultVec5:%[0-9]+]] = OpBitcast %v2uint [[value5]]
// CHECK-NEXT: [[low5:%[0-9]+]] = OpCompositeExtract %uint [[resultVec5]] 0
// CHECK-NEXT: [[high5:%[0-9]+]] = OpCompositeExtract %uint [[resultVec5]] 1
// CHECK-NEXT: [[lowRow2:%[0-9]+]] = OpCompositeConstruct %v2uint [[low4]] [[low5]]
// CHECK-NEXT: [[highRow2:%[0-9]+]] = OpCompositeConstruct %v2uint [[high4]] [[high5]]
// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_3 [[lowRow0]] [[lowRow1]] [[lowRow2]]
// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_3 [[highRow0]] [[highRow1]] [[highRow2]]
// CHECK-NEXT: OpStore %lowbits3x2 [[low]]
// CHECK-NEXT: OpStore %highbits3x2 [[high]]
asuint(value3x2, lowbits3x2, highbits3x2);

double2x1 value2x1;
uint2x1 lowbits2x1;
uint2x1 highbits2x1;
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %v2double %value2x1
// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[value]] 0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]]
// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0
// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1
// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[value]] 1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]]
// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0
// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1
// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %v2uint [[low0]] [[low1]]
// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %v2uint [[high0]] [[high1]]
// CHECK-NEXT: OpStore %lowbits2x1 [[low]]
// CHECK-NEXT: OpStore %highbits2x1 [[high]]
asuint(value2x1, lowbits2x1, highbits2x1);

double1x2 value1x2;
uint1x2 lowbits1x2;
uint1x2 highbits1x2;
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %v2double %value1x2
// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[value]] 0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]]
// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0
// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1
// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[value]] 1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]]
// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0
// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1
// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %v2uint [[low0]] [[low1]]
// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %v2uint [[high0]] [[high1]]
// CHECK-NEXT: OpStore %lowbits1x2 [[low]]
// CHECK-NEXT: OpStore %highbits1x2 [[high]]
asuint(value1x2, lowbits1x2, highbits1x2);
}

0 comments on commit 5704c47

Please sign in to comment.