@@ -261,19 +261,59 @@ end
261
261
nothing
262
262
end
263
263
264
- const LDIVBUFFERS = Vector{UInt8}[]
265
- @inline function lubuffer (:: Val{T} , :: StaticInt{UF} , N) where {T,UF}
266
- buff = LDIVBUFFERS[Threads. threadid ()]
267
- RSUF = StaticInt {UF} () * VectorizationBase. register_size ()
264
+ const buffer = Ref {Ptr{Cvoid}} (C_NULL )
265
+
266
+ function __init__ ()
267
+ bp_size = 2 * sizeof (Int) * Threads. nthreads ()
268
+ buffer[] = bp = Libc. malloc (bp_size)
269
+ Libc. memset (bp, 0 , bp_size)
270
+ end
271
+
272
+ function _get_buffer_pointer (:: StaticInt{UF} , N) where {UF}
273
+ RS = VectorizationBase. register_size ()
274
+ RSUF = StaticInt {UF} () * RS
268
275
L = RSUF * N
269
- L > length (buff) && resize! (buff, L % UInt)
270
- ptr = Base. unsafe_convert (Ptr{T}, pointer (buff))
276
+ tid = Threads. threadid () - 1
277
+ bp = Ptr {Pair{Ptr{Cvoid},Int}} (buffer[]) + 2 sizeof (Int) * tid
278
+ (p, buff_current) = unsafe_load (bp)
279
+ if buff_current < L
280
+ p == C_NULL || Libc. free (p)
281
+ buff_size = max (RSUF * 128 , L)
282
+ p = Libc. malloc (Int (buff_size + RS - 1 ))
283
+ unsafe_store! (bp, p => buff_size)
284
+ end
285
+ return VectorizationBase. align (p, RS)
286
+ end
287
+
288
+ @inline function lubuffer (:: Val{T} , :: StaticInt{UF} , N) where {T,UF}
289
+ RS = VectorizationBase. register_size ()
290
+ RSUF = StaticInt {UF} () * RS
291
+ ptr = Ptr {T} (_get_buffer_pointer (StaticInt {UF} (), N))
271
292
si = StrideIndex {2,(1, 2),1} (
272
293
(VectorizationBase. static_sizeof (T), RSUF),
273
294
(StaticInt (0 ), StaticInt (0 ))
274
295
)
275
- stridedpointer (ptr, si, StaticInt {0} ())
296
+ stridedpointer (ptr, si, StaticInt {0} ()), nothing
297
+ end
298
+ @inline function lubuffer (
299
+ :: Val{T} ,
300
+ :: StaticInt{UF} ,
301
+ :: StaticInt{N}
302
+ ) where {T,UF,N}
303
+ RSUF = StaticInt {UF} () * VectorizationBase. pick_vector_width (T)
304
+ L = RSUF * N
305
+ buf = Ref {NTuple{L,T}} ()
306
+ ptr = Base. unsafe_convert (Ptr{T}, buf)
307
+ si = StrideIndex {2,(1, 2),1} (
308
+ (
309
+ VectorizationBase. static_sizeof (T),
310
+ RSUF * VectorizationBase. static_sizeof (T)
311
+ ),
312
+ (StaticInt (0 ), StaticInt (0 ))
313
+ )
314
+ stridedpointer (ptr, si, StaticInt {0} ()), buf
276
315
end
316
+ @inline _free (p:: Ptr ) = Libc. free (p)
277
317
_canonicalize (x) = signed (x)
278
318
_canonicalize (:: StaticInt{N} ) where {N} = StaticInt {N} ()
279
319
function div_dispatch! (
@@ -528,12 +568,12 @@ function block_size(::Val{T}) where {T}
528
568
end
529
569
530
570
nmuladd! (C, A, U, M, K, N) = @turbo for n ∈ CloseOpen (N), m ∈ CloseOpen (M)
531
- Cmn = A[m, n]
532
- for k ∈ CloseOpen (K)
533
- Cmn -= C[m, k] * U[k, n]
534
- end
535
- C[m, K+ n] = Cmn
571
+ Cmn = A[m, n]
572
+ for k ∈ CloseOpen (K)
573
+ Cmn -= C[m, k] * U[k, n]
536
574
end
575
+ C[m, K+ n] = Cmn
576
+ end
537
577
538
578
function rdiv_block_N! (
539
579
spc:: AbstractStridedPointer{T} ,
@@ -695,50 +735,44 @@ function rdiv_U!(
695
735
WU = UF * WS
696
736
MU = UF > 1 ? M : 0
697
737
Nd, Nr = VectorizationBase. vdivrem (N, WS)
698
- spb = lubuffer (Val (T), UF, N)
738
+ spb, preserve = lubuffer (Val (T), UF, N)
699
739
m = 0
700
- while m < MU - WU + 1
701
- n = Nr
702
- if n > 0
703
- BdivU_small_kern_u! (spb, spc, spa, spu, n, UF, Val (UNIT))
704
- end
705
- for _ ∈ 1 : Nd
706
- rdiv_solve_W_u! (spb, spc, spa, spu, n, WS, UF, Val (UNIT))
707
- n += W
708
- end
709
- m += WU
710
- spa = gesp (spa, (WU, StaticInt (0 )))
711
- spc = gesp (spc, (WU, StaticInt (0 )))
712
- end
713
- finalmask = VectorizationBase. mask (WS, M)
714
- while m < M
715
- ubm = m + W
716
- nomaskiter = ubm < M
717
- mask = nomaskiter ? VectorizationBase. max_mask (WS) : finalmask
718
- n = Nr
719
- if n > 0
720
- BdivU_small_kern! (spb, spc, spa, spu, n, mask, Val (UNIT))
740
+ GC. @preserve preserve begin
741
+ while m < MU - WU + 1
742
+ n = Nr
743
+ if n > 0
744
+ BdivU_small_kern_u! (spb, spc, spa, spu, n, UF, Val (UNIT))
745
+ end
746
+ for _ ∈ 1 : Nd
747
+ rdiv_solve_W_u! (spb, spc, spa, spu, n, WS, UF, Val (UNIT))
748
+ n += W
749
+ end
750
+ m += WU
751
+ spa = gesp (spa, (WU, StaticInt (0 )))
752
+ spc = gesp (spc, (WU, StaticInt (0 )))
721
753
end
722
- for i ∈ 1 : Nd
723
- # @show C, n
724
- rdiv_solve_W! (spb, spc, spa, spu, n, i ≠ Nd, mask, Val (UNIT))
725
- n += W
754
+ finalmask = VectorizationBase. mask (WS, M)
755
+ while m < M
756
+ ubm = m + W
757
+ nomaskiter = ubm < M
758
+ mask = nomaskiter ? VectorizationBase. max_mask (WS) : finalmask
759
+ n = Nr
760
+ if n > 0
761
+ BdivU_small_kern! (spb, spc, spa, spu, n, mask, Val (UNIT))
762
+ end
763
+ for i ∈ 1 : Nd
764
+ # @show C, n
765
+ rdiv_solve_W! (spb, spc, spa, spu, n, i ≠ Nd, mask, Val (UNIT))
766
+ n += W
767
+ end
768
+ spa = gesp (spa, (WS, StaticInt (0 )))
769
+ spc = gesp (spc, (WS, StaticInt (0 )))
770
+ m = ubm
726
771
end
727
- spa = gesp (spa, (WS, StaticInt (0 )))
728
- spc = gesp (spc, (WS, StaticInt (0 )))
729
- m = ubm
730
772
end
731
773
nothing
732
774
end
733
775
734
- function __init__ ()
735
- nthread = Threads. nthreads ()
736
- resize! (LDIVBUFFERS, nthread)
737
- for i ∈ 1 : nthread
738
- LDIVBUFFERS[i] =
739
- Vector {UInt8} (undef, 3 VectorizationBase. register_size () * 128 )
740
- end
741
- end
742
776
#=
743
777
using PrecompileTools
744
778
@static if VERSION >= v"1.8.0-beta1"
0 commit comments