Skip to content

Commit 8e2d7e4

Browse files
authored
Merge pull request #25 from JuliaSIMD/liballocsbutiamgivingup
trying to cut down heisen allocations to no avail
2 parents 7fa5823 + f2ce9ef commit 8e2d7e4

File tree

1 file changed

+23
-159
lines changed

1 file changed

+23
-159
lines changed

src/TriangularSolve.jl

+23-159
Original file line numberDiff line numberDiff line change
@@ -31,89 +31,13 @@ using Polyester
3131
end
3232
end
3333

34-
# @generated function nmuladd(A::VecUnroll{Nm1},B::AbstractStridedPointer,C::VecUnroll{Nm1}) where {Nm1}
35-
# N = Nm1 + 1
36-
# quote
37-
# $(Expr(:meta,:inline))
38-
# Ad = VectorizationBase.data(A);
39-
# Cd = VectorizationBase.data(C);
40-
# bp = stridedpointer(B)
41-
# Base.Cartesian.@nexprs $N n -> C_n = Cd[n]
42-
# Base.Cartesian.@nexprs $N k -> begin
43-
# A_k = Ad[k]
44-
# Base.Cartesian.@nexprs $N n -> begin
45-
# C_n = Base.FastMath.sub_fast(C_n, Base.FastMath.mul_fast(A_k, vload(B, (k-1,n-1))))
46-
# end
47-
# end
48-
# VecUnroll(Base.Cartesian.@ntuple $N C)
49-
# end
50-
# end
51-
52-
# @inline function solve_Wx3W(A11::V, A12::V, A13::V, U::AbstractMatrix, ::StaticInt{W}) where {V<:VecUnroll,W}
53-
# WS = StaticInt{W}()
54-
55-
# U11 = view(U,StaticInt(1):WS,StaticInt(1):WS)
56-
# A11 = solve_AU(A11, U11)
57-
58-
# U12 = view(U,StaticInt(1):WS, StaticInt(1)+WS:WS*StaticInt(2))
59-
# A12 = nmuladd(A11, U12, A12)
60-
# U22 = view(U,StaticInt(1)+WS:WS*StaticInt(2),StaticInt(1)+WS:WS*StaticInt(2))
61-
# A12 = solve_AU(A12, U22)
62-
63-
# U13 = view(U,StaticInt(1):WS, StaticInt(1)+WS*StaticInt(2):WS*StaticInt(3))
64-
# A13 = nmuladd(A11, U13, A13)
65-
# U23 = view(U,StaticInt(1)+WS:WS*StaticInt(2),StaticInt(1)+WS*StaticInt(2):WS*StaticInt(3))
66-
# A13 = nmuladd(A12, U23, A13)
67-
# U33 = view(U,StaticInt(1)+WS*StaticInt(2):WS*StaticInt(3),StaticInt(1)+WS*StaticInt(2):WS*StaticInt(3))
68-
# A13 = solve_AU(A13, U33)
69-
70-
# return A11, A12, A13
71-
# end
72-
73-
# @inline function solve_Wx3W!(ap::AbstractStridedPointer{T}, bp::AbstractStridedPointer{T}, U, rowoffset, coloffset) where {T}
74-
# WS = VectorizationBase.pick_vector_width(T)
75-
# W = Int(WS)
76-
# A11 = vload(bp, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset)))
77-
# A12 = vload(bp, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset+WS)))
78-
# A13 = vload(bp, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset+WS+WS)))
79-
80-
# A11, A12, A13 = solve_Wx3W(A11, A12, A13, U, WS)
81-
82-
# vstore!(ap, A11, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset)))
83-
# vstore!(ap, A12, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset+WS)))
84-
# vstore!(ap, A13, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset+WS+WS)))
85-
# end
86-
# @inline function solve_Wx3W!(ap::AbstractStridedPointer{T}, bp::AbstractStridedPointer{T}, U, rowoffset, coloffset, m::VectorizationBase.AbstractMask) where {T}
87-
# WS = VectorizationBase.pick_vector_width(T)
88-
# W = Int(WS)
89-
# A11 = vload(bp, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset)), m)
90-
# A12 = vload(bp, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset+WS)), m)
91-
# A13 = vload(bp, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset+WS+WS)), m)
92-
93-
# A11, A12, A13 = solve_Wx3W(A11, A12, A13, U, WS)
94-
95-
# vstore!(ap, A11, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset)), m)
96-
# vstore!(ap, A12, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset+WS)), m)
97-
# vstore!(ap, A13, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset+WS+WS)), m)
98-
# end
99-
100-
# solve_3Wx3W!(A,B,U::UpperTriangular) = solve_3Wx3W!(A,B,parent(U))
101-
# function solve_3Wx3W!(A::AbstractMatrix{T},B,U) where {T}
102-
# W = VectorizationBase.pick_vector_width(T)
103-
# ap = stridedpointer(A);
104-
# bp = stridedpointer(B);
105-
# solve_Wx3W!(ap, bp, U, StaticInt(1), StaticInt(1))
106-
# solve_Wx3W!(ap, bp, U, StaticInt(1) + W, StaticInt(1))
107-
# solve_Wx3W!(ap, bp, U, StaticInt(1) + W + W, StaticInt(1))
108-
# end
109-
11034
@inline maybestore!(p, v, i) = vstore!(p, v, i)
11135
@inline maybestore!(::Nothing, v, i) = nothing
11236

