@@ -101,39 +101,13 @@ static void interleave(TileSwizzle &swizzle, size_t srcIdx, int expandedIdx) {
101101template  <typename  MMAIntrinsicTy>
102102static  TileSwizzle getIntrinsicSwizzle (MMAIntrinsicTy intrinsic,
103103                                       unsigned  operandIdx) {
104-   IREE::GPU::MMASingleSubgroupLayout layout;
104+   IREE::GPU::MMASingleSubgroupLayout layout =
105+       IREE::GPU::getSingleSubgroupLayout (intrinsic, operandIdx);
105106  const  bool  isScaled =
106107      std::is_same<MMAIntrinsicTy, IREE::GPU::ScaledMMAIntrinsic>::value;
107-   const  unsigned  lhsIdx = 0 ;
108-   const  unsigned  rhsIdx = 1 ;
109-   const  unsigned  lhsScalesIdx = 2 ;
110-   const  unsigned  rhsScalesIdx = 3 ;
111-   const  bool  isLHSorRHS = operandIdx == lhsIdx || operandIdx == rhsIdx;
112-   if  (isScaled) {
113-     //  The operand mapping for `getSingleSubgroupLayout` follows a different
114-     //  operand order than is used for TileSwizzle, so we need to remap the
115-     //  operandIdx to get the right layout. The layouts for TileSwizzle vs.
116-     //  `getSingleSubgroupLayout` are shown below:
117-     //              | TileSwizzle | getSingleSubgroupLayout
118-     //          LHS | 0           | 0
119-     //          RHS | 1           | 2
120-     //   LHS Scales | 2           | 1
121-     //   RHS Scales | 3           | 3
122-     //          ACC | 4           | 4
123-     //  TODO(Max191): Decide on a consistent operand order for both.
124-     int64_t  layoutOperandIdx = operandIdx;
125-     if  (operandIdx == rhsIdx) {
126-       layoutOperandIdx = 2 ;
127-     } else  if  (operandIdx == lhsScalesIdx) {
128-       layoutOperandIdx = 1 ;
129-     }
130-     layout = IREE::GPU::getSingleSubgroupLayout (
131-         static_cast <ScaledMMAIntrinsic>(intrinsic), layoutOperandIdx);
132-   } else  {
133-     layout = IREE::GPU::getSingleSubgroupLayout (
134-         static_cast <MMAIntrinsic>(intrinsic),
135-         static_cast <IREE::GPU::MMAFragment>(operandIdx));
136-   }
108+   const  bool  isLhs = isIntrinsicLhs<MMAIntrinsicTy>(operandIdx);
109+   const  bool  isRhs = isIntrinsicRhs<MMAIntrinsicTy>(operandIdx);
110+   const  bool  isRhsScale = isIntrinsicRhsScale<MMAIntrinsicTy>(operandIdx);
137111
138112  //  MMASingleSubgroupLayout has non-transposed RHS and RHS scales, but
139113  //  TileSwizzle has transposed RHS and RHS scales, so reorder the `layout`
@@ -143,7 +117,7 @@ static TileSwizzle getIntrinsicSwizzle(MMAIntrinsicTy intrinsic,
143117    //  rotate right by 1 element to swap [K, Kb] and N.
144118    std::rotate (v.begin (), v.end () - 1 , v.end ());
145119  };
146-   if  (operandIdx == rhsIdx  || (isScaled && operandIdx == rhsScalesIdx) ) {
120+   if  (isRhs  || isRhsScale ) {
147121    swapRHSKAndN (layout.outer );
148122    swapRHSKAndN (layout.thread );
149123    swapRHSKAndN (layout.tstrides );
@@ -155,7 +129,7 @@ static TileSwizzle getIntrinsicSwizzle(MMAIntrinsicTy intrinsic,
155129  //  All other operands (and LHS/RHS for non-scaled matmuls) have 2 source
156130  //  dimensions. These correspond to the arrays in `layout` all having a
157131  //  matching size. Let's just guard that assumption with one assert here.
158-   const  unsigned  numSrcDims = isScaled && isLHSorRHS  ? 3  : 2 ;
132+   const  unsigned  numSrcDims = isScaled && (isLhs || isRhs)  ? 3  : 2 ;
159133  assert (layout.thread .size () == numSrcDims &&
160134         " expected layout rank to match the number of source dims" 
161135  swizzle.expandShape .resize (numSrcDims);
@@ -233,16 +207,14 @@ static size_t getInnermostNonInternalDimIdx(
233207template  <typename  MMAAttrTy>
234208static  TileSwizzle getSwizzleImpl (MMAAttrTy mma, unsigned  operandIdx) {
235209  TileSwizzle swizzle = getIntrinsicSwizzle (mma.getIntrinsic (), operandIdx);
236-   const  bool  isScaled =
237-       std::is_same<MMAAttrTy, IREE::GPU::DataTiledScaledMMAAttr>::value;
238-   const  unsigned  lhsIdx = 0 ;
239-   const  unsigned  rhsIdx = 1 ;
240-   const  unsigned  lhsScalesIdx = 2 ;
241-   const  unsigned  rhsScalesIdx = 3 ;
242-   const  unsigned  accIdx = isScaled ? 4  : 2 ;
243-   const  bool  isRhsScales = isScaled && operandIdx == rhsScalesIdx;
244-   const  bool  isLhsScales = isScaled && operandIdx == lhsScalesIdx;
245-   if  (operandIdx == lhsIdx || isLhsScales) {
210+   using  MMAIntrinsicTy = decltype (mma.getIntrinsic ());
211+   const  bool  isScaled = std::is_same<MMAIntrinsicTy, ScaledMMAIntrinsic>::value;
212+   const  bool  isLhs = isIntrinsicLhs<MMAIntrinsicTy>(operandIdx);
213+   const  bool  isRhs = isIntrinsicRhs<MMAIntrinsicTy>(operandIdx);
214+   const  bool  isAcc = isIntrinsicAcc<MMAIntrinsicTy>(operandIdx);
215+   const  bool  isLhsScale = isIntrinsicLhsScale<MMAIntrinsicTy>(operandIdx);
216+   const  bool  isRhsScale = isIntrinsicRhsScale<MMAIntrinsicTy>(operandIdx);
217+   if  (isLhs || isLhsScale) {
246218    //  A-matrix (LHS). Source dimensions are M (index 0) and K (index 1).
247219    //  Unroll on K with interleaving, then on M.
248220    if  (mma.getIntrinsicsK () > 1 ) {
@@ -253,10 +225,10 @@ static TileSwizzle getSwizzleImpl(MMAAttrTy mma, unsigned operandIdx) {
253225      //  the unrolled scales with each vector load, so we need to interleave at
254226      //  the very last dimension for the scales. For the LHS, we load in blocks,
255227      //  so we don't need to interleave.
256-       if  (isLhsScales ) {
228+       if  (isLhsScale ) {
257229        interleavingIdx = swizzle.expandShape [1 ].size () - 1 ;
258230      }
259-       if  (!isScaled || isLhsScales ) {
231+       if  (!isScaled || isLhsScale ) {
260232        interleave (swizzle, 1 , interleavingIdx);
261233      }
262234    }
@@ -272,7 +244,7 @@ static TileSwizzle getSwizzleImpl(MMAAttrTy mma, unsigned operandIdx) {
272244                           mma.getSubgroupsM () * mma.getSubgroupsN ());
273245      expand (swizzle, 0 , dim);
274246    }
275-   } else  if  (operandIdx == rhsIdx  || isRhsScales ) {
247+   } else  if  (isRhs  || isRhsScale ) {
276248    //  B-matrix (RHS). Since the pack ops already took care of transposing B,
277249    //  source dimensions are N (index 0) and K (index 1).
278250    //  Unroll on K with interleaving, then on N.
@@ -282,10 +254,10 @@ static TileSwizzle getSwizzleImpl(MMAAttrTy mma, unsigned operandIdx) {
282254          getInnermostNonInternalDimIdx (swizzle.expandShape [1 ]);
283255      //  Like with the LHS above, we want to interleave such that we load all
284256      //  the unrolled scales with each vector load.
285-       if  (isRhsScales ) {
257+       if  (isRhsScale ) {
286258        interleavingIdx = swizzle.expandShape [1 ].size () - 1 ;
287259      }
288-       if  (!isScaled || isRhsScales ) {
260+       if  (!isScaled || isRhsScale ) {
289261        interleave (swizzle, 1 , interleavingIdx);
290262      }
291263    }
@@ -295,7 +267,7 @@ static TileSwizzle getSwizzleImpl(MMAAttrTy mma, unsigned operandIdx) {
295267    if  (mma.getSubgroupsN () > 1 ) {
296268      expand (swizzle, 0 , {Kind::CrossThread, mma.getSubgroupsN ()});
297269    }
298-   } else  if  (operandIdx == accIdx ) {
270+   } else  if  (isAcc ) {
299271    //  C-matrix (accumulator). Source dimensions are M (index 0) and N (index
300272    //  1). Unroll on N, then on M.
301273    if  (mma.getIntrinsicsN () > 1 ) {
@@ -319,9 +291,8 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledScaledMMAAttr scaledMma,
319291  return  getSwizzleImpl (scaledMma, operandIdx);
320292}
321293
322- TileSwizzle getSwizzle (IREE::GPU::DataTiledMMAAttr mma,
323-                        IREE::GPU::MMAFragment fragment) {
324-   return  getSwizzleImpl (mma, static_cast <unsigned >(fragment));
294+ TileSwizzle getSwizzle (IREE::GPU::DataTiledMMAAttr mma, int  operandIndex) {
295+   return  getSwizzleImpl (mma, operandIndex);
325296}
326297
327298// / Remove the expanded dimensions for this index and update the permutation by
0 commit comments