Skip to content

Commit 2e164a0

Browse files
authored
Strengthen any type checks (#1864)
1 parent 7d09d5e commit 2e164a0

File tree

3 files changed

+48
-10
lines changed

3 files changed

+48
-10
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3330,8 +3330,8 @@ void createInvertedTerminator(DiffeGradientUtils *gutils,
33303330
auto PNtype = PNtypeT[{-1}];
33313331

33323332
// TODO remove explicit type check and only use PNtype
3333-
if (PNtype == BaseType::Anything || PNtype == BaseType::Pointer ||
3334-
PNtype == BaseType::Integer || orig->getType()->isPointerTy())
3333+
if (!gutils->TR.anyFloat(orig, /*anythingIsFloat*/ false) ||
3334+
orig->getType()->isPointerTy())
33353335
continue;
33363336

33373337
Type *PNfloatType = PNtype.isFloat();

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5888,29 +5888,60 @@ TypeTree TypeResults::query(Value *val) const {
58885888
return analyzer->getAnalysis(val);
58895889
}
58905890

5891-
bool TypeResults::anyFloat(Value *val) const {
5891+
// Returns last non-padding/alignment location of the corresponding subtype T.
5892+
size_t skippedBytes(SmallSet<size_t, 8> &offs, Type *T, const DataLayout &DL,
5893+
size_t offset = 0) {
5894+
auto ST = dyn_cast<StructType>(T);
5895+
if (!ST)
5896+
return (DL.getTypeSizeInBits(T) + 7) / 8;
5897+
5898+
auto SL = DL.getStructLayout(ST);
5899+
size_t prevOff = 0;
5900+
for (size_t idx = 0; idx < ST->getNumElements(); idx++) {
5901+
auto off = SL->getElementOffset(idx);
5902+
if (off > prevOff)
5903+
for (size_t i = prevOff; i < off; i++)
5904+
offs.insert(offset + i);
5905+
size_t subSize = skippedBytes(offs, ST->getElementType(idx), DL, prevOff);
5906+
prevOff = off + subSize;
5907+
}
5908+
return prevOff;
5909+
}
5910+
5911+
bool TypeResults::anyFloat(Value *val, bool anythingIsFloat) const {
58925912
assert(val);
58935913
assert(val->getType());
58945914
auto q = query(val);
58955915
auto dt = q[{-1}];
5916+
if (!anythingIsFloat && dt == BaseType::Anything)
5917+
return false;
58965918
if (dt != BaseType::Anything && dt != BaseType::Unknown)
58975919
return dt.isFloat();
58985920

5899-
size_t ObjSize = 1;
5921+
if (val->getType()->isTokenTy())
5922+
return false;
59005923
auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout();
5901-
if (val->getType()->isSized())
5902-
ObjSize = (dl.getTypeSizeInBits(val->getType()) + 7) / 8;
5924+
SmallSet<size_t, 8> offs;
5925+
size_t ObjSize = skippedBytes(offs, val->getType(), dl);
59035926

59045927
for (size_t i = 0; i < ObjSize;) {
59055928
dt = q[{(int)i}];
59065929
if (dt == BaseType::Integer) {
59075930
i++;
59085931
continue;
59095932
}
5933+
if (!anythingIsFloat && dt == BaseType::Integer) {
5934+
i++;
5935+
continue;
5936+
}
59105937
if (dt == BaseType::Pointer) {
59115938
i += dl.getPointerSize(0);
59125939
continue;
59135940
}
5941+
if (offs.count(i)) {
5942+
i++;
5943+
continue;
5944+
}
59145945
return true;
59155946
}
59165947
return false;
@@ -5923,11 +5954,12 @@ bool TypeResults::anyPointer(Value *val) const {
59235954
auto dt = q[{-1}];
59245955
if (dt != BaseType::Anything && dt != BaseType::Unknown)
59255956
return dt == BaseType::Pointer;
5957+
if (val->getType()->isTokenTy())
5958+
return false;
59265959

5927-
size_t ObjSize = 1;
59285960
auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout();
5929-
if (val->getType()->isSized())
5930-
ObjSize = (dl.getTypeSizeInBits(val->getType()) + 7) / 8;
5961+
SmallSet<size_t, 8> offs;
5962+
size_t ObjSize = skippedBytes(offs, val->getType(), dl);
59315963

59325964
for (size_t i = 0; i < ObjSize;) {
59335965
dt = q[{(int)i}];
@@ -5939,6 +5971,10 @@ bool TypeResults::anyPointer(Value *val) const {
59395971
i += (dl.getTypeSizeInBits(FT) + 7) / 8;
59405972
continue;
59415973
}
5974+
if (offs.count(i)) {
5975+
i++;
5976+
continue;
5977+
}
59425978
return true;
59435979
}
59445980
return false;

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,9 @@ class TypeResults {
178178
/// Whether any part of the top level register can contain a float
179179
/// e.g. { i64, float } can contain a float, but { i64, i8* } would not.
180180
// Of course, here we compute with type analysis rather than llvm type
181-
bool anyFloat(llvm::Value *val) const;
181+
// The flag `anythingIsFloat` specifies whether an anything should
182+
// be considered a float.
183+
bool anyFloat(llvm::Value *val, bool anythingIsFloat = true) const;
182184

183185
/// Whether any part of the top level register can contain a pointer
184186
/// e.g. { i64, i8* } can contain a pointer, but { i64, float } would not.

0 commit comments

Comments
 (0)