@@ -1026,23 +1026,26 @@ class AdjointGenerator
10261026 if (Mode == DerivativeMode::ForwardMode) {
10271027
10281028 auto dt = vd[{-1 }];
1029- for (size_t i = 0 ; i < storeSize; ++i) {
1030- bool Legal = true ;
1031- dt.checkedOrIn (vd[{(int )i}], /* PointerIntSame*/ true , Legal);
1032- if (!Legal) {
1033- std::string str;
1034- raw_string_ostream ss (str);
1035- ss << " Cannot deduce single type of store " << I << vd.str ()
1036- << " size: " << storeSize;
1037- if (CustomErrorHandler) {
1038- CustomErrorHandler (str.c_str (), wrap (&I), ErrorType::NoType,
1039- &TR.analyzer , nullptr , wrap (&BuilderZ));
1040- } else {
1041- EmitFailure (" CannotDeduceType" , I.getDebugLoc (), &I, ss.str ());
1029+ // Only need the full type in forward mode, if storing a constant
1030+ // and therefore may need to zero some floats.
1031+ if (constantval)
1032+ for (size_t i = 0 ; i < storeSize; ++i) {
1033+ bool Legal = true ;
1034+ dt.checkedOrIn (vd[{(int )i}], /* PointerIntSame*/ true , Legal);
1035+ if (!Legal) {
1036+ std::string str;
1037+ raw_string_ostream ss (str);
1038+ ss << " Cannot deduce single type of store " << I << vd.str ()
1039+ << " size: " << storeSize;
1040+ if (CustomErrorHandler) {
1041+ CustomErrorHandler (str.c_str (), wrap (&I), ErrorType::NoType,
1042+ &TR.analyzer , nullptr , wrap (&BuilderZ));
1043+ } else {
1044+ EmitFailure (" CannotDeduceType" , I.getDebugLoc (), &I, ss.str ());
1045+ }
1046+ return ;
10421047 }
1043- return ;
10441048 }
1045- }
10461049
10471050 Value *diff = nullptr ;
10481051 if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && constantval) {
0 commit comments