Skip to content

Commit 1135f76

Browse files
authored
Handle insert value of multi type (#1560)
* Handle insert value of multi type * Fix multi agg * now with separated * fix insertion index math * Allow pointer in double addTo * fix * fix ins typetree * fix erasure
1 parent 6732d3d commit 1135f76

File tree

10 files changed

+361
-133
lines changed

10 files changed

+361
-133
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 148 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,8 +1084,14 @@ class AdjointGenerator
10841084

10851085
auto dt = vd[{-1}];
10861086
for (size_t i = start; i < size; ++i) {
1087+
auto nex = vd[{(int)i}];
1088+
if ((nex == BaseType::Anything && dt.isFloat()) ||
1089+
(dt == BaseType::Anything && nex.isFloat())) {
1090+
nextStart = i;
1091+
break;
1092+
}
10871093
bool Legal = true;
1088-
dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal);
1094+
dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal);
10891095
if (!Legal) {
10901096
nextStart = i;
10911097
break;
@@ -1199,7 +1205,8 @@ class AdjointGenerator
11991205
Builder2, align, start, size, isVolatile, ordering, syncScope,
12001206
mask, prevNoAlias, prevScopes);
12011207
((DiffeGradientUtils *)gutils)
1202-
->addToDiffe(orig_val, diff, Builder2, FT, start, size, mask);
1208+
->addToDiffe(orig_val, diff, Builder2, FT, start, size, {},
1209+
mask);
12031210
}
12041211
break;
12051212
}
@@ -1909,72 +1916,160 @@ class AdjointGenerator
19091916

