Skip to content

Fork of NNop.jl for registration in MurrellGroupRegistry (and possibly experimentation)

Notifications You must be signed in to change notification settings

MurrellGroup/NNop.jl

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

NNop.jl

GPU Backend CI Status
AMDGPU
CUDA

Kernels (with ChainRules.jl integration):

Benchmarking

See benchmarks/main.jl for comparison scripts between naїve & fused versions.

Flash Attention

Implementation of FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.

E, L, H, B = 64, 4096, 4, 4
causal = false

q = ROCArray(rand(Float32, E, L, H, B))
k = ROCArray(rand(Float32, E, L, H, B))
v = ROCArray(rand(Float32, E, L, H, B))

o = NNop.flash_attention(q, k, v; causal)
∇ = Zygote.gradient(q, k, v) do q, k, v
    sum(NNop.flash_attention(q, k, v; causal))
end
Naїve attention Flash Attention
FWD
Execution time 60.987 ms 18.380 ms
Peak memory usage 5.044 GiB 16.500 MiB
FWD + BWD
Execution time 1.154 s 306.960 ms
Peak memory usage 19.164 GiB 80.813 MiB

Features:

  • Forward & backward passes.
  • Arbitrary sequence length.
  • FP32, FP16, BFP16 support.
  • Variable sequence length.
  • Causal masking.

Fused Softmax

Implementation of Online normalizer calculation for softmax.

x = ROCArray(rand(Float32, 8192, 1024))
y = NNop.online_softmax(x)
Naїve Softmax Online Softmax
Execution time 745.123 μs 61.600 μs
Peak memory usage 64.258 MiB 32.000 MiB

Fused RMS Norm

x = ROCArray(rand(Float32, 1024, 1024))
w = ROCArray(rand(Float32, 1024))
y = NNop.rms_norm(x, w)
∇ = Zygote.gradient(x, w) do x, w
    sum(NNop.rms_norm(x, w))
end
Naїve RMS Norm Fused RMS Norm
FWD
Execution time 171.124 μs 48.432 μs
Peak memory usage 8.004 MiB 4.004 MiB
FWD + BWD
Execution time 902.919 μs 241.838 μs
Peak memory usage 44.043 MiB 13.008 MiB

Fused Layer Norm

x = ROCArray(rand(Float32, 1024, 1024))
w = ROCArray(rand(Float32, 1024))
w = ROCArray(rand(Float32, 1024))
y = NNop.layer_norm(x, w)
∇ = Zygote.gradient(x, w, b) do x, w, b
    sum(NNop.layer_norm(x, w, b))
end
Naїve Layer Norm Fused Layer Norm
FWD
Execution time 188.392 μs 48.175 μs
Peak memory usage 4.008 MiB 4.004 MiB
FWD + BWD
Execution time 1.150 ms 293.969 μs
Peak memory usage 52.055 MiB 14.016 MiB

About

Fork of NNop.jl for registration in MurrellGroupRegistry (and possibly experimentation)

Resources

Stars

Watchers

Forks

Releases

No releases published

Languages

  • Julia 100.0%