@@ -2521,18 +2521,6 @@ struct LoadOpToBlockIOConversion
25212521    if  (tileHeight * tileWidth * packedElemSizeInBits / 8  < GRF_SIZE)
25222522      vBlocks = 1 ;
25232523
2524-     //  TODO: use the axis info to general the handling for both regular pointer
2525-     //  and block pointer.
2526-     const  bool  memoryRowMajor = isMemoryRowMajor (op);
2527-     //  FIXME: Add support of column major.
2528-     if  (!memoryRowMajor)
2529-       return  failure ();
2530- 
2531-     unsigned  contiguousDim = memoryRowMajor ? 1  : 0 ;
2532-     const  bool  isTransposeRequired = contiguousDim != colDim;
2533-     if  (isTransposeRequired)
2534-       return  matchAndRewriteTranspose (op, adaptor, rewriter);
2535- 
25362524    Location loc = op.getLoc ();
25372525    auto  b = TritonLLVMOpBuilder (loc, rewriter);
25382526    MLIRContext *ctx = rewriter.getContext ();
@@ -2661,6 +2649,55 @@ struct LoadOpToBlockIOConversion
26612649      }
26622650    }
26632651
2652+     //  TODO: use the axis info to general the handling for both regular pointer
2653+     //  and block pointer.
2654+     const  bool  memoryRowMajor = isMemoryRowMajor (op);
2655+     unsigned  contiguousDim = memoryRowMajor ? 1  : 0 ;
2656+     const  bool  isTransposeRequired = contiguousDim != colDim;
2657+ 
2658+     if  (isTransposeRequired) {
2659+       if  (numPackedVals > 1 )
2660+         return  failure ();
2661+       if  (elemSizeInBits > 32 )
2662+         return  failure ();
2663+       if  (tileWidth > 32 )
2664+         return  failure (); //  tileWidth is limited to 32 for transpose 2d load.
2665+ 
2666+       vBlocks = 1 ;
2667+ 
2668+       //  use the d32 for transpose 2d load.
2669+       packedElemSizeInBits = 32 ;
2670+       numPackedVals = packedElemSizeInBits / elemSizeInBits;
2671+       if  (numPackedVals > 1  && tileWidth != threadsPerWarp)
2672+         return  failure (); //  Couldn't use the transpose 2d load for un-packable
2673+                           //  along tile height dim.
2674+       tileHeight = std::min (tileHeight / numPackedVals, 8 );
2675+ 
2676+       if  (tileHeight * tileWidth < threadsPerWarp)
2677+         return  failure (); //  The tile size is not large enough for IGC scalar
2678+                           //  backend vectorization.
2679+       //  transpose the width and height of the tile
2680+       std::swap (tileHeight, tileWidth);
2681+       //  if (oneMatrixPerLoadForBT) {
2682+       //    // Only load 1 operand per inst on row.
2683+       //    numOperandsPer2DLoadM = 1;
2684+       //    tileHeight = elemsPerDPASInst[threadOrder[rank - 2]];
2685+       //  } else {
2686+       //    // We can decompose the matrix returned by transposed large 2d load
2687+       //    // when threads per warp < column size. Otherwise we have to load one
2688+       //    // operand per inst.
2689+       //    // Note: the tileHeight and numOperandsPer2DLoadM are the column size
2690+       //    // now.
2691+       //    numOperandsPer2DLoadM =
2692+       //        (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
2693+       //  }
2694+       //  // The transpose 2d load only support 1 operand per inst on column.
2695+       //  // (vBlocks = 1)
2696+       //  numOperandsPer2DloadN = 1;
2697+       //  // TODO: support load column major data.
2698+       //  return failure();
2699+     }
2700+ 
26642701    int64_t  numElemsPerLoad = mlir::ceil (
26652702        tileHeight * tileWidth * numPackedVals * vBlocks, (int )threadsPerWarp);
26662703    unsigned  numValuesPerLoad = mlir::ceil ((int )numElemsPerLoad, numPackedVals);
@@ -2740,8 +2777,6 @@ struct LoadOpToBlockIOConversion
27402777        }
27412778      } break ;
27422779      case  DpasEncodingAttr::OpIdx::OperandB: {
2743-         assert (numPackedVals == 1  &&
2744-                " invalid number of packed values for DPAS operand B."  );
27452780        unsigned  elemsPerLanePerDPASInst =
27462781            product<unsigned >(dpasLayout.getDPASInstShapeB ()) / threadsPerWarp;
27472782        //  Block 2D contain at least one DotOp B.
@@ -2751,6 +2786,9 @@ struct LoadOpToBlockIOConversion
27512786          if  (tileHeight >= (opsPerChannel * sysDepth) &&
27522787              ((opsPerChannel == 4  && elemSizeInBits == 8 ) ||
27532788               (opsPerChannel == 2  && elemSizeInBits == 16 ))) {
2789+             assert (!isTransposeRequired ||
2790+                    opsPerChannel == numPackedVals &&
2791+                        " invalid opsPerChannel for transposed DotOp B"  );
27542792            //  Use the VNNI packing format for DotOp B layout.
27552793            numValuesPerLoad = numElemsPerLoad / opsPerChannel;
27562794            packedType = i32_ty;
@@ -2814,8 +2852,8 @@ struct LoadOpToBlockIOConversion
28142852          /* tile_width*/   tileWidth,
28152853          /* tile_height*/   tileHeight,
28162854          /* v_blocks*/   vBlocks,
2817-           /* transpose*/   false ,
2818-           /* vnni_transform*/   useVNNIFormat);
2855+           /* transpose*/   isTransposeRequired ,
2856+           /* vnni_transform*/   !isTransposeRequired &&  useVNNIFormat);
28192857
28202858      //  When strides[0] is 0, we only want to load the first row, so we
28212859      //  set the base height to be 1. If tile height is bigger than 1,
0 commit comments