Skip to content

Commit c87b3ad

Browse files
authored
BLAS: fix blas fptype for complex (#2167)
1 parent 0a81fa1 commit c87b3ad

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

Diff for: enzyme/Enzyme/Utils.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -579,14 +579,18 @@ void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal,
579579
call->setDebugLoc(loc);
580580
}
581581

582-
Type *BlasInfo::fpType(LLVMContext &ctx) const {
582+
Type *BlasInfo::fpType(LLVMContext &ctx, bool to_scalar) const {
583583
if (floatType == "d" || floatType == "D") {
584584
return Type::getDoubleTy(ctx);
585585
} else if (floatType == "s" || floatType == "S") {
586586
return Type::getFloatTy(ctx);
587587
} else if (floatType == "c" || floatType == "C") {
588+
if (to_scalar)
589+
return Type::getFloatTy(ctx);
588590
return VectorType::get(Type::getFloatTy(ctx), 2, false);
589591
} else if (floatType == "z" || floatType == "Z") {
592+
if (to_scalar)
593+
return Type::getDoubleTy(ctx);
590594
return VectorType::get(Type::getDoubleTy(ctx), 2, false);
591595
} else {
592596
assert(false && "Unreachable");

Diff for: enzyme/Enzyme/Utils.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ struct BlasInfo {
678678
std::string function;
679679
bool is64;
680680

681-
llvm::Type *fpType(llvm::LLVMContext &ctx) const;
681+
llvm::Type *fpType(llvm::LLVMContext &ctx, bool to_scalar = false) const;
682682
llvm::IntegerType *intType(llvm::LLVMContext &ctx) const;
683683
};
684684

Diff for: enzyme/tools/enzyme-tblgen/blasTAUpdater.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ inline void emit_BLASTypes(raw_ostream &os) {
1616
"\"cublas\" && StringRef(blas.suffix).contains(\"v2\");\n";
1717

1818
os << "TypeTree ttFloat;\n"
19-
<< "llvm::Type *floatType = blas.fpType(call.getContext()); \n"
19+
<< "llvm::Type *floatType = blas.fpType(call.getContext(), true); \n"
2020
<< "if (byRefFloat) {\n"
2121
<< " ttFloat.insert({-1},BaseType::Pointer);\n"
2222
<< " ttFloat.insert({-1,0},floatType);\n"

0 commit comments

Comments
 (0)