| 
 | 1 | +import tempfile  | 
 | 2 | +import triton  | 
 | 3 | +from triton.compiler import IRSource, make_backend  | 
 | 4 | +from triton._C.libtriton import ir  | 
 | 5 | + | 
 | 6 | +target = triton.runtime.driver.active.get_current_target()  | 
 | 7 | +backend = make_backend(target)  | 
 | 8 | + | 
 | 9 | + | 
 | 10 | +def test_mlir_attribute_parsing() -> None:  | 
 | 11 | +    '''  | 
 | 12 | +    Tests that MLIR attributes are parsed correctly from input ttir/ttgir.  | 
 | 13 | +
  | 
 | 14 | +    Checks for the following:  | 
 | 15 | +    1. Name and type signature are parsed correctly  | 
 | 16 | +    2. _get_num_warps_from_ir_str() works  | 
 | 17 | +    3. tt.nv_tma_desc attribute is parsed correctly  | 
 | 18 | +    '''  | 
 | 19 | + | 
 | 20 | +    sample_ttgir = r"""  | 
 | 21 | +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>  | 
 | 22 | +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>  | 
 | 23 | +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>  | 
 | 24 | +#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>  | 
 | 25 | +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>  | 
 | 26 | +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {  | 
 | 27 | +  tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},  | 
 | 28 | +                                %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32},  | 
 | 29 | +                                %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},  | 
 | 30 | +                                %arg3: i32 {tt.divisibility = 16 : i32},  | 
 | 31 | +                                %arg4: i32 {tt.divisibility = 16 : i32},  | 
 | 32 | +                                %arg5: i32 {tt.divisibility = 16 : i32},  | 
 | 33 | +                                %arg6: i32 {tt.divisibility = 16 : i32},  | 
 | 34 | +                                %arg7: i32 {tt.divisibility = 16 : i32},  | 
 | 35 | +                                %arg8: i32 {tt.divisibility = 16 : i32, tt.nv_tma_desc = 0 : i32},  | 
 | 36 | +                                %desc: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}) attributes {noinline = false} {  | 
 | 37 | +    tt.return  | 
 | 38 | +  }  | 
 | 39 | +}  | 
 | 40 | +"""  | 
 | 41 | +    with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:  | 
 | 42 | +        f.write(sample_ttgir)  | 
 | 43 | +        f.flush()  | 
 | 44 | +        context = ir.context()  | 
 | 45 | +        src = IRSource(f.name, context, backend)  | 
 | 46 | + | 
 | 47 | +        # check name and type signature  | 
 | 48 | +        # should match ty_to_cpp(...)  | 
 | 49 | +        assert  src.signature == \  | 
 | 50 | +                    {0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \  | 
 | 51 | +                           4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"}  | 
 | 52 | +        assert src.name == "@matmul_kernel"  | 
 | 53 | + | 
 | 54 | +        # check num warps  | 
 | 55 | +        assert src.parse_options()['num_warps'] == 8  | 
 | 56 | + | 
 | 57 | +    sample_ttgir_vector_add = r"""  | 
 | 58 | +    #blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>  | 
 | 59 | +    module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {  | 
 | 60 | +       tt.func public @add_kernel(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32},  | 
 | 61 | +       %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32},  | 
 | 62 | +       %arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32},  | 
 | 63 | +       %arg3: i32 {tt.divisibility = 16 : i32})  | 
 | 64 | +        attributes {noinline = false} {  | 
 | 65 | +         %c1024_i32 = arith.constant 1024 : i32  | 
 | 66 | +         %0 = tt.get_program_id x : i32  | 
 | 67 | +         %1 = arith.muli %0, %c1024_i32 : i32  | 
 | 68 | +         %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>  | 
 | 69 | +         %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>  | 
 | 70 | +         %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>  | 
 | 71 | +         %5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked>  | 
 | 72 | +         %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked>  | 
 | 73 | +         %7 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>  | 
 | 74 | +         %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>  | 
 | 75 | +         %9 = tt.load %8, %6 : tensor<1024x!tt.ptr<i32>, #blocked>  | 
 | 76 | +         %10 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>  | 
 | 77 | +         %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>  | 
 | 78 | +         %12 = tt.load %11, %6 : tensor<1024x!tt.ptr<i32>, #blocked>  | 
 | 79 | +         %13 = arith.addi %9, %12 : tensor<1024xi32, #blocked>  | 
 | 80 | +         %14 = tt.splat %arg2 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>  | 
 | 81 | +         %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>  | 
 | 82 | +         tt.store %15, %13, %6 : tensor<1024x!tt.ptr<i32>, #blocked>  | 
 | 83 | +         tt.return  | 
 | 84 | +       }  | 
 | 85 | +     }  | 
 | 86 | +    """  | 
 | 87 | +    with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:  | 
 | 88 | +        f.write(sample_ttgir_vector_add)  | 
 | 89 | +        f.flush()  | 
 | 90 | +        context = ir.context()  | 
 | 91 | +        src = IRSource(f.name, context, backend)  | 
 | 92 | + | 
 | 93 | +        # now test compilation  | 
 | 94 | +        triton.compile(f.name, target=target)  | 
0 commit comments