@@ -6410,12 +6410,25 @@ def dataTypeNameAbbrevToInstType(abbrev: str, sourceSwap: bool = False) -> InstT
6410
6410
if numTmpSgpr == 4:
6411
6411
tmpSgprX3 = tmpSgprInfo.idx+3
6412
6412
6413
- sparseInputBlocksA = 1
6414
- sparseInputBlocksB = 1
6415
- if kernel["MIInputPerThread"] * kernel["ProblemType"]["DataTypeB"].numBytes() > 16 :
6416
- sparseInputBlocksA = 2 if kernel["ProblemType"]["Sparse"] == 2 else 1
6417
- sparseInputBlocksB = 2 if kernel["ProblemType"]["Sparse"] == 1 else 1
6418
-
6413
+ def findSparseOffset(isA:bool):
6414
+ blocksPerTGroupSMFMA = 1
6415
+ elementsPerBlockSMFMA = 1
6416
+ blockOffsetSMFMA =1
6417
+ if kernel["ProblemType"]["Sparse"] != 0:
6418
+ if kernel["MIInputPerThread"] * kernel["ProblemType"]["DataTypeB"].numBytes() > 16: # double K
6419
+ isSparseTrack = (kernel["ProblemType"]["Sparse"] == 1 and isA) or (kernel["ProblemType"]["Sparse"] == 2 and not isA)
6420
+ # gfx950 sparse track only has one block for each thread group.
6421
+ # TODO adjust this value for other arch.
6422
+ blocksPerTGroupSMFMA = 1 if isSparseTrack else 2
6423
+ if blocksPerTGroupSMFMA > 1:
6424
+ threadGroups = kernel["MatrixInstK"] // kernel["MIInputPerThread"]
6425
+ elementsPerBlockSMFMA = kernel["MIInputPerThread"] // blocksPerTGroupSMFMA
6426
+ blockStride = elementsPerBlockSMFMA * threadGroups
6427
+ blockOffsetSMFMA = blockStride - elementsPerBlockSMFMA
6428
+ return blocksPerTGroupSMFMA, elementsPerBlockSMFMA, blockOffsetSMFMA
6429
+
6430
+ blocksPerTGroupSMFMAA, elementsPerBlockSMFMAA, blockOffsetSMFMAA = findSparseOffset(True)
6431
+ blocksPerTGroupSMFMAB, elementsPerBlockSMFMAB, blockOffsetSMFMAB = findSparseOffset(False)
6419
6432
# replace 0 for differnet thread
6420
6433
if kernel["ProblemType"]["Sparse"] == 1 and numMIInput//8 >= 1:
6421
6434
vgprPerSet0Group = 1
@@ -6436,16 +6449,15 @@ def dataTypeNameAbbrevToInstType(abbrev: str, sourceSwap: bool = False) -> InstT
6436
6449
for group in range(0, numSet0GroupA):
6437
6450
if numSet0GroupA > 1 or (is_wmma_v2 and vgprPerInputA > 2):
6438
6451
if group == 0:
6439
- if kernel["ProblemType"]["Sparse"] == 1 :
6440
- multiplyBy = numMIInput//sparseInputBlocksA
6452
+ if kernel["ProblemType"]["Sparse"]:
6453
+ multiplyBy = numMIInput//blocksPerTGroupSMFMAA
6441
6454
else:
6442
6455
multiplyBy = numMIInput//2 if vgprPerInputA == 8 else numMIInput
6443
6456
shiftK.add(vectorStaticMultiply(vgpr(kReg_first), vgpr(kReg_first), multiplyBy, tmpSgprInfo))
6444
6457
shiftK.add(VAddU32(vgpr(kReg), vgpr(kReg_first), 0, ""))
6445
- elif kernel["ProblemType"]["Sparse"] == 2 and group == 2 and sparseInputBlocksA == 2:
6446
- is8bits = kernel["ProblemType"]["DataType"].numBytes() == 1
6447
- strideK = 3 if kernel["MatrixInstK"]== 32 or (is8bits and kernel["MatrixInstK"]== 64) else 7
6448
- shiftK.add(VAddU32(vgpr(kReg), vgpr(kReg), numMIInput//numSet0GroupA*strideK, "add part of K"))
6458
+ elif blocksPerTGroupSMFMAA == 2 and (group * vgprPerSet0Group) == (elementsPerBlockSMFMAA * numRegistersIn):
6459
+ kIncA = blockOffsetSMFMAA + (numMIInput//numSet0GroupA) * max(group - 1, 0)
6460
+ shiftK.add(VAddU32(vgpr(kReg), vgpr(kReg), kIncA, "add part of K"))
6449
6461
else:
6450
6462
kIncA = numMIInput//numSet0GroupA
6451
6463
if self.states.asmCaps["HasMFMA_f8f6f4"]:
@@ -6482,16 +6494,15 @@ def dataTypeNameAbbrevToInstType(abbrev: str, sourceSwap: bool = False) -> InstT
6482
6494
for group in range(0, numSet0GroupB):
6483
6495
if numSet0GroupB > 1 or (is_wmma_v2 and vgprPerInputB > 2):
6484
6496
if group == 0:
6485
- if kernel["ProblemType"]["Sparse"] == 1 :
6486
- multiplyBy = numMIInput//sparseInputBlocksB
6497
+ if kernel["ProblemType"]["Sparse"]:
6498
+ multiplyBy = numMIInput//blocksPerTGroupSMFMAB
6487
6499
else:
6488
6500
multiplyBy = numMIInput//2 if vgprPerInputB == 8 else numMIInput
6489
6501
shiftK.add(vectorStaticMultiply(vgpr(kReg_first), vgpr(kReg_first), multiplyBy, tmpSgprInfo))
6490
6502
shiftK.add(VAddU32(vgpr(kReg), vgpr(kReg_first), 0, ""))
6491
- elif kernel["ProblemType"]["Sparse"] == 1 and group == 2 and sparseInputBlocksB == 2:
6492
- is8bits = kernel["ProblemType"]["DataType"].numBytes() == 1
6493
- strideK = 3 if kernel["MatrixInstK"]== 32 or (is8bits and kernel["MatrixInstK"]== 64) else 7
6494
- shiftK.add(VAddU32(vgpr(kReg), vgpr(kReg), numMIInput//numSet0GroupB*strideK, "add part of K"))
6503
+ elif blocksPerTGroupSMFMAB == 2 and (group * vgprPerSet0Group) == (elementsPerBlockSMFMAB * numRegistersIn):
6504
+ kIncB = blockOffsetSMFMAB + (numMIInput//numSet0GroupB) * max(group - 1, 0)
6505
+ shiftK.add(VAddU32(vgpr(kReg), vgpr(kReg), kIncB, "add part of K"))
6495
6506
else:
6496
6507
kIncB = numMIInput//numSet0GroupB
6497
6508
if group == 2 and vgprPerInputB == 8:
0 commit comments