11337
@inline maybestore!(p, v, i, m) = vstore!(p, v, i, m)
11438
@inline maybestore!(::Nothing, v, i, m) = nothing
11539

116-
@inline function store_small_kern!(spa, sp, v, spu, i, n, mask, ::Val{true})
40+
@inline function store_small_kern!(spa, sp, v, _, i, n, mask, ::Val{true})
11741
vstore!(spa, v, i, mask)
11842
vstore!(sp, v, i, mask)
11943
end
@@ -160,46 +84,6 @@ end
16084
store_small_kern!(spa, sp, Amn, spu, Unroll{1,W,U,1,W,zero(UInt),1}((StaticInt(0),n)), n, Val{UNIT}())
16185
end
16286
end
163-
# function BdivU_small!(A::AbstractMatrix{T}, B::AbstractMatrix{T}, U::AbstractMatrix{T}) where {T}
164-
# W = VectorizationBase.pick_vector_width(T)
165-
# M, N = size(A)
166-
# m = 0
167-
# spa = stridedpointer(A)
168-
# spb = stridedpointer(B)
169-
# spu = stridedpointer(U)
170-
# while m < M
171-
# ml = m+1
172-
# mu = m+W
173-
# maskiter = mu > M
174-
# mask = maskiter ? VectorizationBase.mask(W, M) : VectorizationBase.max_mask(W)
175-
# for n ∈ 1:N
176-
# Amn = vload(spb, (MM(W, ml),n), mask)
177-
# for k ∈ 1:n-1
178-
# Amn = vfnmadd_fast(vload(spa, (MM(W, ml),k), mask), vload(spu, (k,n)), Amn)
179-
# end
180-
# vstore!(spa, Amn / vload(spu, (n,n)), (MM(W, ml),n), mask)
181-
# end
182-
# m = mu
183-
# end
184-
# # @inbounds @fastmath for m ∈ 1:M
185-
# # for n ∈ 1:N
186-
# # Amn = B[m,n]
187-
# # for k ∈ 1:n-1
188-
# # Amn -= A[m,k]*U[k,n]
189-
# # end
190-
# # A[m,n] = Amn / U[n,n]
191-
# # end
192-
# # end
193-
# end
194-
# function nmuladd!(C,A,B,D)
195-
# @turbo for n ∈ axes(C,2), m ∈ axes(C,1)
196-
# Cmn = D[m,n]
197-
# for k ∈ axes(B,1)
198-
# Cmn -= A[m,k]*B[k,n]
199-
# end
200-
# C[m,n] = Cmn
201-
# end
202-
# end
20387

