Skip to content

Commit 17bb16a

Browse files
committed
Refine the method for caculating sparse mfma offset for checking the over k in tile section.
1 parent 96c51cc commit 17bb16a

File tree

1 file changed

+29
-18
lines changed

1 file changed

+29
-18
lines changed

tensilelite/Tensile/KernelWriterAssembly.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6410,12 +6410,25 @@ def dataTypeNameAbbrevToInstType(abbrev: str, sourceSwap: bool = False) -> InstT
64106410
if numTmpSgpr == 4:
64116411
tmpSgprX3 = tmpSgprInfo.idx+3
64126412

6413-
sparseInputBlocksA = 1
6414-
sparseInputBlocksB = 1
6415-
if kernel["MIInputPerThread"] * kernel["ProblemType"]["DataTypeB"].numBytes() > 16 :
6416-
sparseInputBlocksA = 2 if kernel["ProblemType"]["Sparse"] == 2 else 1
6417-
sparseInputBlocksB = 2 if kernel["ProblemType"]["Sparse"] == 1 else 1
6418-
6413+
def findSparseOffset(isA:bool):
6414+
blocksPerTGroupSMFMA = 1
6415+
elementsPerBlockSMFMA = 1
6416+
blockOffsetSMFMA =1
6417+
if kernel["ProblemType"]["Sparse"] != 0:
6418+
if kernel["MIInputPerThread"] * kernel["ProblemType"]["DataTypeB"].numBytes() > 16: # double K
6419+
isSparseTrack = (kernel["ProblemType"]["Sparse"] == 1 and isA) or (kernel["ProblemType"]["Sparse"] == 2 and not isA)
6420+
# gfx950 sparse track only has one block for each thread group.
6421+
# TODO adjust this value for other arch.
6422+
blocksPerTGroupSMFMA = 1 if isSparseTrack else 2
6423+
if blocksPerTGroupSMFMA > 1:
6424+
threadGroups = kernel["MatrixInstK"] // kernel["MIInputPerThread"]
6425+
elementsPerBlockSMFMA = kernel["MIInputPerThread"] // blocksPerTGroupSMFMA
6426+
blockStride = elementsPerBlockSMFMA * threadGroups
6427+
blockOffsetSMFMA = blockStride - elementsPerBlockSMFMA
6428+
return blocksPerTGroupSMFMA, elementsPerBlockSMFMA, blockOffsetSMFMA
6429+
6430+
blocksPerTGroupSMFMAA, elementsPerBlockSMFMAA, blockOffsetSMFMAA = findSparseOffset(True)
6431+
blocksPerTGroupSMFMAB, elementsPerBlockSMFMAB, blockOffsetSMFMAB = findSparseOffset(False)
64196432
# replace 0 for differnet thread
64206433
if kernel["ProblemType"]["Sparse"] == 1 and numMIInput//8 >= 1:
64216434
vgprPerSet0Group = 1
@@ -6436,16 +6449,15 @@ def dataTypeNameAbbrevToInstType(abbrev: str, sourceSwap: bool = False) -> InstT
64366449
for group in range(0, numSet0GroupA):
64376450
if numSet0GroupA > 1 or (is_wmma_v2 and vgprPerInputA > 2):
64386451
if group == 0:
6439-
if kernel["ProblemType"]["Sparse"] == 1:
6440-
multiplyBy = numMIInput//sparseInputBlocksA
6452+
if kernel["ProblemType"]["Sparse"]:
6453+
multiplyBy = numMIInput//blocksPerTGroupSMFMAA
64416454
else:
64426455
multiplyBy = numMIInput//2 if vgprPerInputA == 8 else numMIInput
64436456
shiftK.add(vectorStaticMultiply(vgpr(kReg_first), vgpr(kReg_first), multiplyBy, tmpSgprInfo))
64446457
shiftK.add(VAddU32(vgpr(kReg), vgpr(kReg_first), 0, ""))
6445-
elif kernel["ProblemType"]["Sparse"] == 2 and group == 2 and sparseInputBlocksA == 2:
6446-
is8bits = kernel["ProblemType"]["DataType"].numBytes() == 1
6447-
strideK = 3 if kernel["MatrixInstK"]== 32 or (is8bits and kernel["MatrixInstK"]== 64) else 7
6448-
shiftK.add(VAddU32(vgpr(kReg), vgpr(kReg), numMIInput//numSet0GroupA*strideK, "add part of K"))
6458+
elif blocksPerTGroupSMFMAA == 2 and (group * vgprPerSet0Group) == (elementsPerBlockSMFMAA * numRegistersIn):
6459+
kIncA = blockOffsetSMFMAA + (numMIInput//numSet0GroupA) * max(group - 1, 0)
6460+
shiftK.add(VAddU32(vgpr(kReg), vgpr(kReg), kIncA, "add part of K"))
64496461
else:
64506462
kIncA = numMIInput//numSet0GroupA
64516463
if self.states.asmCaps["HasMFMA_f8f6f4"]:
@@ -6482,16 +6494,15 @@ def dataTypeNameAbbrevToInstType(abbrev: str, sourceSwap: bool = False) -> InstT
64826494
for group in range(0, numSet0GroupB):
64836495
if numSet0GroupB > 1 or (is_wmma_v2 and vgprPerInputB > 2):
64846496
if group == 0:
6485-
if kernel["ProblemType"]["Sparse"] == 1:
6486-
multiplyBy = numMIInput//sparseInputBlocksB
6497+
if kernel["ProblemType"]["Sparse"]:
6498+
multiplyBy = numMIInput//blocksPerTGroupSMFMAB
64876499
else:
64886500
multiplyBy = numMIInput//2 if vgprPerInputB == 8 else numMIInput
64896501
shiftK.add(vectorStaticMultiply(vgpr(kReg_first), vgpr(kReg_first), multiplyBy, tmpSgprInfo))
64906502
shiftK.add(VAddU32(vgpr(kReg), vgpr(kReg_first), 0, ""))
6491-
elif kernel["ProblemType"]["Sparse"] == 1 and group == 2 and sparseInputBlocksB == 2:
6492-
is8bits = kernel["ProblemType"]["DataType"].numBytes() == 1
6493-
strideK = 3 if kernel["MatrixInstK"]== 32 or (is8bits and kernel["MatrixInstK"]== 64) else 7
6494-
shiftK.add(VAddU32(vgpr(kReg), vgpr(kReg), numMIInput//numSet0GroupB*strideK, "add part of K"))
6503+
elif blocksPerTGroupSMFMAB == 2 and (group * vgprPerSet0Group) == (elementsPerBlockSMFMAB * numRegistersIn):
6504+
kIncB = blockOffsetSMFMAB + (numMIInput//numSet0GroupB) * max(group - 1, 0)
6505+
shiftK.add(VAddU32(vgpr(kReg), vgpr(kReg), kIncB, "add part of K"))
64956506
else:
64966507
kIncB = numMIInput//numSet0GroupB
64976508
if group == 2 and vgprPerInputB == 8:

0 commit comments

Comments
 (0)