19101917
if (!gutils->isConstantValue(orig_inserted)) {
19111918
auto TT = TR.query(orig_inserted);
1912-
auto it = TT[{-1}];
1913-
bool Legal = true;
1914-
for (size_t i = 0; i < size0; ++i) {
1915-
bool LegalOr = true;
1916-
it.checkedOrIn(TT[{(int)i}], /*pointerIntSame*/ true, LegalOr);
1917-
Legal &= LegalOr;
1918-
}
1919-
Type *flt = it.isFloat();
1920-
if (!it.isKnown() || !Legal) {
1921-
bool found = false;
1922-
1923-
if (looseTypeAnalysis && !Legal) {
1924-
if (orig_inserted->getType()->isFPOrFPVectorTy()) {
1925-
flt = orig_inserted->getType()->getScalarType();
1926-
found = true;
1927-
} else if (orig_inserted->getType()->isIntOrIntVectorTy() ||
1928-
orig_inserted->getType()->isPointerTy()) {
1929-
flt = nullptr;
1930-
found = true;
1919+
1920+
unsigned start = 0;
1921+
Value *dindex = nullptr;
1922+
1923+
while (1) {
1924+
unsigned nextStart = size0;
1925+
1926+
auto dt = TT[{-1}];
1927+
for (size_t i = start; i < size0; ++i) {
1928+
auto nex = TT[{(int)i}];
1929+
if ((nex == BaseType::Anything && dt.isFloat()) ||
1930+
(dt == BaseType::Anything && nex.isFloat())) {
1931+
nextStart = i;
1932+
break;
1933+
}
1934+
bool Legal = true;
1935+
dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal);
1936+
if (!Legal) {
1937+
nextStart = i;
1938+
break;
19311939
}
19321940
}
1933-
if (!found) {
1934-
std::string str;
1935-
raw_string_ostream ss(str);
1936-
ss << "Cannot deduce type of insertvalue " << IVI
1937-
<< " size: " << size0 << " TT: " << TT.str();
1938-
if (CustomErrorHandler) {
1939-
CustomErrorHandler(str.c_str(), wrap(&IVI), ErrorType::NoType,
1940-
&TR.analyzer, nullptr, wrap(&Builder2));
1941-
} else {
1942-
EmitFailure("CannotDeduceType", IVI.getDebugLoc(), &IVI,
1943-
ss.str());
1941+
Type *flt = dt.isFloat();
1942+
if (!dt.isKnown()) {
1943+
bool found = false;
1944+
if (looseTypeAnalysis) {
1945+
if (orig_inserted->getType()->isFPOrFPVectorTy()) {
1946+
flt = orig_inserted->getType()->getScalarType();
1947+
found = true;
1948+
} else if (orig_inserted->getType()->isIntOrIntVectorTy() ||
1949+
orig_inserted->getType()->isPointerTy()) {
1950+
flt = nullptr;
1951+
found = true;
1952+
}
1953+
}
1954+
if (!found) {
1955+
std::string str;
1956+
raw_string_ostream ss(str);
1957+
ss << "Cannot deduce type of insertvalue ins " << IVI
1958+
<< " size: " << size0 << " TT: " << TT.str();
1959+
if (CustomErrorHandler) {
1960+
CustomErrorHandler(str.c_str(), wrap(&IVI), ErrorType::NoType,
1961+
&TR.analyzer, nullptr, wrap(&Builder2));
1962+
} else {
1963+
EmitFailure("CannotDeduceType", IVI.getDebugLoc(), &IVI,
1964+
ss.str());
1965+
}
19441966
}
19451967
}
1946-
}
1947-
if (flt) {
1948-
auto rule = [&](Value *prediff) {
1949-
return Builder2.CreateExtractValue(prediff, IVI.getIndices());
1950-
};
1951-
auto prediff = diffe(&IVI, Builder2);
1952-
auto dindex =
1953-
applyChainRule(orig_inserted->getType(), Builder2, rule, prediff);
1954-
addToDiffe(orig_inserted, dindex, Builder2, flt);
1968+
1969+
if (flt) {
1970+
if (!dindex) {
1971+
auto rule = [&](Value *prediff) {
1972+
return Builder2.CreateExtractValue(prediff, IVI.getIndices());
1973+
};
1974+
auto prediff = diffe(&IVI, Builder2);
1975+
dindex = applyChainRule(orig_inserted->getType(), Builder2, rule,
1976+
prediff);
1977+
}
1978+
1979+
auto TT = TR.query(orig_inserted);
1980+
1981+
((DiffeGradientUtils *)gutils)
1982+
->addToDiffe(orig_inserted, dindex, Builder2, flt, start,
1983+
nextStart - start);
1984+
}
1985+
if (nextStart == size0)
1986+
break;
1987+
start = nextStart;
19551988
}
19561989
}
19571990

19581991
size_t size1 = 1;
1959-
if (orig_agg->getType()->isSized() &&
1960-
(orig_agg->getType()->isIntOrIntVectorTy() ||
1961-
orig_agg->getType()->isFPOrFPVectorTy()))
1992+
if (orig_agg->getType()->isSized())
19621993
size1 =
19631994
(gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
19641995
orig_agg->getType()) +
19651996
7) /
19661997
8;
19671998

19681999
if (!gutils->isConstantValue(orig_agg)) {
1969-
auto rule = [&](Value *prediff) {
1970-
return Builder2.CreateInsertValue(
1971-
prediff, Constant::getNullValue(orig_inserted->getType()),
1972-
IVI.getIndices());
1973-
};
1974-
auto prediff = diffe(&IVI, Builder2);
1975-
auto dindex =
1976-
applyChainRule(orig_agg->getType(), Builder2, rule, prediff);
1977-
addToDiffe(orig_agg, dindex, Builder2, TR.addingType(size1, orig_agg));
2000+
2001+
auto TT = TR.query(orig_agg);
2002+
2003+
unsigned start = 0;
2004+
2005+
Value *dindex = nullptr;
2006+
2007+
while (1) {
2008+
unsigned nextStart = size1;
2009+
2010+
auto dt = TT[{-1}];
2011+
for (size_t i = start; i < size1; ++i) {
2012+
auto nex = TT[{(int)i}];
2013+
if ((nex == BaseType::Anything && dt.isFloat()) ||
2014+
(dt == BaseType::Anything && nex.isFloat())) {
2015+
nextStart = i;
2016+
break;
2017+
}
2018+
bool Legal = true;
2019+
dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal);
2020+
if (!Legal) {
2021+
nextStart = i;
2022+
break;
2023+
}
2024+
}
2025+
Type *flt = dt.isFloat();
2026+
if (!dt.isKnown()) {
2027+
bool found = false;
2028+
if (looseTypeAnalysis) {
2029+
if (orig_agg->getType()->isFPOrFPVectorTy()) {
2030+
flt = orig_agg->getType()->getScalarType();
2031+
found = true;
2032+
} else if (orig_agg->getType()->isIntOrIntVectorTy() ||
2033+
orig_agg->getType()->isPointerTy()) {
2034+
flt = nullptr;
2035+
found = true;
2036+
}
2037+
}
2038+
if (!found) {
2039+
std::string str;
2040+
raw_string_ostream ss(str);
2041+
ss << "Cannot deduce type of insertvalue agg " << IVI
2042+
<< " start: " << start << " size: " << size1
2043+
<< " TT: " << TT.str();
2044+
if (CustomErrorHandler) {
2045+
CustomErrorHandler(str.c_str(), wrap(&IVI), ErrorType::NoType,
2046+
&TR.analyzer, nullptr, wrap(&Builder2));
2047+
} else {
2048+
EmitFailure("CannotDeduceType", IVI.getDebugLoc(), &IVI,
2049+
ss.str());
2050+
}
2051+
}
2052+
}
2053+
2054+
if (flt) {
2055+
if (!dindex) {
2056+
auto rule = [&](Value *prediff) {
2057+
return Builder2.CreateInsertValue(
2058+
prediff, Constant::getNullValue(orig_inserted->getType()),
2059+
IVI.getIndices());
2060+
};
2061+
auto prediff = diffe(&IVI, Builder2);
2062+
dindex =
2063+
applyChainRule(orig_agg->getType(), Builder2, rule, prediff);
2064+
}
2065+
((DiffeGradientUtils *)gutils)
2066+
->addToDiffe(orig_agg, dindex, Builder2, flt, start,
2067+
nextStart - start);
2068+
}
2069+
if (nextStart == size1)
2070+
break;
2071+
start = nextStart;
2072+
}
19782073
}
19792074

19802075
setDiffe(&IVI,

0 commit comments

Comments
 (0)