Skip to content

Commit 77b4fff

Browse files
authored
[forwardmode] handle multi store on active store (#1564)
1 parent 1f1e996 commit 77b4fff

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,23 +1026,26 @@ class AdjointGenerator
10261026
if (Mode == DerivativeMode::ForwardMode) {
10271027

10281028
auto dt = vd[{-1}];
1029-
for (size_t i = 0; i < storeSize; ++i) {
1030-
bool Legal = true;
1031-
dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal);
1032-
if (!Legal) {
1033-
std::string str;
1034-
raw_string_ostream ss(str);
1035-
ss << "Cannot deduce single type of store " << I << vd.str()
1036-
<< " size: " << storeSize;
1037-
if (CustomErrorHandler) {
1038-
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
1039-
&TR.analyzer, nullptr, wrap(&BuilderZ));
1040-
} else {
1041-
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I, ss.str());
1029+
// Only need the full type in forward mode, if storing a constant
1030+
// and therefore may need to zero some floats.
1031+
if (constantval)
1032+
for (size_t i = 0; i < storeSize; ++i) {
1033+
bool Legal = true;
1034+
dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal);
1035+
if (!Legal) {
1036+
std::string str;
1037+
raw_string_ostream ss(str);
1038+
ss << "Cannot deduce single type of store " << I << vd.str()
1039+
<< " size: " << storeSize;
1040+
if (CustomErrorHandler) {
1041+
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
1042+
&TR.analyzer, nullptr, wrap(&BuilderZ));
1043+
} else {
1044+
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I, ss.str());
1045+
}
1046+
return;
10421047
}
1043-
return;
10441048
}
1045-
}
10461049

10471050
Value *diff = nullptr;
10481051
if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && constantval) {

0 commit comments

Comments
 (0)