Skip to content

Commit d18f50c

Browse files
authored
Cublas byref fixup (#1877)
1 parent 2851f39 commit d18f50c

File tree

7 files changed

+93
-65
lines changed

7 files changed

+93
-65
lines changed

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,8 +1117,10 @@ void TypeAnalyzer::updateAnalysis(Value *Val, TypeTree Data, Value *Origin) {
11171117
}
11181118
if (auto I = dyn_cast<Instruction>(Val)) {
11191119
EmitFailure("IllegalUpdateAnalysis", I->getDebugLoc(), I, ss.str());
1120+
exit(1);
11201121
} else if (auto I = dyn_cast_or_null<Instruction>(Origin)) {
11211122
EmitFailure("IllegalUpdateAnalysis", I->getDebugLoc(), I, ss.str());
1123+
exit(1);
11221124
} else {
11231125
llvm::errs() << ss.str() << "\n";
11241126
}

enzyme/test/Integration/ReverseMode/cublas.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ void my_dgemv(cublasHandle_t *handle, cublasOperation_t trans, int M, int N,
2525
double alpha, double *__restrict__ A, int lda,
2626
double *__restrict__ X, int incx, double beta,
2727
double *__restrict__ Y, int incy) {
28-
cublasDgemv(handle, trans, M, N, alpha, A, lda, X, incx, beta, Y, incy);
28+
cublasDgemv(handle, trans, M, N, &alpha, A, lda, X, incx, &beta, Y, incy);
2929
inDerivative = true;
3030
}
3131

3232
void ow_dgemv(cublasHandle_t *handle, cublasOperation_t trans, int M, int N,
3333
double alpha, double *A, int lda, double *X, int incx,
3434
double beta, double *Y, int incy) {
35-
cublasDgemv(handle, trans, M, N, alpha, A, lda, X, incx, beta, Y, incy);
35+
cublasDgemv(handle, trans, M, N, &alpha, A, lda, X, incx, &beta, Y, incy);
3636
inDerivative = true;
3737
}
3838

@@ -55,8 +55,8 @@ void my_dgemm(cublasHandle_t *handle, cublasOperation_t transA,
5555
cublasOperation_t transB, int M, int N, int K, double alpha,
5656
double *__restrict__ A, int lda, double *__restrict__ B, int ldb,
5757
double beta, double *__restrict__ C, int ldc) {
58-
cublasDgemm(handle, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C,
59-
ldc);
58+
cublasDgemm(handle, transA, transB, M, N, K, &alpha, A, lda, B, ldb, &beta, C,
59+
ldc);
6060
inDerivative = true;
6161
}
6262

@@ -212,10 +212,10 @@ static void gemvTests() {
212212

213213
inDerivative = true;
214214
// dC = alpha * X * transpose(Y) + A
215-
cublasDger(handle, M, N, alpha, trans ? B : dC, trans ? incB : incC,
216-
trans ? dC : B, trans ? incC : incB, dA, lda);
215+
cublasDger(handle, M, N, &alpha, trans ? B : dC, trans ? incB : incC,
216+
trans ? dC : B, trans ? incC : incB, dA, lda);
217217
// dY = beta * dY
218-
cublasDscal(handle, trans ? N : M, beta, dC, incC);
218+
cublasDscal(handle, trans ? N : M, &beta, dC, incC);
219219

220220
checkTest(Test);
221221

@@ -241,15 +241,16 @@ static void gemvTests() {
241241

242242
inDerivative = true;
243243
// dC = alpha * X * transpose(Y) + A
244-
cublasDger(handle, M, N, alpha, trans ? B : dC, trans ? incB : incC,
245-
trans ? dC : B, trans ? incC : incB, dA, lda);
244+
cublasDger(handle, M, N, &alpha, trans ? B : dC, trans ? incB : incC,
245+
trans ? dC : B, trans ? incC : incB, dA, lda);
246246

247247
// dB = alpha * trans(A) * dC + dB
248-
cublasDgemv(handle, transpose(transA), M, N, alpha, A, lda, dC, incC,
249-
1.0, dB, incB);
248+
double c1 = 1.0;
249+
cublasDgemv(handle, transpose(transA), M, N, &alpha, A, lda, dC, incC,
250+
&c1, dB, incB);
250251

251252
// dY = beta * dY
252-
cublasDscal(handle, trans ? N : M, beta, dC, incC);
253+
cublasDscal(handle, trans ? N : M, &beta, dC, incC);
253254

254255
checkTest(Test);
255256

@@ -391,7 +392,9 @@ static void gemmTests() {
391392
transB_bool ? A : dC, transB_bool ? lda : incC, 1.0, dB, incB);
392393

393394
// TODO we are currently faking support here, this needs to be actually implemented
394-
cublasDlascl(handle, (cublasOperation_t)'G', 0, 0, 1.0, beta, M, N, dC, incC, 0);
395+
double c10 = 1.0;
396+
cublasDlascl(handle, (cublasOperation_t)'G', 0, 0, &c10, &beta, M, N,
397+
dC, incC, 0);
395398

396399
checkTest(Test);
397400

enzyme/test/Integration/blasinfra.h

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -931,14 +931,12 @@ __attribute__((noinline)) void dlacpy(char *uplo_p, int *M_p, int *N_p, double *
931931

932932
__attribute__((noinline)) cublasStatus_t
933933
cublasDlascl(cublasHandle_t *handle, cublasOperation_t type, int KL, int KU,
934-
double cfrom, double cto, int M, int N, double *A, int lda, int info) {
935-
calls.push_back((BlasCall){ABIType::CUBLAS,handle,
936-
inDerivative, CallType::LASCL,
937-
A, UNUSED_POINTER, UNUSED_POINTER,
938-
cfrom, cto,
939-
CUBLAS_LAYOUT,
940-
(char)type, UNUSED_TRANS,
941-
M, N, UNUSED_INT, lda, KL, KU});
934+
double *cfrom, double *cto, int M, int N, double *A, int lda,
935+
int info) {
936+
calls.push_back((BlasCall){ABIType::CUBLAS, handle, inDerivative,
937+
CallType::LASCL, A, UNUSED_POINTER, UNUSED_POINTER,
938+
*cfrom, *cto, CUBLAS_LAYOUT, (char)type,
939+
UNUSED_TRANS, M, N, UNUSED_INT, lda, KL, KU});
942940
return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
943941
}
944942
__attribute__((noinline)) cublasStatus_t cublasDlacpy(cublasHandle_t *handle, char uplo, int M,
@@ -1054,47 +1052,57 @@ __attribute__((noinline)) cublasStatus_t cublasDaxpy(cublasHandle_t *handle,
10541052
}
10551053
__attribute__((noinline)) cublasStatus_t
10561054
cublasDgemv(cublasHandle_t *handle, cublasOperation_t trans, int M, int N,
1057-
double alpha, double *A, int lda, double *X, int incx, double beta,
1058-
double *Y, int incy) {
1059-
BlasCall call = {ABIType::CUBLAS,handle,
1060-
inDerivative, CallType::GEMV, Y, A, X, alpha, beta, CUBLAS_LAYOUT,
1061-
(char)trans, UNUSED_TRANS, M, N, UNUSED_INT, lda, incx, incy};
1055+
double *alpha, double *A, int lda, double *X, int incx,
1056+
double *beta, double *Y, int incy) {
1057+
BlasCall call = {ABIType::CUBLAS,
1058+
handle,
1059+
inDerivative,
1060+
CallType::GEMV,
1061+
Y,
1062+
A,
1063+
X,
1064+
*alpha,
1065+
*beta,
1066+
CUBLAS_LAYOUT,
1067+
(char)trans,
1068+
UNUSED_TRANS,
1069+
M,
1070+
N,
1071+
UNUSED_INT,
1072+
lda,
1073+
incx,
1074+
incy};
10621075
calls.push_back(call);
10631076
return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
10641077
}
10651078
__attribute__((noinline)) cublasStatus_t
10661079
cublasDgemm(cublasHandle_t *handle, cublasOperation_t transA,
1067-
cublasOperation_t transB, int M, int N, int K, double alpha,
1068-
double *A, int lda, double *B, int ldb, double beta, double *C,
1069-
int ldc) {
1070-
calls.push_back((BlasCall){ABIType::CUBLAS,handle,inDerivative, CallType::GEMM, C, A, B, alpha,
1071-
beta,
1072-
CUBLAS_LAYOUT,
1073-
(char)transA, (char)transB, M, N, K, lda,
1074-
ldb, ldc});
1080+
cublasOperation_t transB, int M, int N, int K, double *alpha,
1081+
double *A, int lda, double *B, int ldb, double *beta, double *C,
1082+
int ldc) {
1083+
calls.push_back((BlasCall){ABIType::CUBLAS, handle, inDerivative,
1084+
CallType::GEMM, C, A, B, *alpha, *beta,
1085+
CUBLAS_LAYOUT, (char)transA, (char)transB, M, N, K,
1086+
lda, ldb, ldc});
10751087
return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
10761088
}
10771089
__attribute__((noinline)) cublasStatus_t
1078-
cublasDscal(cublasHandle_t *handle, int N, double alpha, double *X, int incX) {
1090+
cublasDscal(cublasHandle_t *handle, int N, double *alpha, double *X, int incX) {
10791091
calls.push_back((BlasCall){
1080-
ABIType::CUBLAS,handle,inDerivative, CallType::SCAL, X, UNUSED_POINTER, UNUSED_POINTER, alpha,
1081-
UNUSED_DOUBLE,
1082-
CUBLAS_LAYOUT,
1083-
UNUSED_TRANS, UNUSED_TRANS, N, UNUSED_INT,
1084-
UNUSED_INT, incX, UNUSED_INT, UNUSED_INT});
1092+
ABIType::CUBLAS, handle, inDerivative, CallType::SCAL, X, UNUSED_POINTER,
1093+
UNUSED_POINTER, *alpha, UNUSED_DOUBLE, CUBLAS_LAYOUT, UNUSED_TRANS,
1094+
UNUSED_TRANS, N, UNUSED_INT, UNUSED_INT, incX, UNUSED_INT, UNUSED_INT});
10851095
return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
10861096
}
10871097

10881098
// A = alpha * X * transpose(Y) + A
10891099
__attribute__((noinline)) cublasStatus_t
1090-
cublasDger(cublasHandle_t *handle, int M, int N, double alpha, double *X,
1091-
int incX, double *Y, int incY, double *A, int lda) {
1092-
calls.push_back((BlasCall){ABIType::CUBLAS,handle,inDerivative, CallType::GER, A, X, Y, alpha,
1093-
UNUSED_DOUBLE,
1094-
CUBLAS_LAYOUT,
1095-
UNUSED_TRANS,
1096-
UNUSED_TRANS, M, N, UNUSED_INT, incX, incY,
1097-
lda});
1100+
cublasDger(cublasHandle_t *handle, int M, int N, double *alpha, double *X,
1101+
int incX, double *Y, int incY, double *A, int lda) {
1102+
calls.push_back((BlasCall){ABIType::CUBLAS, handle, inDerivative,
1103+
CallType::GER, A, X, Y, *alpha, UNUSED_DOUBLE,
1104+
CUBLAS_LAYOUT, UNUSED_TRANS, UNUSED_TRANS, M, N,
1105+
UNUSED_INT, incX, incY, lda});
10981106
return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
10991107
}
11001108

enzyme/tools/enzyme-tblgen/blas-tblgen.cpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) {
267267

268268
os << " const bool byRef = blas.prefix == \"\" || blas.prefix == "
269269
"\"cublas_\";\n";
270+
os << "const bool byRefFloat = byRef || blas.prefix == \"cublas\";\n";
271+
os << "(void)byRefFloat;\n";
270272
os << " const bool cblas = blas.prefix == \"cblas_\";\n";
271273
os << " const bool cublas = blas.prefix == \"cublas_\" || blas.prefix == "
272274
"\"cublas\";\n";
@@ -355,7 +357,7 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) {
355357
auto ty = argTypeMap.lookup(actArgs[i]);
356358
os << " if (";
357359
if (ty == ArgType::fp)
358-
os << "byRef && ";
360+
os << "byRefFloat && ";
359361
os << "active_" << name << ") {\n"
360362
<< " auto shadow_" << name << " = gutils->invertPointerM(orig_"
361363
<< name << ", BuilderZ);\n"
@@ -385,7 +387,7 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) {
385387
auto ty = argTypeMap.lookup(actArgs[i]);
386388
os << " if (";
387389
if (ty == ArgType::fp)
388-
os << "byRef && ";
390+
os << "byRefFloat && ";
389391
os << "active_" << name << ") {\n"
390392
<< " rt_inactive_" << name << " = BuilderZ.CreateOr(rt_inactive_"
391393
<< name << ", rt_inactive_out, \"rt.inactive.\" \"" << name << "\");\n"
@@ -406,7 +408,8 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) {
406408
}
407409
}
408410
if (!hasFP)
409-
os << " Type* blasFPType = byRef ? (Type*)PointerType::getUnqual(fpType) "
411+
os << " Type* blasFPType = byRefFloat ? "
412+
"(Type*)PointerType::getUnqual(fpType) "
410413
": (Type*)fpType;\n";
411414

412415
bool hasChar = false;
@@ -609,25 +612,29 @@ void emit_extract_calls(const TGPattern &pattern, raw_ostream &os) {
609612
<< " if (Mode != DerivativeMode::ForwardModeSplit)\n"
610613
<< " cacheval = lookup(cacheval, Builder2);\n"
611614
<< " }\n"
612-
<< "\n"
613-
<< " if (byRef) {\n";
615+
<< "\n";
614616

615617
for (size_t i = 0; i < nameVec.size(); i++) {
616618
auto ty = typeMap.lookup(i);
617619
auto name = nameVec[i];
618620
// this branch used "true_" << name everywhere instead of "arg_" << name
619621
// before. probably randomly, but check to make sure
620622
if (ty == ArgType::len || ty == ArgType::vincInc || ty == ArgType::mldLD) {
623+
os << " if (byRef) {\n";
621624
extract_scalar(name, "intType", os);
625+
os << " }\n";
622626
} else if (ty == ArgType::fp) {
627+
os << " if (byRefFloat) {\n";
623628
extract_scalar(name, "fpType", os);
629+
os << " }\n";
624630
} else if (ty == ArgType::trans) {
625631
// we are in the byRef branch and trans only exist in lv23.
626632
// So just unconditionally asume that no layout exist and use i-1
633+
os << " if (byRef) {\n";
627634
extract_scalar(name, "charType", os);
635+
os << " }\n";
628636
}
629637
}
630-
os << " }\n";
631638

632639
std::string input_var = "";
633640
size_t actVar = 0;
@@ -1207,8 +1214,8 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos,
12071214
} else if (Def->isSubClassOf("Constant")) {
12081215
auto val = Def->getValueAsString("value");
12091216
os << "{to_blas_fp_callconv(Builder2, ConstantFP::get(fpType, " << val
1210-
<< "), byRef, blasFPType, allocationBuilder, \"constant.fp." << val
1211-
<< "\")}";
1217+
<< "), byRefFloat, blasFPType, allocationBuilder, \"constant.fp."
1218+
<< val << "\")}";
12121219
} else if (Def->isSubClassOf("Char")) {
12131220
auto val = Def->getValueAsString("value");
12141221
if (val == "N") {
@@ -1382,7 +1389,7 @@ void emit_fret_call(StringRef dfnc_name, StringRef argName, StringRef name,
13821389
<< bb << ".CreateCall(derivcall_" << dfnc_name << ", " << argName
13831390
<< ", Defs));\n";
13841391
}
1385-
os << " if (byRef) {\n"
1392+
os << " if (byRefFloat) {\n"
13861393
<< " ((DiffeGradientUtils *)gutils)"
13871394
<< "->addToInvertedPtrDiffe(&call, nullptr, fpType, 0, "
13881395
<< "(called->getParent()->getDataLayout().getTypeSizeInBits(fpType)/8), "
@@ -1401,7 +1408,7 @@ void emit_runtime_condition(DagInit *ruleDag, StringRef name, StringRef tab,
14011408
StringRef B, bool isFP, raw_ostream &os) {
14021409
os << tab << "BasicBlock *nextBlock_" << name << " = nullptr;\n"
14031410
<< tab << "if (EnzymeRuntimeActivityCheck && cacheMode"
1404-
<< (isFP ? " && byRef" : "") << ") {\n"
1411+
<< (isFP ? " && byRefFloat" : "") << ") {\n"
14051412
<< tab << " BasicBlock *current = Builder2.GetInsertBlock();\n"
14061413
<< tab << " auto activeBlock = gutils->addReverseBlock(current,"
14071414
<< "bb_name + \"." << name << ".active\");\n"
@@ -1415,7 +1422,8 @@ void emit_runtime_condition(DagInit *ruleDag, StringRef name, StringRef tab,
14151422

14161423
void emit_runtime_continue(DagInit *ruleDag, StringRef name, StringRef tab,
14171424
StringRef B, bool isFP, raw_ostream &os) {
1418-
os << tab << "if (nextBlock_" << name << (isFP ? " && byRef" : "") << ") {\n"
1425+
os << tab << "if (nextBlock_" << name << (isFP ? " && byRefFloat" : "")
1426+
<< ") {\n"
14191427
<< tab << " " << B << ".CreateBr(nextBlock_" << name << ");\n"
14201428
<< tab << " " << B << ".SetInsertPoint(nextBlock_" << name << ");\n"
14211429
<< tab << "}\n";

enzyme/tools/enzyme-tblgen/blasDeclUpdater.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) {
2323
os << " return;\n";
2424
os << " const bool byRef = blas.prefix == \"\" || blas.prefix == "
2525
"\"cublas_\";\n";
26+
os << "const bool byRefFloat = byRef || blas.prefix == \"cublas\";\n";
27+
os << "(void)byRefFloat;\n";
2628
if (lv23)
2729
os << " const bool cblas = blas.prefix == \"cblas_\";\n";
2830
os << " const bool cublas = blas.prefix == \"cublas_\" || blas.prefix == "
@@ -104,26 +106,26 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) {
104106
}
105107
}
106108

107-
os << " if (byRef) {\n";
108-
109109
for (size_t argPos = 0; argPos < numArgs; argPos++) {
110110
const auto typeOfArg = argTypeMap.lookup(argPos);
111111
size_t i = (lv23 ? argPos - 1 : argPos);
112112

113113
if (is_char_arg(typeOfArg) || typeOfArg == ArgType::len ||
114114
typeOfArg == ArgType::vincInc || typeOfArg == ArgType::fp ||
115115
typeOfArg == ArgType::mldLD) {
116+
os << " if (" << (typeOfArg == ArgType::fp ? "byRefFloat" : "byRef")
117+
<< ") {\n";
116118
os << " F->removeParamAttr(" << i << " + offset"
117119
<< ", llvm::Attribute::ReadNone);\n"
118120
<< " F->addParamAttr(" << i << " + offset"
119121
<< ", llvm::Attribute::ReadOnly);\n"
120122
<< " F->addParamAttr(" << i << " + offset"
121123
<< ", llvm::Attribute::NoCapture);\n";
124+
os << " }\n";
122125
}
123126
}
124127

125-
os << " }\n"
126-
<< " // Julia declares double* pointers as Int64,\n"
128+
os << " // Julia declares double* pointers as Int64,\n"
127129
<< " // so LLVM won't let us add these Attributes.\n"
128130
<< " if (!julia_decl) {\n";
129131
for (size_t argPos = 0; argPos < numArgs; argPos++) {

enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) {
1919

2020
os << " const bool byRef = blas.prefix == \"\" || blas.prefix == "
2121
"\"cublas_\";\n";
22+
os << "const bool byRefFloat = byRef || blas.prefix == \"cublas\";\n";
23+
os << "(void)byRefFloat;\n";
2224
if (lv23)
2325
os << " const bool cblas = blas.prefix == \"cblas_\";\n";
2426
os << " const bool cublas = blas.prefix == \"cublas_\" || blas.prefix == "
@@ -77,7 +79,7 @@ void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) {
7779

7880
// We need the shadow of the value we're updating
7981
if (typeMap[argPos] == ArgType::fp) {
80-
os << " if (shadow && byRef && active_" << argname
82+
os << " if (shadow && byRefFloat && active_" << argname
8183
<< ") return true;\n";
8284
} else if (typeMap[argPos] == ArgType::vincData ||
8385
typeMap[argPos] == ArgType::mldData) {

enzyme/tools/enzyme-tblgen/blasTAUpdater.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
void emit_BLASTypes(raw_ostream &os) {
44
os << "const bool byRef = blas.prefix == \"\" || blas.prefix == "
55
"\"cublas_\";\n";
6+
os << "const bool byRefFloat = byRef || blas.prefix == "
7+
"\"cublas\";\n";
8+
os << "(void)byRefFloat;\n";
69
os << "const bool cblas = blas.prefix == \"cblas_\";\n";
710
os << "const bool cublas = blas.prefix == \"cublas_\" || blas.prefix == "
811
"\"cublas\";\n";
@@ -18,7 +21,7 @@ void emit_BLASTypes(raw_ostream &os) {
1821
<< "} else {\n"
1922
<< " llvm_unreachable(\"unknown float type of blas\");\n"
2023
<< "}\n"
21-
<< "if (byRef) {\n"
24+
<< "if (byRefFloat) {\n"
2225
<< " ttFloat.insert({-1},BaseType::Pointer);\n"
2326
<< " ttFloat.insert({-1,0},floatType);\n"
2427
<< "} else { \n"

0 commit comments

Comments
 (0)