@@ -157,19 +157,22 @@ class LayoutRematerialization {
157157  getConvertBackwardSlice (OpOperand &root, Attribute rootEncoding,
158158                          SetVector<Value> &slice,
159159                          DenseMap<Value, Attribute> &layout,
160-                           std::function<bool (Operation *)> stopPropagation);
160+                           std::function<bool (Operation *)> stopPropagation,
161+                           bool  includeForOp = false );
161162
162163  LogicalResult getRematerializableSlice (
163164      OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
164165      DenseMap<Value, Attribute> &layout,
165-       std::function<bool (Operation *)> stopPropagation = nullptr);
166+       std::function<bool (Operation *)> stopPropagation = nullptr,
167+       bool includeForOp = false);
166168
167169private: 
168170  void  updateRematMapping (SmallVector<std::tuple<Value, Value>> &values);
169171  //  Existing tuples of (value, layout) that needs to be updated when recreating
170172  //  scf ops. This prevents keeping track of Values that have been delete when
171-   //  rewriting slices.
172-   DenseMap<Value, Attribute> mappedValues;
173+   //  rewriting slices. The Value maybe mapped to different attributes in remove
174+   //  layout.
175+   DenseMap<Value, SmallVector<Attribute>> mappedValues;
173176  //  map of the values remat based on encoding.
174177  DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
175178  //  DenseMap<std::pair<Operation*, Attribute>, Operation*>
@@ -183,7 +186,11 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
183186                                            Value newV) {
184187  LDBG (" addRematValue "   << old << "  encoding "   << encoding << "  "   << newV);
185188  rematMapping[{old, encoding}] = newV;
186-   mappedValues[old] = encoding;
189+   if  (mappedValues.contains (old)) {
190+     mappedValues[old].push_back (encoding);
191+   } else  {
192+     mappedValues[old] = {encoding};
193+   }
187194}
188195
189196//  Remove unneeded values now that we are done with the rematMapping.
@@ -988,22 +995,28 @@ void LayoutRematerialization::updateRematMapping(
988995  for  (auto  [old, newV] : values) {
989996    auto  it = mappedValues.find (old);
990997    if  (it != mappedValues.end ()) {
991-       Attribute encoding = it->second ;
992-       auto  rematIt = rematMapping.find ({old, it->second });
993-       assert (rematIt != rematMapping.end ());
994-       Value replacedValue = rematIt->second ;
995-       rematMapping.erase (rematIt);
996-       mappedValues.erase (it);
997-       //  Loop through the replacement value to find the new version of remat
998-       //  value. This should be okay as the number of values should be small.
999-       for  (auto  [before, after] : values) {
1000-         if  (before == replacedValue) {
1001-           replacedValue = after;
1002-           break ;
998+       SmallVector<Attribute> encodings = it->second ;
999+       for  (auto  encoding : encodings) {
1000+         auto  rematIt = rematMapping.find ({old, encoding});
1001+         assert (rematIt != rematMapping.end ());
1002+         Value replacedValue = rematIt->second ;
1003+         rematMapping.erase (rematIt);
1004+         //  Loop through the replacement value to find the new version of remat
1005+         //  value. This should be okay as the number of values should be small.
1006+         for  (auto  [before, after] : values) {
1007+           if  (before == replacedValue) {
1008+             replacedValue = after;
1009+             break ;
1010+           }
10031011        }
1012+         rematMapping[{newV, encoding}] = replacedValue;
1013+       }
1014+       mappedValues.erase (it);
1015+       if  (mappedValues.contains (newV)) {
1016+         mappedValues[newV].append (encodings);
1017+       } else  {
1018+         mappedValues[newV] = std::move (encodings);
10041019      }
1005-       rematMapping[{newV, encoding}] = replacedValue;
1006-       mappedValues[newV] = encoding;
10071020    }
10081021  }
10091022}
@@ -1078,6 +1091,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
10781091      deadOps.push_back (forOp.getOperation ());
10791092      Block &loopBody = *newForOp.getBody ();
10801093      for  (auto  m : argMapping) {
1094+         mapping.map (newForOp.getResult (m.first ), newForOp.getResult (m.second ));
10811095        mapping.map (forOp.getResult (m.first ), newForOp.getResult (m.second ));
10821096        int  numIndVars = newForOp.getNumInductionVars ();
10831097        mapping.map (loopBody.getArgument (m.first  + numIndVars),
@@ -1188,8 +1202,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11881202    builder.replaceAllUsesWith (std::get<0 >(kv), std::get<1 >(kv));
11891203  }
11901204
1191-   for  (Operation *op : deadOps)
1192-     opToDelete.insert (op);
1205+   for  (Operation *op : deadOps) {
1206+     if  (!isa<scf::ForOp>(op))
1207+       opToDelete.insert (op);
1208+     else 
1209+       op->erase ();
1210+   }
11931211}
11941212
11951213void  LayoutRematerialization::rewriteSlice (SetVector<Value> &slice,
@@ -1202,7 +1220,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
12021220LogicalResult LayoutRematerialization::getConvertBackwardSlice (
12031221    OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
12041222    DenseMap<Value, Attribute> &layout,
1205-     std::function<bool (Operation *)> stopPropagation) {
1223+     std::function<bool (Operation *)> stopPropagation, bool includeForOp ) {
12061224  //  Allow re-using existing conversions for a value. Check dominance of any
12071225  //  reusable materializations against the root value. This is sufficient
12081226  //  because the conversions are processed in post-order.
@@ -1231,15 +1249,16 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
12311249  };
12321250
12331251  return  ttgi::getConvertBackwardSlice (root, slice, rootEncoding, layout,
1234-                                        stopPropagation, getExistingConversion);
1252+                                        stopPropagation, getExistingConversion,
1253+                                        includeForOp);
12351254}
12361255
12371256LogicalResult LayoutRematerialization::getRematerializableSlice (
12381257    OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
12391258    DenseMap<Value, Attribute> &layout,
1240-     std::function<bool (Operation *)> stopPropagation) {
1241-   LogicalResult result = getConvertBackwardSlice (root, rootEncoding, slice, 
1242-                                                   layout, stopPropagation);
1259+     std::function<bool (Operation *)> stopPropagation, bool includeForOp ) {
1260+   LogicalResult result = getConvertBackwardSlice (
1261+       root, rootEncoding, slice,  layout, stopPropagation, includeForOp );
12431262  if  (result.failed () || slice.empty ())
12441263    return  failure ();
12451264
@@ -1362,8 +1381,9 @@ void LayoutRematerialization::backwardRematerialization(
13621381  //  rematerialized.
13631382  SetVector<Value> slice;
13641383  DenseMap<Value, Attribute> layout;
1365-   LogicalResult result = getRematerializableSlice (
1366-       convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout);
1384+   LogicalResult result = getRematerializableSlice (convertOp.getSrcMutable (),
1385+                                                   targetType.getEncoding (),
1386+                                                   slice, layout, nullptr , true );
13671387  if  (result.failed ()) {
13681388    LDBG ("   getRematerializableSlice failed"  );
13691389    return ;
0 commit comments