@@ -1084,8 +1084,14 @@ class AdjointGenerator
10841084
10851085 auto dt = vd[{-1 }];
10861086 for (size_t i = start; i < size; ++i) {
1087+ auto nex = vd[{(int )i}];
1088+ if ((nex == BaseType::Anything && dt.isFloat ()) ||
1089+ (dt == BaseType::Anything && nex.isFloat ())) {
1090+ nextStart = i;
1091+ break ;
1092+ }
10871093 bool Legal = true ;
1088- dt.checkedOrIn (vd[{( int )i}] , /* PointerIntSame*/ true , Legal);
1094+ dt.checkedOrIn (nex , /* PointerIntSame*/ true , Legal);
10891095 if (!Legal) {
10901096 nextStart = i;
10911097 break ;
@@ -1199,7 +1205,8 @@ class AdjointGenerator
11991205 Builder2, align, start, size, isVolatile, ordering, syncScope,
12001206 mask, prevNoAlias, prevScopes);
12011207 ((DiffeGradientUtils *)gutils)
1202- ->addToDiffe (orig_val, diff, Builder2, FT, start, size, mask);
1208+ ->addToDiffe (orig_val, diff, Builder2, FT, start, size, {},
1209+ mask);
12031210 }
12041211 break ;
12051212 }
@@ -1909,72 +1916,160 @@ class AdjointGenerator
19091916
19101917 if (!gutils->isConstantValue (orig_inserted)) {
19111918 auto TT = TR.query (orig_inserted);
1912- auto it = TT[{-1 }];
1913- bool Legal = true ;
1914- for (size_t i = 0 ; i < size0; ++i) {
1915- bool LegalOr = true ;
1916- it.checkedOrIn (TT[{(int )i}], /* pointerIntSame*/ true , LegalOr);
1917- Legal &= LegalOr;
1918- }
1919- Type *flt = it.isFloat ();
1920- if (!it.isKnown () || !Legal) {
1921- bool found = false ;
1922-
1923- if (looseTypeAnalysis && !Legal) {
1924- if (orig_inserted->getType ()->isFPOrFPVectorTy ()) {
1925- flt = orig_inserted->getType ()->getScalarType ();
1926- found = true ;
1927- } else if (orig_inserted->getType ()->isIntOrIntVectorTy () ||
1928- orig_inserted->getType ()->isPointerTy ()) {
1929- flt = nullptr ;
1930- found = true ;
1919+
1920+ unsigned start = 0 ;
1921+ Value *dindex = nullptr ;
1922+
1923+ while (1 ) {
1924+ unsigned nextStart = size0;
1925+
1926+ auto dt = TT[{-1 }];
1927+ for (size_t i = start; i < size0; ++i) {
1928+ auto nex = TT[{(int )i}];
1929+ if ((nex == BaseType::Anything && dt.isFloat ()) ||
1930+ (dt == BaseType::Anything && nex.isFloat ())) {
1931+ nextStart = i;
1932+ break ;
1933+ }
1934+ bool Legal = true ;
1935+ dt.checkedOrIn (nex, /* PointerIntSame*/ true , Legal);
1936+ if (!Legal) {
1937+ nextStart = i;
1938+ break ;
19311939 }
19321940 }
1933- if (!found) {
1934- std::string str;
1935- raw_string_ostream ss (str);
1936- ss << " Cannot deduce type of insertvalue " << IVI
1937- << " size: " << size0 << " TT: " << TT.str ();
1938- if (CustomErrorHandler) {
1939- CustomErrorHandler (str.c_str (), wrap (&IVI), ErrorType::NoType,
1940- &TR.analyzer , nullptr , wrap (&Builder2));
1941- } else {
1942- EmitFailure (" CannotDeduceType" , IVI.getDebugLoc (), &IVI,
1943- ss.str ());
1941+ Type *flt = dt.isFloat ();
1942+ if (!dt.isKnown ()) {
1943+ bool found = false ;
1944+ if (looseTypeAnalysis) {
1945+ if (orig_inserted->getType ()->isFPOrFPVectorTy ()) {
1946+ flt = orig_inserted->getType ()->getScalarType ();
1947+ found = true ;
1948+ } else if (orig_inserted->getType ()->isIntOrIntVectorTy () ||
1949+ orig_inserted->getType ()->isPointerTy ()) {
1950+ flt = nullptr ;
1951+ found = true ;
1952+ }
1953+ }
1954+ if (!found) {
1955+ std::string str;
1956+ raw_string_ostream ss (str);
1957+ ss << " Cannot deduce type of insertvalue ins " << IVI
1958+ << " size: " << size0 << " TT: " << TT.str ();
1959+ if (CustomErrorHandler) {
1960+ CustomErrorHandler (str.c_str (), wrap (&IVI), ErrorType::NoType,
1961+ &TR.analyzer , nullptr , wrap (&Builder2));
1962+ } else {
1963+ EmitFailure (" CannotDeduceType" , IVI.getDebugLoc (), &IVI,
1964+ ss.str ());
1965+ }
19441966 }
19451967 }
1946- }
1947- if (flt) {
1948- auto rule = [&](Value *prediff) {
1949- return Builder2.CreateExtractValue (prediff, IVI.getIndices ());
1950- };
1951- auto prediff = diffe (&IVI, Builder2);
1952- auto dindex =
1953- applyChainRule (orig_inserted->getType (), Builder2, rule, prediff);
1954- addToDiffe (orig_inserted, dindex, Builder2, flt);
1968+
1969+ if (flt) {
1970+ if (!dindex) {
1971+ auto rule = [&](Value *prediff) {
1972+ return Builder2.CreateExtractValue (prediff, IVI.getIndices ());
1973+ };
1974+ auto prediff = diffe (&IVI, Builder2);
1975+ dindex = applyChainRule (orig_inserted->getType (), Builder2, rule,
1976+ prediff);
1977+ }
1978+
1979+ auto TT = TR.query (orig_inserted);
1980+
1981+ ((DiffeGradientUtils *)gutils)
1982+ ->addToDiffe (orig_inserted, dindex, Builder2, flt, start,
1983+ nextStart - start);
1984+ }
1985+ if (nextStart == size0)
1986+ break ;
1987+ start = nextStart;
19551988 }
19561989 }
19571990
19581991 size_t size1 = 1 ;
1959- if (orig_agg->getType ()->isSized () &&
1960- (orig_agg->getType ()->isIntOrIntVectorTy () ||
1961- orig_agg->getType ()->isFPOrFPVectorTy ()))
1992+ if (orig_agg->getType ()->isSized ())
19621993 size1 =
19631994 (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
19641995 orig_agg->getType ()) +
19651996 7 ) /
19661997 8 ;
19671998
19681999 if (!gutils->isConstantValue (orig_agg)) {
1969- auto rule = [&](Value *prediff) {
1970- return Builder2.CreateInsertValue (
1971- prediff, Constant::getNullValue (orig_inserted->getType ()),
1972- IVI.getIndices ());
1973- };
1974- auto prediff = diffe (&IVI, Builder2);
1975- auto dindex =
1976- applyChainRule (orig_agg->getType (), Builder2, rule, prediff);
1977- addToDiffe (orig_agg, dindex, Builder2, TR.addingType (size1, orig_agg));
2000+
2001+ auto TT = TR.query (orig_agg);
2002+
2003+ unsigned start = 0 ;
2004+
2005+ Value *dindex = nullptr ;
2006+
2007+ while (1 ) {
2008+ unsigned nextStart = size1;
2009+
2010+ auto dt = TT[{-1 }];
2011+ for (size_t i = start; i < size1; ++i) {
2012+ auto nex = TT[{(int )i}];
2013+ if ((nex == BaseType::Anything && dt.isFloat ()) ||
2014+ (dt == BaseType::Anything && nex.isFloat ())) {
2015+ nextStart = i;
2016+ break ;
2017+ }
2018+ bool Legal = true ;
2019+ dt.checkedOrIn (nex, /* PointerIntSame*/ true , Legal);
2020+ if (!Legal) {
2021+ nextStart = i;
2022+ break ;
2023+ }
2024+ }
2025+ Type *flt = dt.isFloat ();
2026+ if (!dt.isKnown ()) {
2027+ bool found = false ;
2028+ if (looseTypeAnalysis) {
2029+ if (orig_agg->getType ()->isFPOrFPVectorTy ()) {
2030+ flt = orig_agg->getType ()->getScalarType ();
2031+ found = true ;
2032+ } else if (orig_agg->getType ()->isIntOrIntVectorTy () ||
2033+ orig_agg->getType ()->isPointerTy ()) {
2034+ flt = nullptr ;
2035+ found = true ;
2036+ }
2037+ }
2038+ if (!found) {
2039+ std::string str;
2040+ raw_string_ostream ss (str);
2041+ ss << " Cannot deduce type of insertvalue agg " << IVI
2042+ << " start: " << start << " size: " << size1
2043+ << " TT: " << TT.str ();
2044+ if (CustomErrorHandler) {
2045+ CustomErrorHandler (str.c_str (), wrap (&IVI), ErrorType::NoType,
2046+ &TR.analyzer , nullptr , wrap (&Builder2));
2047+ } else {
2048+ EmitFailure (" CannotDeduceType" , IVI.getDebugLoc (), &IVI,
2049+ ss.str ());
2050+ }
2051+ }
2052+ }
2053+
2054+ if (flt) {
2055+ if (!dindex) {
2056+ auto rule = [&](Value *prediff) {
2057+ return Builder2.CreateInsertValue (
2058+ prediff, Constant::getNullValue (orig_inserted->getType ()),
2059+ IVI.getIndices ());
2060+ };
2061+ auto prediff = diffe (&IVI, Builder2);
2062+ dindex =
2063+ applyChainRule (orig_agg->getType (), Builder2, rule, prediff);
2064+ }
2065+ ((DiffeGradientUtils *)gutils)
2066+ ->addToDiffe (orig_agg, dindex, Builder2, flt, start,
2067+ nextStart - start);
2068+ }
2069+ if (nextStart == size1)
2070+ break ;
2071+ start = nextStart;
2072+ }
19782073 }
19792074
19802075 setDiffe (&IVI,
0 commit comments