20488
@generated function rdiv_solve_W_u!(spc, spb, spa, spu, n, ::StaticInt{W}, ::StaticInt{U}, ::Val{UNIT}) where {W, U, UNIT}
20589
quote
@@ -286,7 +170,7 @@ const LDIVBUFFERS = Vector{UInt8}[]
286170
buff = LDIVBUFFERS[Threads.threadid()]
287171
RSUF = StaticInt{UF}()*VectorizationBase.register_size()
288172
L = RSUF*N
289-
L > length(buff) && resize!(buff, L)
173+
L > length(buff) && resize!(buff, L%UInt)
290174
ptr = Base.unsafe_convert(Ptr{T}, buff)
291175
si = StrideIndex{2,(1,2),1}((VectorizationBase.static_sizeof(T), RSUF), (StaticInt(0),StaticInt(0)))
292176
stridedpointer(ptr, si, StaticInt{0}())
@@ -412,24 +296,14 @@ function rdiv_block_N!(
412296
N_temp = Core.ifelse(repeat, B_normalized, N)
413297
while true
414298
# println("Solve with N_temp = $N_temp and n = $n")
415-
rdiv_U!(spc, spa_rdiv, gesp(spu, (n,StaticInt{0}())), M, N_temp, StaticInt{X}(), Val(UNIT))
299+
rdiv_U!(spc, spa_rdiv, gesp(spu, (n,StaticInt{0}())), M, N_temp, StaticInt{X}(), Val{UNIT}())
416300
repeat || break
417301
spa = gesp(spa, (StaticInt(0), B_normalized))
418302
spc = gesp(spc, (StaticInt(0), B_normalized))
419303
spu = gesp(spu, (StaticInt(0), B_normalized))
420-
nnext = n + B_normalized
421-
# N_temp =
422304
n += B_normalized
423305
repeat = n + B_normalized < N
424306
N_temp = repeat ? N_temp : N - n
425-
# N_temp = min(n + B_normalized, N) - n
426-
# println("nmuladd with N_temp = $N_temp and n = $n")
427-
# mul!(
428-
# copyto!(view(C, :, n+1:n+N_temp), view(A, :, n+1:n+N_temp)),
429-
# view(C, :, 1:n),
430-
# view(U, 1:n, n+1:n+N_temp),
431-
# -1.0, 1.0
432-
# )
433307
nmuladd!(spc_base, spa, spu, M, n, N_temp)
434308
spa_rdiv = spc
435309
end
@@ -439,15 +313,14 @@ function rdiv_block_MandN!(
439313
) where {T,UNIT,X}
440314
B = block_size(Val(T))
441315
W = VectorizationBase.pick_vector_width(T)
442-
B_normalized = VectorizationBase.vcld(N, VectorizationBase.vcld(N, B)*W)*W
443316
WUF = W*unroll_factor(W)
444317
B_m = VectorizationBase.vcld(M, VectorizationBase.vcld(M, B)*WUF)*WUF
445318
m = 0
446319
while m < M
447320
mu = m + B_m
448321
Mtemp = min(M, mu) - m
449322
rdiv_block_N!(
450-
spc, spa, spu, Mtemp, N, Val(UNIT), StaticInt{X}(),
323+
spc, spa, spu, Mtemp, N, Val{UNIT}(), StaticInt{X}(),
451324
VectorizationBase.vcld(N, VectorizationBase.vcld(N, B)*W)*W
452325
)
453326
spa = gesp(spa, (B_m, StaticInt{0}()))
@@ -458,42 +331,33 @@ function rdiv_block_MandN!(
458331
end
459332
function m_thread_block_size(M, N, nthreads, ::Val{T}) where {T}
460333
W = VectorizationBase.pick_vector_width(T)
461-
WUF = W * unroll_factor(W)
462334
nb = clamp(VectorizationBase.vdiv(M * N, StaticInt{256}() * W), 1, nthreads)
463335
min(M, VectorizationBase.vcld(M, nb*W)*W)
464336
end
465337

338+
struct RDivBlockMandNv2{UNIT,X} end
339+
function (f::RDivBlockMandNv2{UNIT,X})(allargs, blockstart, blockstop) where {UNIT,X}
340+
spc, spa, spu, N, Mrem, Nblock, mtb = allargs
341+
for block = blockstart-1:blockstop-1
342+
rdiv_block_MandN!(
343+
gesp(spc, (mtb*block, StaticInt{0}())),
344+
gesp(spa, (mtb*block, StaticInt{0}())),
345+
spu, Core.ifelse(block == Nblock-1, Mrem, mtb), N, Val{UNIT}(), static(X)
346+
)
347+
end
348+
end
349+
350+
466351
function multithread_rdiv!(
467-
spc::AbstractStridedPointer{T}, spa, spu, M, N, mtb, ::Val{UNIT}, ::StaticInt{X}
468-
) where {X,T,UNIT}
469-
mtb = 8
352+
spc::AbstractStridedPointer{TC}, spa::AbstractStridedPointer{TA}, spu::AbstractStridedPointer{TU}, M::Int, N::Int, mtb::Int, ::Val{UNIT}, ::StaticInt{X}
353+
) where {X,UNIT,TC,TA,TU}
354+
# Main._a[] = (spc, spa, spu, M, N, mtb, Val(UNIT), static(X));
470355
(Md, Mr) = VectorizationBase.vdivrem(M, mtb)
471356
Nblock = Md + (Mr 0)
472357
Mrem = Core.ifelse(Mr 0, Mr, mtb)
473-
# @show mtb, Nblock, Mrem, Md, Mr
474-
# return
475-
let Md = Md, Mr = Mr, Nblock = Md + (Mr 0), Mrem = Core.ifelse(Mr 0, Mr, mtb), VUNIT = Val{UNIT}(), StaticX = StaticInt{X}()
476-
@batch for block in CloseOpen(Nblock)
477-
# for block in CloseOpen(Nblock)
478-
# let block = 0
479-
rdiv_block_MandN!(
480-
# rdiv_block_N!(
481-
gesp(spc, (mtb*block, StaticInt{0}())),
482-
gesp(spa, (mtb*block, StaticInt{0}())),
483-
spu, Core.ifelse(block == Nblock-1, Mrem, mtb), N, VUNIT, StaticX
484-
# spu, M, N, Val{UNIT}(), StaticInt{X}()
485-
)
486-
end
487-
end
358+
f = RDivBlockMandNv2{UNIT,X}()
359+
batch(f, (Nblock,min(Nblock,Threads.nthreads())), spc, spa, spu, N, Mrem, Nblock, mtb)
488360
nothing
489-
# nlaunch = Md - (Mr == 0)
490-
# threads, torelease = Polyester.request_threads(Base.Threads.threadid(), nlaunch)
491-
# nthread = length(threads)
492-
# if (nthread % Int32) ≤ zero(Int32)
493-
# return rdiv_block_MandN!(spc, spa, spu, M, N, Val(UNIT), StaticInt{X}())
494-
# end
495-
# nbatch = nthread + one(nthread)
496-
497361
end
498362

499363
# We're using `W x W` blocks, consuming `W` registers
@@ -521,7 +385,7 @@ function rdiv_U!(spc::AbstractStridedPointer{T}, spa::AbstractStridedPointer, sp
521385
if n > 0
522386
BdivU_small_kern_u!(spb, spc, spa, spu, n, UF, Val(UNIT))
523387
end
524-
for i 1:Nd
388+
for _ 1:Nd
525389
rdiv_solve_W_u!(spb, spc, spa, spu, n, WS, UF, Val(UNIT))
526390
n += W
527391
end

0 commit comments

Comments
 (0)