@@ -2528,17 +2528,6 @@ struct LoadOpToBlockIOConversion
25282528    if  (tileHeight * tileWidth * packedElemSizeInBits / 8  < GRF_SIZE)
25292529      vBlocks = 1 ;
25302530
2531-     //  TODO: use the axis info to general the handling for both regular pointer
2532-     //  and block pointer.
2533-     const  bool  memoryRowMajor = isMemoryRowMajor (op);
2534-     unsigned  contiguousDim = memoryRowMajor ? 1  : 0 ;
2535-     const  bool  isTransposeRequired = contiguousDim != colDim;
2536- 
2537-     if  (isTransposeRequired) {
2538-       //  TODO: support load column major data.
2539-       return  failure ();
2540-     }
2541- 
25422531    Location loc = op.getLoc ();
25432532    auto  b = TritonLLVMOpBuilder (loc, rewriter);
25442533    MLIRContext *ctx = rewriter.getContext ();
@@ -2667,10 +2656,59 @@ struct LoadOpToBlockIOConversion
26672656      }
26682657    }
26692658
2659+     //  TODO: use the axis info to general the handling for both regular pointer
2660+     //  and block pointer.
2661+     const  bool  memoryRowMajor = isMemoryRowMajor (op);
2662+     unsigned  contiguousDim = memoryRowMajor ? 1  : 0 ;
2663+     const  bool  isTransposeRequired = contiguousDim != colDim;
2664+ 
2665+     if  (isTransposeRequired) {
2666+       if  (numPackedVals > 1 )
2667+         return  failure ();
2668+       if  (elemSizeInBits > 32 )
2669+         return  failure ();
2670+       if  (tileWidth > 32 )
2671+         return  failure (); //  tileWidth is limited to 32 for transpose 2d load.
2672+ 
2673+       vBlocks = 1 ;
2674+ 
2675+       //  use the d32 for transpose 2d load.
2676+       packedElemSizeInBits = 32 ;
2677+       numPackedVals = packedElemSizeInBits / elemSizeInBits;
2678+       if  (numPackedVals > 1  && tileWidth != threadsPerWarp)
2679+         return  failure (); //  Couldn't use the transpose 2d load for un-packable
2680+                           //  along tile height dim.
2681+       tileHeight = std::min (tileHeight / numPackedVals, 8 );
2682+ 
2683+       if  (tileHeight * tileWidth < threadsPerWarp)
2684+         return  failure (); //  The tile size is not large enough for IGC scalar
2685+                           //  backend vectorization.
2686+       //  transpose the width and height of the tile
2687+       std::swap (tileHeight, tileWidth);
2688+       //  if (oneMatrixPerLoadForBT) {
2689+       //    // Only load 1 operand per inst on row.
2690+       //    numOperandsPer2DLoadM = 1;
2691+       //    tileHeight = elemsPerDPASInst[threadOrder[rank - 2]];
2692+       //  } else {
2693+       //    // We can decompose the matrix returned by transposed large 2d load
2694+       //    // when threads per warp < column size. Otherwise we have to load one
2695+       //    // operand per inst.
2696+       //    // Note: the tileHeight and numOperandsPer2DLoadM are the column size
2697+       //    // now.
2698+       //    numOperandsPer2DLoadM =
2699+       //        (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
2700+       //  }
2701+       //  // The transpose 2d load only support 1 operand per inst on column.
2702+       //  // (vBlocks = 1)
2703+       //  numOperandsPer2DloadN = 1;
2704+       //  // TODO: support load column major data.
2705+       //  return failure();
2706+     }
2707+ 
26702708    baseWidth = b.i32_val (
26712709        std::max (64u , vBlocks * tileWidth * (packedElemSizeInBits / 8 )));
26722710    //  If the stride is 0, we want to load only the first row.
2673-     int  stride = getStride (ptr, 0 );
2711+     int  stride = getStride (ptr, memoryRowMajor ?  0  :  1 );
26742712    baseHeightInt = (stride == 0  ? 1  : tileHeight);
26752713    baseHeight = b.i32_val (baseHeightInt);
26762714    pitch = getPitch (rewriter, ptr, elemSizeInBits, memoryRowMajor ? 0  : 1 );
@@ -2739,17 +2777,19 @@ struct LoadOpToBlockIOConversion
27392777        }
27402778      } break ;
27412779      case  DpasEncodingAttr::OpIdx::OperandB: {
2742-         assert (numPackedVals == 1  &&
2743-                " invalid number of packed values for DPAS operand B."  );
2780+         //   assert(numPackedVals == 1 &&
2781+         //          "invalid number of packed values for DPAS operand B.");
27442782        unsigned  elemsPerLanePerDPASInst =
27452783            product<unsigned >(dpasLayout.getDPASInstShapeB ()) / threadsPerWarp;
27462784        //  Block 2D contain at least one DotOp B.
27472785        if  (numElemsPerLoad >= elemsPerLanePerDPASInst) {
27482786          unsigned  opsPerChannel = dpasLayout.getOpsPerChannel ();
27492787          unsigned  sysDepth = dpasLayout.getSystolicDepth ();
2750-           if  (tileHeight >= (opsPerChannel * sysDepth) &&
2751-               ((opsPerChannel == 4  && elemSizeInBits == 8 ) ||
2752-                (opsPerChannel == 2  && elemSizeInBits == 16 ))) {
2788+           if  ((opsPerChannel == 4  && elemSizeInBits == 8 ) ||
2789+               (opsPerChannel == 2  && elemSizeInBits == 16 )) {
2790+             assert (!isTransposeRequired ||
2791+                    opsPerChannel == numPackedVals &&
2792+                        " invalid opsPerChannel for transposed DotOp B"  );
27532793            //  Use the VNNI packing format for DotOp B layout.
27542794            numValuesPerLoad = numElemsPerLoad / opsPerChannel;
27552795            packedType = i32_ty;
@@ -2816,8 +2856,8 @@ struct LoadOpToBlockIOConversion
28162856          /* tile_width*/   tileWidth,
28172857          /* tile_height*/   tileHeight,
28182858          /* v_blocks*/   vBlocks,
2819-           /* transpose*/   false ,
2820-           /* vnni_transform*/   useVNNIFormat);
2859+           /* transpose*/   isTransposeRequired ,
2860+           /* vnni_transform*/   !isTransposeRequired &&  useVNNIFormat);
28212861
28222862      //  When strides[0] is 0, we only want to load the first row, so we
28232863      //  set the base height to be 1. If tile height is bigger than 1,
0 commit comments