Skip to content

Commit

Permalink
Reorder output parameters and add llvm_unreachable assert
Browse files Browse the repository at this point in the history
  • Loading branch information
cassiebeckley committed Oct 29, 2024
1 parent 4668b5d commit b82dec1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 24 deletions.
33 changes: 18 additions & 15 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11415,10 +11415,9 @@ SpirvEmitter::processIntrinsicAllOrAny(const CallExpr *callExpr,
return nullptr;
}

void SpirvEmitter::splitDouble(SpirvInstruction *value,
SpirvInstruction *&lowbits,
SpirvInstruction *&highbits, SourceLocation loc,
SourceRange range) {
void SpirvEmitter::splitDouble(SpirvInstruction *value, SourceLocation loc,
SourceRange range, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits) {
const QualType uintType = astContext.UnsignedIntTy;
const QualType uintVec2Type = astContext.getExtVectorType(uintType, 2);

Expand All @@ -11433,9 +11432,9 @@ void SpirvEmitter::splitDouble(SpirvInstruction *value,
void SpirvEmitter::splitDoubleVector(QualType elemType, uint32_t count,
QualType outputType,
SpirvInstruction *value,
SourceLocation loc, SourceRange range,
SpirvInstruction *&lowbits,
SpirvInstruction *&highbits,
SourceLocation loc, SourceRange range) {
SpirvInstruction *&highbits) {
llvm::SmallVector<SpirvInstruction *, 4> lowElems;
llvm::SmallVector<SpirvInstruction *, 4> highElems;

Expand All @@ -11444,7 +11443,7 @@ void SpirvEmitter::splitDoubleVector(QualType elemType, uint32_t count,
spvBuilder.createCompositeExtract(elemType, value, {i}, loc, range);
SpirvInstruction *lowbitsResult = nullptr;
SpirvInstruction *highbitsResult = nullptr;
splitDouble(elem, lowbitsResult, highbitsResult, loc, range);
splitDouble(elem, loc, range, lowbitsResult, highbitsResult);
lowElems.push_back(lowbitsResult);
highElems.push_back(highbitsResult);
}
Expand All @@ -11458,9 +11457,9 @@ void SpirvEmitter::splitDoubleVector(QualType elemType, uint32_t count,
void SpirvEmitter::splitDoubleMatrix(QualType elemType, uint32_t rowCount,
uint32_t colCount, QualType outputType,
SpirvInstruction *value,
SourceLocation loc, SourceRange range,
SpirvInstruction *&lowbits,
SpirvInstruction *&highbits,
SourceLocation loc, SourceRange range) {
SpirvInstruction *&highbits) {

llvm::SmallVector<SpirvInstruction *, 4> lowElems;
llvm::SmallVector<SpirvInstruction *, 4> highElems;
Expand All @@ -11476,8 +11475,8 @@ void SpirvEmitter::splitDoubleMatrix(QualType elemType, uint32_t rowCount,
spvBuilder.createCompositeExtract(colType, value, {i}, loc, range);
SpirvInstruction *lowbitsResult = nullptr;
SpirvInstruction *highbitsResult = nullptr;
splitDoubleVector(elemType, colCount, outputColType, column, lowbitsResult,
highbitsResult, loc, range);
splitDoubleVector(elemType, colCount, outputColType, column, loc, range,
lowbitsResult, highbitsResult);
lowElems.push_back(lowbitsResult);
highElems.push_back(highbitsResult);
}
Expand Down Expand Up @@ -11607,13 +11606,17 @@ SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
SpirvInstruction *highbitsResult = nullptr;

if (isScalarType(argType)) {
splitDouble(value, lowbitsResult, highbitsResult, loc, range);
splitDouble(value, loc, range, lowbitsResult, highbitsResult);
} else if (isVectorType(argType, &elemType, &rowCount)) {
splitDoubleVector(elemType, rowCount, arg1->getType(), value,
lowbitsResult, highbitsResult, loc, range);
splitDoubleVector(elemType, rowCount, arg1->getType(), value, loc, range,
lowbitsResult, highbitsResult);
} else if (isMxNMatrix(argType, &elemType, &rowCount, &colCount)) {
splitDoubleMatrix(elemType, rowCount, colCount, arg1->getType(), value,
lowbitsResult, highbitsResult, loc, range);
loc, range, lowbitsResult, highbitsResult);
} else {
llvm_unreachable(
"unexpected argument type is not scalar, vector, or matrix");
return nullptr;
}

spvBuilder.createStore(lowbits, lowbitsResult, loc, range);
Expand Down
18 changes: 9 additions & 9 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1314,19 +1314,19 @@ class SpirvEmitter : public ASTConsumer {

// Splits the `value`, which must be a 64-bit scalar, into two 32-bit wide
// uints, stored in `lowbits` and `highbits`.
void splitDouble(SpirvInstruction *value, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits, SourceLocation loc,
SourceRange range);
void splitDouble(SpirvInstruction *value, SourceLocation loc,
SourceRange range, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits);

// Splits the value, which must be a vector with element type `elemType` and
// size `count`, into two composite values of size `count` and type
// `outputType`. The elements are split component-wise: the vector
// {0x0123456789abcdef, 0x0123456789abcdef} is split into `lowbits`
// {0x89abcdef, 0x89abcdef} and and `highbits` {0x01234567, 0x01234567}.
void splitDoubleVector(QualType elemType, uint32_t count, QualType outputType,
SpirvInstruction *value, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits, SourceLocation loc,
SourceRange range);
SpirvInstruction *value, SourceLocation loc,
SourceRange range, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits);

// Splits the value, which must be a matrix with element type `elemType` and
// dimensions `rowCount` and `colCount`, into two composite values of
Expand All @@ -1347,9 +1347,9 @@ class SpirvEmitter : public ASTConsumer {
// 0x012345678, 0x012345678 }.
void splitDoubleMatrix(QualType elemType, uint32_t rowCount,
uint32_t colCount, QualType outputType,
SpirvInstruction *value, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits, SourceLocation loc,
SourceRange range);
SpirvInstruction *value, SourceLocation loc,
SourceRange range, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits);

public:
/// \brief Wrapper method to create a fatal error message and report it
Expand Down

0 comments on commit b82dec1

Please sign in to comment.