-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Triton -> C-Lisp compiler #79
Comments
Contributor should complete tutorials here: https://triton-lang.org/main/getting-started/tutorials/index.html |
A way to look at LLVM IR generated is given here: https://github.com/triton-lang/triton?tab=readme-ov-file#tips-for-hacking |
IR Dump after import triton
import triton.language as tl
@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
# There are multiple 'programs' processing different data. We identify which program
# we are here:
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
# This program will process inputs that are offset from the initial data.
# For instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers:
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses.
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size.
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM.
tl.store(output_ptr + offsets, output, mask=mask) Only pasting first IR ; *** IR Dump After Annotation2MetadataPass on [module] ***
; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
@global_smem = external addrspace(3) global [0 x i8], align 16
define void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3) !dbg !7 {
%5 = call i32 asm "mov.u32 $0, %ctaid.x;", "=r"(), !dbg !10
%6 = mul i32 %5, 1024, !dbg !11
%7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !12
%8 = urem i32 %7, 32, !dbg !12
%9 = udiv i32 %7, 32, !dbg !12
%10 = and i32 %8, 1, !dbg !12
%11 = icmp eq i32 %10, 0, !dbg !12
%12 = select i1 %11, i32 0, i32 4, !dbg !12
%13 = xor i32 0, %12, !dbg !12
%14 = and i32 %8, 2, !dbg !12
%15 = icmp eq i32 %14, 0, !dbg !12
%16 = select i1 %15, i32 0, i32 8, !dbg !12
%17 = xor i32 %13, %16, !dbg !12
%18 = and i32 %8, 4, !dbg !12
%19 = icmp eq i32 %18, 0, !dbg !12
%20 = select i1 %19, i32 0, i32 16, !dbg !12
%21 = xor i32 %17, %20, !dbg !12
%22 = and i32 %8, 8, !dbg !12
%23 = icmp eq i32 %22, 0, !dbg !12
%24 = select i1 %23, i32 0, i32 32, !dbg !12
%25 = xor i32 %21, %24, !dbg !12
%26 = and i32 %8, 16, !dbg !12
%27 = icmp eq i32 %26, 0, !dbg !12
%28 = select i1 %27, i32 0, i32 64, !dbg !12
%29 = xor i32 %25, %28, !dbg !12
%30 = and i32 %9, 1, !dbg !12
%31 = icmp eq i32 %30, 0, !dbg !12
%32 = select i1 %31, i32 0, i32 128, !dbg !12
%33 = xor i32 %29, %32, !dbg !12
%34 = and i32 %9, 2, !dbg !12
%35 = icmp eq i32 %34, 0, !dbg !12
%36 = select i1 %35, i32 0, i32 256, !dbg !12
%37 = xor i32 %33, %36, !dbg !12
%38 = xor i32 512, %12, !dbg !12
%39 = xor i32 %38, %16, !dbg !12
%40 = xor i32 %39, %20, !dbg !12
%41 = xor i32 %40, %24, !dbg !12
%42 = xor i32 %41, %28, !dbg !12
%43 = xor i32 %42, %32, !dbg !12
%44 = xor i32 %43, %36, !dbg !12
%45 = add i32 %37, 0, !dbg !12
%46 = add i32 %44, 0, !dbg !12
%47 = add i32 %6, %45, !dbg !13
%48 = add i32 %6, %46, !dbg !13
%49 = icmp slt i32 %47, %3, !dbg !14
%50 = icmp slt i32 %48, %3, !dbg !14
%51 = getelementptr float, ptr addrspace(1) %0, i32 %47, !dbg !15
%52 = getelementptr float, ptr addrspace(1) %0, i32 %48, !dbg !15
%53 = call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l,b"(ptr addrspace(1) %51, i1 %49), !dbg !16
%54 = extractvalue { i32, i32, i32, i32 } %53, 0, !dbg !16
%55 = bitcast i32 %54 to <1 x float>, !dbg !16
%56 = extractvalue { i32, i32, i32, i32 } %53, 1, !dbg !16
%57 = bitcast i32 %56 to <1 x float>, !dbg !16
%58 = extractvalue { i32, i32, i32, i32 } %53, 2, !dbg !16
%59 = bitcast i32 %58 to <1 x float>, !dbg !16
%60 = extractvalue { i32, i32, i32, i32 } %53, 3, !dbg !16
%61 = bitcast i32 %60 to <1 x float>, !dbg !16
%62 = extractelement <1 x float> %55, i32 0, !dbg !16
%63 = extractelement <1 x float> %57, i32 0, !dbg !16
%64 = extractelement <1 x float> %59, i32 0, !dbg !16
%65 = extractelement <1 x float> %61, i32 0, !dbg !16
%66 = call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l,b"(ptr addrspace(1) %52, i1 %50), !dbg !16
%67 = extractvalue { i32, i32, i32, i32 } %66, 0, !dbg !16
%68 = bitcast i32 %67 to <1 x float>, !dbg !16
%69 = extractvalue { i32, i32, i32, i32 } %66, 1, !dbg !16
%70 = bitcast i32 %69 to <1 x float>, !dbg !16
%71 = extractvalue { i32, i32, i32, i32 } %66, 2, !dbg !16
%72 = bitcast i32 %71 to <1 x float>, !dbg !16
%73 = extractvalue { i32, i32, i32, i32 } %66, 3, !dbg !16
%74 = bitcast i32 %73 to <1 x float>, !dbg !16
%75 = extractelement <1 x float> %68, i32 0, !dbg !16
%76 = extractelement <1 x float> %70, i32 0, !dbg !16
%77 = extractelement <1 x float> %72, i32 0, !dbg !16
%78 = extractelement <1 x float> %74, i32 0, !dbg !16
%79 = getelementptr float, ptr addrspace(1) %1, i32 %47, !dbg !17
%80 = getelementptr float, ptr addrspace(1) %1, i32 %48, !dbg !17
%81 = call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l,b"(ptr addrspace(1) %79, i1 %49), !dbg !18
%82 = extractvalue { i32, i32, i32, i32 } %81, 0, !dbg !18
%83 = bitcast i32 %82 to <1 x float>, !dbg !18
%84 = extractvalue { i32, i32, i32, i32 } %81, 1, !dbg !18
%85 = bitcast i32 %84 to <1 x float>, !dbg !18
%86 = extractvalue { i32, i32, i32, i32 } %81, 2, !dbg !18
%87 = bitcast i32 %86 to <1 x float>, !dbg !18
%88 = extractvalue { i32, i32, i32, i32 } %81, 3, !dbg !18
%89 = bitcast i32 %88 to <1 x float>, !dbg !18
%90 = extractelement <1 x float> %83, i32 0, !dbg !18
%91 = extractelement <1 x float> %85, i32 0, !dbg !18
%92 = extractelement <1 x float> %87, i32 0, !dbg !18
%93 = extractelement <1 x float> %89, i32 0, !dbg !18
%94 = call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l,b"(ptr addrspace(1) %80, i1 %50), !dbg !18
%95 = extractvalue { i32, i32, i32, i32 } %94, 0, !dbg !18
%96 = bitcast i32 %95 to <1 x float>, !dbg !18
%97 = extractvalue { i32, i32, i32, i32 } %94, 1, !dbg !18
%98 = bitcast i32 %97 to <1 x float>, !dbg !18
%99 = extractvalue { i32, i32, i32, i32 } %94, 2, !dbg !18
%100 = bitcast i32 %99 to <1 x float>, !dbg !18
%101 = extractvalue { i32, i32, i32, i32 } %94, 3, !dbg !18
%102 = bitcast i32 %101 to <1 x float>, !dbg !18
%103 = extractelement <1 x float> %96, i32 0, !dbg !18
%104 = extractelement <1 x float> %98, i32 0, !dbg !18
%105 = extractelement <1 x float> %100, i32 0, !dbg !18
%106 = extractelement <1 x float> %102, i32 0, !dbg !18
%107 = fadd float %62, %90, !dbg !19
%108 = fadd float %63, %91, !dbg !19
%109 = fadd float %64, %92, !dbg !19
%110 = fadd float %65, %93, !dbg !19
%111 = fadd float %75, %103, !dbg !19
%112 = fadd float %76, %104, !dbg !19
%113 = fadd float %77, %105, !dbg !19
%114 = fadd float %78, %106, !dbg !19
%115 = getelementptr float, ptr addrspace(1) %2, i32 %47, !dbg !20
%116 = getelementptr float, ptr addrspace(1) %2, i32 %48, !dbg !20
%117 = insertelement <1 x float> undef, float %107, i32 0, !dbg !21
%118 = bitcast <1 x float> %117 to i32, !dbg !21
%119 = insertelement <1 x float> undef, float %108, i32 0, !dbg !21
%120 = bitcast <1 x float> %119 to i32, !dbg !21
%121 = insertelement <1 x float> undef, float %109, i32 0, !dbg !21
%122 = bitcast <1 x float> %121 to i32, !dbg !21
%123 = insertelement <1 x float> undef, float %110, i32 0, !dbg !21
%124 = bitcast <1 x float> %123 to i32, !dbg !21
%125 = and i1 true, %49, !dbg !21
call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %118, i32 %120, i32 %122, i32 %124, ptr addrspace(1) %115, i1 %125), !dbg !21
%126 = insertelement <1 x float> undef, float %111, i32 0, !dbg !21
%127 = bitcast <1 x float> %126 to i32, !dbg !21
%128 = insertelement <1 x float> undef, float %112, i32 0, !dbg !21
%129 = bitcast <1 x float> %128 to i32, !dbg !21
%130 = insertelement <1 x float> undef, float %113, i32 0, !dbg !21
%131 = bitcast <1 x float> %130 to i32, !dbg !21
%132 = insertelement <1 x float> undef, float %114, i32 0, !dbg !21
%133 = bitcast <1 x float> %132 to i32, !dbg !21
%134 = and i1 true, %50, !dbg !21
call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %127, i32 %129, i32 %131, i32 %133, ptr addrspace(1) %116, i1 %134), !dbg !21
ret void, !dbg !22
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0
attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
!llvm.module.flags = !{!0, !1}
!llvm.dbg.cu = !{!2}
!nvvm.annotations = !{!4, !5}
!llvm.ident = !{!6}
!0 = !{i32 2, !"Debug Info Version", i32 3}
!1 = !{i32 4, !"nvvm-reflect-ftz", i32 1}
!2 = distinct !DICompileUnit(language: DW_LANG_C, file: !3, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly)
!3 = !DIFile(filename: "01-vector-add.py", directory: "/home/sasank/code")
!4 = !{ptr @add_kernel, !"kernel", i32 1}
!5 = !{ptr @add_kernel, !"maxntidx", i32 128}
!6 = !{!"clang version 3.8.0 (tags/RELEASE_380/final)"}
!7 = distinct !DISubprogram(name: "add_kernel", linkageName: "add_kernel", scope: !3, file: !3, line: 28, type: !8, scopeLine: 28, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2)
!8 = !DISubroutineType(cc: DW_CC_normal, types: !9)
!9 = !{}
!10 = !DILocation(line: 37, column: 24, scope: !7)
!11 = !DILocation(line: 42, column: 24, scope: !7)
!12 = !DILocation(line: 43, column: 41, scope: !7)
!13 = !DILocation(line: 43, column: 28, scope: !7)
!14 = !DILocation(line: 45, column: 21, scope: !7)
!15 = !DILocation(line: 48, column: 24, scope: !7)
!16 = !DILocation(line: 48, column: 16, scope: !7)
!17 = !DILocation(line: 49, column: 24, scope: !7)
!18 = !DILocation(line: 49, column: 16, scope: !7)
!19 = !DILocation(line: 50, column: 17, scope: !7)
!20 = !DILocation(line: 52, column: 26, scope: !7)
!21 = !DILocation(line: 52, column: 35, scope: !7)
!22 = !DILocation(line: 52, column: 4, scope: !7) |
This doesn't look very nice. We have
|
I checked out version 1.1.1 and hacked around the LLVM IR from the dumbed pickle binary: Triton code: @triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
**meta, # Optional meta-parameters for the kernel
):
BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
# This program will process inputs that are offset from the initial data.
# for instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extar elements in case the input is not a
# multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask) Triton IR: def void add_kernel(f32* x_ptr .aligned(16) , f32* y_ptr .aligned(16) , f32* output_ptr .aligned(16) , i32 n_elements .multipleof(16) )
{
entry:
%0 = get_program_id(0) i32;
%1 = mul i32 %0, 1024;
%3 = make_range[0 : 1024] i32<1024>;
%4 = splat i32<1024> %1;
%6 = add i32<1024> %4, %3;
%9 = splat i32<1024> n_elements;
%11 = icmp_slt i1<1024> %6, %9;
%14 = splat f32*<1024> x_ptr;
%16 = getelementptr f32*<1024> %14, %6;
%19 = splat f32<1024> undef;
%20 = masked_load f32<1024> %16, %11, %19;
%24 = splat f32*<1024> y_ptr;
%26 = getelementptr f32*<1024> %24, %6;
%29 = splat f32<1024> undef;
%30 = masked_load f32<1024> %26, %11, %29;
%34 = fadd f32<1024> %20, %30;
%37 = splat f32*<1024> output_ptr;
%39 = getelementptr f32*<1024> %37, %6;
masked_store void %39, %34, %11;
ret void;
} LLVM IR: ; ModuleID = 'add_kernel'
source_filename = "add_kernel"
define void @add_kernel(float addrspace(1)* align 16 %0, float addrspace(1)* align 16 %1, float addrspace(1)* align 16 %2, i32 %3) {
entry:
%4 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%5 = urem i32 %4, 32
%6 = udiv i32 %4, 32
%7 = mul i32 %6, 32
%8 = add i32 %7, %5
%9 = mul i32 %8, 4
%idx_0_0 = add i32 %9, 0
%idx_0_1 = add i32 %9, 1
%idx_0_2 = add i32 %9, 2
%idx_0_3 = add i32 %9, 3
%idx_0_4 = add i32 %9, 512
%idx_0_5 = add i32 %9, 513
%idx_0_6 = add i32 %9, 514
%idx_0_7 = add i32 %9, 515
%10 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%11 = mul i32 %10, 1024
%12 = add i32 0, %9
%13 = add i32 %12, 0
%14 = add i32 0, %9
%15 = add i32 %14, 1
%16 = add i32 0, %9
%17 = add i32 %16, 2
%18 = add i32 0, %9
%19 = add i32 %18, 3
%20 = add i32 0, %9
%21 = add i32 %20, 512
%22 = add i32 0, %9
%23 = add i32 %22, 513
%24 = add i32 0, %9
%25 = add i32 %24, 514
%26 = add i32 0, %9
%27 = add i32 %26, 515
%28 = add i32 %11, %12
%29 = add i32 %28, 0
%30 = add i32 %11, %14
%31 = add i32 %30, 1
%32 = add i32 %11, %16
%33 = add i32 %32, 2
%34 = add i32 %11, %18
%35 = add i32 %34, 3
%36 = add i32 %11, %20
%37 = add i32 %36, 512
%38 = add i32 %11, %22
%39 = add i32 %38, 513
%40 = add i32 %11, %24
%41 = add i32 %40, 514
%42 = add i32 %11, %26
%43 = add i32 %42, 515
%44 = icmp slt i32 %29, %3
%45 = icmp slt i32 %31, %3
%46 = icmp slt i32 %33, %3
%47 = icmp slt i32 %35, %3
%48 = icmp slt i32 %37, %3
%49 = icmp slt i32 %39, %3
%50 = icmp slt i32 %41, %3
%51 = icmp slt i32 %43, %3
%52 = getelementptr float, float addrspace(1)* %0, i32 %28
%53 = getelementptr float, float addrspace(1)* %52, i32 0
%54 = getelementptr float, float addrspace(1)* %0, i32 %30
%55 = getelementptr float, float addrspace(1)* %54, i32 1
%56 = getelementptr float, float addrspace(1)* %0, i32 %32
%57 = getelementptr float, float addrspace(1)* %56, i32 2
%58 = getelementptr float, float addrspace(1)* %0, i32 %34
%59 = getelementptr float, float addrspace(1)* %58, i32 3
%60 = getelementptr float, float addrspace(1)* %0, i32 %36
%61 = getelementptr float, float addrspace(1)* %60, i32 512
%62 = getelementptr float, float addrspace(1)* %0, i32 %38
%63 = getelementptr float, float addrspace(1)* %62, i32 513
%64 = getelementptr float, float addrspace(1)* %0, i32 %40
%65 = getelementptr float, float addrspace(1)* %64, i32 514
%66 = getelementptr float, float addrspace(1)* %0, i32 %42
%67 = getelementptr float, float addrspace(1)* %66, i32 515
%68 = call { i32, i32, i32, i32 } asm sideeffect "@$4 ld.global.v4.b32 {$0,$1,$2,$3}, [ $5 + 0];", "=r,=r,=r,=r,b,l"(i1 %44, float addrspace(1)* %52)
%69 = extractvalue { i32, i32, i32, i32 } %68, 0
%70 = bitcast i32 %69 to <1 x float>
%71 = extractvalue { i32, i32, i32, i32 } %68, 1
%72 = bitcast i32 %71 to <1 x float>
%73 = extractvalue { i32, i32, i32, i32 } %68, 2
%74 = bitcast i32 %73 to <1 x float>
%75 = extractvalue { i32, i32, i32, i32 } %68, 3
%76 = bitcast i32 %75 to <1 x float>
%77 = extractelement <1 x float> %70, i64 0
%78 = extractelement <1 x float> %72, i64 0
%79 = extractelement <1 x float> %74, i64 0
%80 = extractelement <1 x float> %76, i64 0
%81 = call { i32, i32, i32, i32 } asm sideeffect "@$4 ld.global.v4.b32 {$0,$1,$2,$3}, [ $5 + 2048];", "=r,=r,=r,=r,b,l"(i1 %48, float addrspace(1)* %60)
%82 = extractvalue { i32, i32, i32, i32 } %81, 0
%83 = bitcast i32 %82 to <1 x float>
%84 = extractvalue { i32, i32, i32, i32 } %81, 1
%85 = bitcast i32 %84 to <1 x float>
%86 = extractvalue { i32, i32, i32, i32 } %81, 2
%87 = bitcast i32 %86 to <1 x float>
%88 = extractvalue { i32, i32, i32, i32 } %81, 3
%89 = bitcast i32 %88 to <1 x float>
%90 = extractelement <1 x float> %83, i64 0
%91 = extractelement <1 x float> %85, i64 0
%92 = extractelement <1 x float> %87, i64 0
%93 = extractelement <1 x float> %89, i64 0
%94 = getelementptr float, float addrspace(1)* %1, i32 %28
%95 = getelementptr float, float addrspace(1)* %94, i32 0
%96 = getelementptr float, float addrspace(1)* %1, i32 %30
%97 = getelementptr float, float addrspace(1)* %96, i32 1
%98 = getelementptr float, float addrspace(1)* %1, i32 %32
%99 = getelementptr float, float addrspace(1)* %98, i32 2
%100 = getelementptr float, float addrspace(1)* %1, i32 %34
%101 = getelementptr float, float addrspace(1)* %100, i32 3
%102 = getelementptr float, float addrspace(1)* %1, i32 %36
%103 = getelementptr float, float addrspace(1)* %102, i32 512
%104 = getelementptr float, float addrspace(1)* %1, i32 %38
%105 = getelementptr float, float addrspace(1)* %104, i32 513
%106 = getelementptr float, float addrspace(1)* %1, i32 %40
%107 = getelementptr float, float addrspace(1)* %106, i32 514
%108 = getelementptr float, float addrspace(1)* %1, i32 %42
%109 = getelementptr float, float addrspace(1)* %108, i32 515
%110 = call { i32, i32, i32, i32 } asm sideeffect "@$4 ld.global.v4.b32 {$0,$1,$2,$3}, [ $5 + 0];", "=r,=r,=r,=r,b,l"(i1 %44, float addrspace(1)* %94)
%111 = extractvalue { i32, i32, i32, i32 } %110, 0
%112 = bitcast i32 %111 to <1 x float>
%113 = extractvalue { i32, i32, i32, i32 } %110, 1
%114 = bitcast i32 %113 to <1 x float>
%115 = extractvalue { i32, i32, i32, i32 } %110, 2
%116 = bitcast i32 %115 to <1 x float>
%117 = extractvalue { i32, i32, i32, i32 } %110, 3
%118 = bitcast i32 %117 to <1 x float>
%119 = extractelement <1 x float> %112, i64 0
%120 = extractelement <1 x float> %114, i64 0
%121 = extractelement <1 x float> %116, i64 0
%122 = extractelement <1 x float> %118, i64 0
%123 = call { i32, i32, i32, i32 } asm sideeffect "@$4 ld.global.v4.b32 {$0,$1,$2,$3}, [ $5 + 2048];", "=r,=r,=r,=r,b,l"(i1 %48, float addrspace(1)* %102)
%124 = extractvalue { i32, i32, i32, i32 } %123, 0
%125 = bitcast i32 %124 to <1 x float>
%126 = extractvalue { i32, i32, i32, i32 } %123, 1
%127 = bitcast i32 %126 to <1 x float>
%128 = extractvalue { i32, i32, i32, i32 } %123, 2
%129 = bitcast i32 %128 to <1 x float>
%130 = extractvalue { i32, i32, i32, i32 } %123, 3
%131 = bitcast i32 %130 to <1 x float>
%132 = extractelement <1 x float> %125, i64 0
%133 = extractelement <1 x float> %127, i64 0
%134 = extractelement <1 x float> %129, i64 0
%135 = extractelement <1 x float> %131, i64 0
%136 = fadd float %77, %119
%137 = fadd float %78, %120
%138 = fadd float %79, %121
%139 = fadd float %80, %122
%140 = fadd float %90, %132
%141 = fadd float %91, %133
%142 = fadd float %92, %134
%143 = fadd float %93, %135
%144 = getelementptr float, float addrspace(1)* %2, i32 %28
%145 = getelementptr float, float addrspace(1)* %144, i32 0
%146 = getelementptr float, float addrspace(1)* %2, i32 %30
%147 = getelementptr float, float addrspace(1)* %146, i32 1
%148 = getelementptr float, float addrspace(1)* %2, i32 %32
%149 = getelementptr float, float addrspace(1)* %148, i32 2
%150 = getelementptr float, float addrspace(1)* %2, i32 %34
%151 = getelementptr float, float addrspace(1)* %150, i32 3
%152 = getelementptr float, float addrspace(1)* %2, i32 %36
%153 = getelementptr float, float addrspace(1)* %152, i32 512
%154 = getelementptr float, float addrspace(1)* %2, i32 %38
%155 = getelementptr float, float addrspace(1)* %154, i32 513
%156 = getelementptr float, float addrspace(1)* %2, i32 %40
%157 = getelementptr float, float addrspace(1)* %156, i32 514
%158 = getelementptr float, float addrspace(1)* %2, i32 %42
%159 = getelementptr float, float addrspace(1)* %158, i32 515
%160 = bitcast float addrspace(1)* %145 to <4 x float> addrspace(1)*
%161 = insertelement <4 x float> undef, float %136, i64 0
%162 = insertelement <4 x float> %161, float %137, i64 1
%163 = insertelement <4 x float> %162, float %138, i64 2
%164 = insertelement <4 x float> %163, float %139, i64 3
br i1 %44, label %165, label %166
165: ; preds = %entry
store <4 x float> %164, <4 x float> addrspace(1)* %160, align 16
br label %166
166: ; preds = %entry, %165
%167 = bitcast float addrspace(1)* %153 to <4 x float> addrspace(1)*
%168 = insertelement <4 x float> undef, float %140, i64 0
%169 = insertelement <4 x float> %168, float %141, i64 1
%170 = insertelement <4 x float> %169, float %142, i64 2
%171 = insertelement <4 x float> %170, float %143, i64 3
br i1 %48, label %172, label %173
172: ; preds = %166
store <4 x float> %171, <4 x float> addrspace(1)* %167, align 16
br label %173
173: ; preds = %166, %172
call void @llvm.donothing()
call void @llvm.donothing()
ret void
}
; Function Attrs: nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0
; Function Attrs: nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #0
; Function Attrs: nounwind readnone willreturn
declare void @llvm.donothing() #1
attributes #0 = { nounwind readnone }
attributes #1 = { nounwind readnone willreturn }
!nvvm.annotations = !{!0, !1}
!0 = !{void (float addrspace(1)*, float addrspace(1)*, float addrspace(1)*, i32)* @add_kernel, !"kernel", i32 1}
!1 = !{void (float addrspace(1)*, float addrspace(1)*, float addrspace(1)*, i32)* @add_kernel, !"maxntidx", i32 128} PTX
|
I don't like seeing inline assembly here :(. Why do we need inline assembly? Can't LLVM intrinsics get us to this? |
Had a quick look at the IR. My observations are the following:
|
Thanks a lot for your analysis. I wonder why these assembly instructions are used instead of LLVM intrinsics. |
For kernel programming, we should be able to create triton backend for Nvidia GPUs etc. Paper is found here: https://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf. Implementation itself is here: https://github.com/triton-lang/triton
The text was updated successfully, but these errors were encountered: