Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Compute Optimal Block Size #9

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
julia 1.0
SIMD
Hwloc
6 changes: 4 additions & 2 deletions bench/bench1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.8
for (i,siz) in enumerate(mnks)
A,B,C = (zeros(siz,siz) for i in 1:3)
blk = Block(A,B,C,false)
sleep(1)
jbtimes[i] = @belapsed addmul!($C,$A,$B,$blk)
sleep(1)
obtimes[i] = @belapsed mul!($C,$A,$B)
end

Expand All @@ -19,5 +21,5 @@ time2gflops(mnk, time) = 2 * mnk^3 / time / 10^9
jflops = time2gflops.(mnks, jbtimes)
oflops = time2gflops.(mnks, obtimes)
plot(mnks, jflops, lab="JuliaBLAS")
plot!(mnks, oflops, lab="OpenBLAS", ylabel="GFLOPS", xlabel="M=N=K", legend=:bottomright, dpi=400, ylims=(0,40), yticks=0:5:40)
savefig("bench1.png")
plot!(mnks, oflops, lab="OpenBLAS", ylabel="GFLOPS", xlabel="M=N=K", legend=:bottomright, dpi=400, ylims=(0,60), yticks=0:5:60)
savefig("bench8.png")
5 changes: 5 additions & 0 deletions src/JuliaBLAS.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
module JuliaBLAS

using SIMD
import Base.Cartesian: @nexprs
import Hwloc

include("tune.jl")
include("gemm.jl")

export addmul!
Expand Down
21 changes: 12 additions & 9 deletions src/gemm.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using SIMD
import Base.Cartesian: @nexprs

struct Block{T1,T2,T3,T4,G}
Ac::T1
Expand All @@ -19,17 +17,22 @@ struct Block{T1,T2,T3,T4,G}
inc2C::Int
end

const Ac = Vector{UInt8}(undef, 110592)
const Bc = Vector{UInt8}(undef, 6266880)
const AB = Vector{UInt8}(undef, 12*4*8)

const block_attr = get_hw_params()
@show block_attr

const Ac = Vector{UInt8}(undef, block_attr.mc * block_attr.kc * 8)
const Bc = Vector{UInt8}(undef, block_attr.kc * block_attr.nc * 8)
const AB = Vector{UInt8}(undef, block_attr.nr * block_attr.mr * 8)

function Block(A::X, B::W, C::Z, generic) where {X, W, Z}
global Ac, Bc, AB
mr=12; nr=4
global block_attr
mr=block_attr.mr; nr=block_attr.nr
m, n = size(C)
mc = 72
kc = 192
nc = 4080
mc = block_attr.mc
kc = block_attr.kc
nc = block_attr.nc
T = promote_type(eltype(X), eltype(W), eltype(Z))
siz = sizeof(T)
_Ac = unsafe_wrap(Array, Ptr{T}(pointer(Ac)), length(Ac)÷siz)
Expand Down
62 changes: 62 additions & 0 deletions src/tune.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@

struct BlockAttributes
mc::Int
kc::Int
nc::Int
mr::Int
nr::Int
end

function get_hw_params()
topology = Hwloc.topology_load()
summary = Hwloc.getinfo(topology)
l3 = findcache(topology, :L3Cache).attr
# assume L2 caches are of same attributes
l2 = findcache(topology, :L2Cache).attr
l2ct = countcache(topology, :L2Cache)
# assume L1 caches are of same attributes
l1 = findcache(topology, :L1Cache).attr
l1ct = countcache(topology, :L2Cache)

#defaults
mc = 72
kc = 192
nc = 4080

wl1 = div((l1.size*l1ct),l1.linesize)
# add + mul latency or FMA latency
lvfma = 5 # TODO: auto?
#
nvec = 4
sdata = 32 #AVX
nvfma = 1

mr = ceil(Int, sqrt(nvec*lvfma*nvfma)/nvec)*nvec
nr = ceil(Int, (nvec*lvfma*nvfma)/mr)

car = floor(Int, (wl1 - 1)/(1+nr/mr))
cbr = ceil(Int, nr*car/mr)

#kc = car*l1.size*l1.linesize/(mr*sdata)

BlockAttributes(mc, kc, nc, mr, nr)
end

function findcache(v, sym)
for c in v.children
if c.type_ == sym
return c
else
return findcache(c, sym)
end
end
end

function countcache(v, sym)
ct = 0
for c in v.children
c.type_ == sym && (ct += 1)
ct += countcache(c,sym)
end
return ct
end