Skip to content

Commit ef56b92

Browse files
committed
try to avoid GC problems
1 parent 06d197f commit ef56b92

File tree

2 files changed

+88
-51
lines changed

2 files changed

+88
-51
lines changed

Project.toml

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TriangularSolve"
22
uuid = "d5829a12-d9aa-46ab-831f-fb7c9ab06edf"
33
authors = ["chriselrod <[email protected]> and contributors"]
4-
version = "0.1.20"
4+
version = "0.1.21"
55

66
[deps]
77
CloseOpenIntervals = "fb6a15b2-703c-40df-9091-08a04967cfa9"
@@ -14,15 +14,18 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
1414
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1515

1616
[compat]
17+
Aqua = "0.8"
18+
CPUSummary = "0.2"
1719
CloseOpenIntervals = "0.1"
1820
IfElse = "0.1"
1921
LayoutPointers = "0.1.2"
2022
LinearAlgebra = "1"
2123
LoopVectorization = "0.12.30"
2224
Polyester = "0.4, 0.5, 0.6, 0.7"
2325
Static = "0.2, 0.3, 0.4, 0.6, 0.7, 0.8"
26+
Test = "1"
2427
VectorizationBase = "0.21"
25-
julia = "1.5"
28+
julia = "1.10"
2629

2730
[extras]
2831
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"

src/TriangularSolve.jl

+83-49
Original file line numberDiff line numberDiff line change
@@ -261,19 +261,59 @@ end
261261
nothing
262262
end
263263

