@@ -149,19 +149,22 @@ class LayoutRematerialization {
149149  getConvertBackwardSlice (OpOperand &root, Attribute rootEncoding,
150150                          SetVector<Value> &slice,
151151                          DenseMap<Value, Attribute> &layout,
152-                           std::function<bool (Operation *)> stopPropagation);
152+                           std::function<bool (Operation *)> stopPropagation,
153+                           bool  includeForOp = false );
153154
154155  LogicalResult getRematerializableSlice (
155156      OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
156157      DenseMap<Value, Attribute> &layout,
157-       std::function<bool (Operation *)> stopPropagation = nullptr);
158+       std::function<bool (Operation *)> stopPropagation = nullptr,
159+       bool includeForOp = false);
158160
159161private: 
160162  void  updateRematMapping (SmallVector<std::tuple<Value, Value>> &values);
161163  //  Existing tuples of (value, layout) that needs to be updated when recreating
162164  //  scf ops. This prevents keeping track of Values that have been delete when
163-   //  rewriting slices.
164-   DenseMap<Value, Attribute> mappedValues;
165+   //  rewriting slices. The Value maybe mapped to different attributes in remove
166+   //  layout.
167+   DenseMap<Value, SmallVector<Attribute>> mappedValues;
165168  //  map of the values remat based on encoding.
166169  DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
167170  //  DenseMap<std::pair<Operation*, Attribute>, Operation*>
@@ -174,7 +177,11 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
174177                                            Value newV) {
175178  LDBG (" addRematValue "   << old << "  encoding "   << encoding << "  "   << newV);
176179  rematMapping[{old, encoding}] = newV;
177-   mappedValues[old] = encoding;
180+   if  (mappedValues.contains (old)) {
181+     mappedValues[old].push_back (encoding);
182+   } else  {
183+     mappedValues[old] = {encoding};
184+   }
178185}
179186
180187//  Remove unneeded values now that we are done with the rematMapping.
@@ -955,22 +962,28 @@ void LayoutRematerialization::updateRematMapping(
955962  for  (auto  [old, newV] : values) {
956963    auto  it = mappedValues.find (old);
957964    if  (it != mappedValues.end ()) {
958-       Attribute encoding = it->second ;
959-       auto  rematIt = rematMapping.find ({old, it->second });
960-       assert (rematIt != rematMapping.end ());
961-       Value replacedValue = rematIt->second ;
962-       rematMapping.erase (rematIt);
963-       mappedValues.erase (it);
964-       //  Loop through the replacement value to find the new version of remat
965-       //  value. This should be okay as the number of values should be small.
966-       for  (auto  [before, after] : values) {
967-         if  (before == replacedValue) {
968-           replacedValue = after;
969-           break ;
965+       SmallVector<Attribute> encodings = it->second ;
966+       for  (auto  encoding : encodings) {
967+         auto  rematIt = rematMapping.find ({old, encoding});
968+         assert (rematIt != rematMapping.end ());
969+         Value replacedValue = rematIt->second ;
970+         rematMapping.erase (rematIt);
971+         //  Loop through the replacement value to find the new version of remat
972+         //  value. This should be okay as the number of values should be small.
973+         for  (auto  [before, after] : values) {
974+           if  (before == replacedValue) {
975+             replacedValue = after;
976+             break ;
977+           }
970978        }
979+         rematMapping[{newV, encoding}] = replacedValue;
980+       }
981+       mappedValues.erase (it);
982+       if  (mappedValues.contains (newV)) {
983+         mappedValues[newV].append (encodings);
984+       } else  {
985+         mappedValues[newV] = std::move (encodings);
971986      }
972-       rematMapping[{newV, encoding}] = replacedValue;
973-       mappedValues[newV] = encoding;
974987    }
975988  }
976989}
@@ -1045,6 +1058,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
10451058      deadOps.push_back (forOp.getOperation ());
10461059      Block &loopBody = *newForOp.getBody ();
10471060      for  (auto  m : argMapping) {
1061+         mapping.map (newForOp.getResult (m.first ), newForOp.getResult (m.second ));
10481062        mapping.map (forOp.getResult (m.first ), newForOp.getResult (m.second ));
10491063        int  numIndVars = newForOp.getNumInductionVars ();
10501064        mapping.map (loopBody.getArgument (m.first  + numIndVars),
@@ -1161,8 +1175,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11611175    builder.replaceAllUsesWith (std::get<0 >(kv), std::get<1 >(kv));
11621176  }
11631177
1164-   for  (Operation *op : deadOps)
1165-     opToDelete.insert (op);
1178+   for  (Operation *op : deadOps) {
1179+     if  (!isa<scf::ForOp>(op))
1180+       opToDelete.insert (op);
1181+     else 
1182+       op->erase ();
1183+   }
11661184}
11671185
11681186void  LayoutRematerialization::rewriteSlice (SetVector<Value> &slice,
@@ -1175,7 +1193,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11751193LogicalResult LayoutRematerialization::getConvertBackwardSlice (
11761194    OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
11771195    DenseMap<Value, Attribute> &layout,
1178-     std::function<bool (Operation *)> stopPropagation) {
1196+     std::function<bool (Operation *)> stopPropagation, bool includeForOp ) {
11791197  //  Allow re-using existing conversions for a value. Check dominance of any
11801198  //  reusable materializations against the root value. This is sufficient
11811199  //  because the conversions are processed in post-order.
@@ -1204,15 +1222,18 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
12041222  };
12051223
12061224  return  ttgi::getConvertBackwardSlice (root, slice, rootEncoding, layout,
1207-                                        stopPropagation, getExistingConversion);
1225+                                        stopPropagation, getExistingConversion,
1226+                                        includeForOp);
12081227}
12091228
12101229LogicalResult LayoutRematerialization::getRematerializableSlice (
12111230    OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
12121231    DenseMap<Value, Attribute> &layout,
1213-     std::function<bool (Operation *)> stopPropagation) {
1214-   LogicalResult result = getConvertBackwardSlice (
1215-       root, rootEncoding, slice, layout, std::move (stopPropagation));
1232+     std::function<bool (Operation *)> stopPropagation, bool includeForOp) {
1233+ 
1234+   LogicalResult result =
1235+       getConvertBackwardSlice (root, rootEncoding, slice, layout,
1236+                               std::move (stopPropagation), includeForOp);
12161237  if  (result.failed () || slice.empty ())
12171238    return  failure ();
12181239
@@ -1301,8 +1322,9 @@ void LayoutRematerialization::backwardRematerialization(
13011322  //  rematerialized.
13021323  SetVector<Value> slice;
13031324  DenseMap<Value, Attribute> layout;
1304-   LogicalResult result = getRematerializableSlice (
1305-       convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout);
1325+   LogicalResult result = getRematerializableSlice (convertOp.getSrcMutable (),
1326+                                                   targetType.getEncoding (),
1327+                                                   slice, layout, nullptr , true );
13061328  if  (result.failed ()) {
13071329    LDBG ("   getRematerializableSlice failed"  );
13081330    return ;
0 commit comments