264-
const LDIVBUFFERS = Vector{UInt8}[]
265-
@inline function lubuffer(::Val{T}, ::StaticInt{UF}, N) where {T,UF}
266-
buff = LDIVBUFFERS[Threads.threadid()]
267-
RSUF = StaticInt{UF}() * VectorizationBase.register_size()
264+
const buffer = Ref{Ptr{Cvoid}}(C_NULL)
265+
266+
function __init__()
267+
bp_size = 2 * sizeof(Int) * Threads.nthreads()
268+
buffer[] = bp = Libc.malloc(bp_size)
269+
Libc.memset(bp, 0, bp_size)
270+
end
271+
272+
function _get_buffer_pointer(::StaticInt{UF}, N) where {UF}
273+
RS = VectorizationBase.register_size()
274+
RSUF = StaticInt{UF}() * RS
268275
L = RSUF * N
269-
L > length(buff) && resize!(buff, L % UInt)
270-
ptr = Base.unsafe_convert(Ptr{T}, pointer(buff))
276+
tid = Threads.threadid() - 1
277+
bp = Ptr{Pair{Ptr{Cvoid},Int}}(buffer[]) + 2sizeof(Int) * tid
278+
(p, buff_current) = unsafe_load(bp)
279+
if buff_current < L
280+
p == C_NULL || Libc.free(p)
281+
buff_size = max(RSUF * 128, L)
282+
p = Libc.malloc(Int(buff_size + RS - 1))
283+
unsafe_store!(bp, p => buff_size)
284+
end
285+
return VectorizationBase.align(p, RS)
286+
end
287+
288+
@inline function lubuffer(::Val{T}, ::StaticInt{UF}, N) where {T,UF}
289+
RS = VectorizationBase.register_size()
290+
RSUF = StaticInt{UF}() * RS
291+
ptr = Ptr{T}(_get_buffer_pointer(StaticInt{UF}(), N))
271292
si = StrideIndex{2,(1, 2),1}(
272293
(VectorizationBase.static_sizeof(T), RSUF),
273294
(StaticInt(0), StaticInt(0))
274295
)
275-
stridedpointer(ptr, si, StaticInt{0}())
296+
stridedpointer(ptr, si, StaticInt{0}()), nothing
297+
end
298+
@inline function lubuffer(
299+
::Val{T},
300+
::StaticInt{UF},
301+
::StaticInt{N}
302+
) where {T,UF,N}
303+
RSUF = StaticInt{UF}() * VectorizationBase.pick_vector_width(T)
304+
L = RSUF * N
305+
buf = Ref{NTuple{L,T}}()
306+
ptr = Base.unsafe_convert(Ptr{T}, buf)
307+
si = StrideIndex{2,(1, 2),1}(
308+
(
309+
VectorizationBase.static_sizeof(T),
310+
RSUF * VectorizationBase.static_sizeof(T)
311+
),
312+
(StaticInt(0), StaticInt(0))
313+
)
314+
stridedpointer(ptr, si, StaticInt{0}()), buf
276315
end
316+
@inline _free(p::Ptr) = Libc.free(p)
277317
_canonicalize(x) = signed(x)
278318
_canonicalize(::StaticInt{N}) where {N} = StaticInt{N}()
279319
function div_dispatch!(
@@ -528,12 +568,12 @@ function block_size(::Val{T}) where {T}
528568
end
529569

530570
nmuladd!(C, A, U, M, K, N) = @turbo for n CloseOpen(N), m CloseOpen(M)
531-
Cmn = A[m, n]
532-
for k CloseOpen(K)
533-
Cmn -= C[m, k] * U[k, n]
534-
end
535-
C[m, K+n] = Cmn
571+
Cmn = A[m, n]
572+
for k CloseOpen(K)
573+
Cmn -= C[m, k] * U[k, n]
536574
end
575+
C[m, K+n] = Cmn
576+
end
537577

538578
function rdiv_block_N!(
539579
spc::AbstractStridedPointer{T},
@@ -695,50 +735,44 @@ function rdiv_U!(
695735
WU = UF * WS
696736
MU = UF > 1 ? M : 0
697737
Nd, Nr = VectorizationBase.vdivrem(N, WS)
698-
spb = lubuffer(Val(T), UF, N)
738+
spb, preserve = lubuffer(Val(T), UF, N)
699739
m = 0
700-
while m < MU - WU + 1
701-
n = Nr
702-
if n > 0
703-
BdivU_small_kern_u!(spb, spc, spa, spu, n, UF, Val(UNIT))
704-
end
705-
for _ 1:Nd
706-
rdiv_solve_W_u!(spb, spc, spa, spu, n, WS, UF, Val(UNIT))
707-
n += W
708-
end
709-
m += WU
710-
spa = gesp(spa, (WU, StaticInt(0)))
711-
spc = gesp(spc, (WU, StaticInt(0)))
712-
end
713-
finalmask = VectorizationBase.mask(WS, M)
714-
while m < M
715-
ubm = m + W
716-
nomaskiter = ubm < M
717-
mask = nomaskiter ? VectorizationBase.max_mask(WS) : finalmask
718-
n = Nr
719-
if n > 0
720-
BdivU_small_kern!(spb, spc, spa, spu, n, mask, Val(UNIT))
740+
GC.@preserve preserve begin
741+
while m < MU - WU + 1
742+
n = Nr
743+
if n > 0
744+
BdivU_small_kern_u!(spb, spc, spa, spu, n, UF, Val(UNIT))
745+
end
746+
for _ 1:Nd
747+
rdiv_solve_W_u!(spb, spc, spa, spu, n, WS, UF, Val(UNIT))
748+
n += W
749+
end
750+
m += WU
751+
spa = gesp(spa, (WU, StaticInt(0)))
752+
spc = gesp(spc, (WU, StaticInt(0)))
721753
end
722-
for i 1:Nd
723-
# @show C, n
724-
rdiv_solve_W!(spb, spc, spa, spu, n, i Nd, mask, Val(UNIT))
725-
n += W
754+
finalmask = VectorizationBase.mask(WS, M)
755+
while m < M
756+
ubm = m + W
757+
nomaskiter = ubm < M
758+
mask = nomaskiter ? VectorizationBase.max_mask(WS) : finalmask
759+
n = Nr
760+
if n > 0
761+
BdivU_small_kern!(spb, spc, spa, spu, n, mask, Val(UNIT))
762+
end
763+
for i 1:Nd
764+
# @show C, n
765+
rdiv_solve_W!(spb, spc, spa, spu, n, i Nd, mask, Val(UNIT))
766+
n += W
767+
end
768+
spa = gesp(spa, (WS, StaticInt(0)))
769+
spc = gesp(spc, (WS, StaticInt(0)))
770+
m = ubm
726771
end
727-
spa = gesp(spa, (WS, StaticInt(0)))
728-
spc = gesp(spc, (WS, StaticInt(0)))
729-
m = ubm
730772
end
731773
nothing
732774
end
733775

734-
function __init__()
735-
nthread = Threads.nthreads()
736-
resize!(LDIVBUFFERS, nthread)
737-
for i 1:nthread
738-
LDIVBUFFERS[i] =
739-
Vector{UInt8}(undef, 3VectorizationBase.register_size() * 128)
740-
end
741-
end
742776
#=
743777
using PrecompileTools
744778
@static if VERSION >= v"1.8.0-beta1"

0 commit comments

Comments
 (0)