@@ -31,89 +31,13 @@ using Polyester
31
31
end
32
32
end
33
33
34
- # @generated function nmuladd(A::VecUnroll{Nm1},B::AbstractStridedPointer,C::VecUnroll{Nm1}) where {Nm1}
35
- # N = Nm1 + 1
36
- # quote
37
- # $(Expr(:meta,:inline))
38
- # Ad = VectorizationBase.data(A);
39
- # Cd = VectorizationBase.data(C);
40
- # bp = stridedpointer(B)
41
- # Base.Cartesian.@nexprs $N n -> C_n = Cd[n]
42
- # Base.Cartesian.@nexprs $N k -> begin
43
- # A_k = Ad[k]
44
- # Base.Cartesian.@nexprs $N n -> begin
45
- # C_n = Base.FastMath.sub_fast(C_n, Base.FastMath.mul_fast(A_k, vload(B, (k-1,n-1))))
46
- # end
47
- # end
48
- # VecUnroll(Base.Cartesian.@ntuple $N C)
49
- # end
50
- # end
51
-
52
- # @inline function solve_Wx3W(A11::V, A12::V, A13::V, U::AbstractMatrix, ::StaticInt{W}) where {V<:VecUnroll,W}
53
- # WS = StaticInt{W}()
54
-
55
- # U11 = view(U,StaticInt(1):WS,StaticInt(1):WS)
56
- # A11 = solve_AU(A11, U11)
57
-
58
- # U12 = view(U,StaticInt(1):WS, StaticInt(1)+WS:WS*StaticInt(2))
59
- # A12 = nmuladd(A11, U12, A12)
60
- # U22 = view(U,StaticInt(1)+WS:WS*StaticInt(2),StaticInt(1)+WS:WS*StaticInt(2))
61
- # A12 = solve_AU(A12, U22)
62
-
63
- # U13 = view(U,StaticInt(1):WS, StaticInt(1)+WS*StaticInt(2):WS*StaticInt(3))
64
- # A13 = nmuladd(A11, U13, A13)
65
- # U23 = view(U,StaticInt(1)+WS:WS*StaticInt(2),StaticInt(1)+WS*StaticInt(2):WS*StaticInt(3))
66
- # A13 = nmuladd(A12, U23, A13)
67
- # U33 = view(U,StaticInt(1)+WS*StaticInt(2):WS*StaticInt(3),StaticInt(1)+WS*StaticInt(2):WS*StaticInt(3))
68
- # A13 = solve_AU(A13, U33)
69
-
70
- # return A11, A12, A13
71
- # end
72
-
73
- # @inline function solve_Wx3W!(ap::AbstractStridedPointer{T}, bp::AbstractStridedPointer{T}, U, rowoffset, coloffset) where {T}
74
- # WS = VectorizationBase.pick_vector_width(T)
75
- # W = Int(WS)
76
- # A11 = vload(bp, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset)))
77
- # A12 = vload(bp, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset+WS)))
78
- # A13 = vload(bp, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset+WS+WS)))
79
-
80
- # A11, A12, A13 = solve_Wx3W(A11, A12, A13, U, WS)
81
-
82
- # vstore!(ap, A11, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset)))
83
- # vstore!(ap, A12, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset+WS)))
84
- # vstore!(ap, A13, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset+WS+WS)))
85
- # end
86
- # @inline function solve_Wx3W!(ap::AbstractStridedPointer{T}, bp::AbstractStridedPointer{T}, U, rowoffset, coloffset, m::VectorizationBase.AbstractMask) where {T}
87
- # WS = VectorizationBase.pick_vector_width(T)
88
- # W = Int(WS)
89
- # A11 = vload(bp, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset)), m)
90
- # A12 = vload(bp, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset+WS)), m)
91
- # A13 = vload(bp, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset+WS+WS)), m)
92
-
93
- # A11, A12, A13 = solve_Wx3W(A11, A12, A13, U, WS)
94
-
95
- # vstore!(ap, A11, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset)), m)
96
- # vstore!(ap, A12, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset+WS)), m)
97
- # vstore!(ap, A13, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset+WS+WS)), m)
98
- # end
99
-
100
- # solve_3Wx3W!(A,B,U::UpperTriangular) = solve_3Wx3W!(A,B,parent(U))
101
- # function solve_3Wx3W!(A::AbstractMatrix{T},B,U) where {T}
102
- # W = VectorizationBase.pick_vector_width(T)
103
- # ap = stridedpointer(A);
104
- # bp = stridedpointer(B);
105
- # solve_Wx3W!(ap, bp, U, StaticInt(1), StaticInt(1))
106
- # solve_Wx3W!(ap, bp, U, StaticInt(1) + W, StaticInt(1))
107
- # solve_Wx3W!(ap, bp, U, StaticInt(1) + W + W, StaticInt(1))
108
- # end
109
-
110
34
@inline maybestore! (p, v, i) = vstore! (p, v, i)
111
35
@inline maybestore! (:: Nothing , v, i) = nothing
112
36
113
37
@inline maybestore! (p, v, i, m) = vstore! (p, v, i, m)
114
38
@inline maybestore! (:: Nothing , v, i, m) = nothing
115
39
116
- @inline function store_small_kern! (spa, sp, v, spu , i, n, mask, :: Val{true} )
40
+ @inline function store_small_kern! (spa, sp, v, _ , i, n, mask, :: Val{true} )
117
41
vstore! (spa, v, i, mask)
118
42
vstore! (sp, v, i, mask)
119
43
end
160
84
store_small_kern! (spa, sp, Amn, spu, Unroll {1,W,U,1,W,zero(UInt),1} ((StaticInt (0 ),n)), n, Val {UNIT} ())
161
85
end
162
86
end
163
- # function BdivU_small!(A::AbstractMatrix{T}, B::AbstractMatrix{T}, U::AbstractMatrix{T}) where {T}
164
- # W = VectorizationBase.pick_vector_width(T)
165
- # M, N = size(A)
166
- # m = 0
167
- # spa = stridedpointer(A)
168
- # spb = stridedpointer(B)
169
- # spu = stridedpointer(U)
170
- # while m < M
171
- # ml = m+1
172
- # mu = m+W
173
- # maskiter = mu > M
174
- # mask = maskiter ? VectorizationBase.mask(W, M) : VectorizationBase.max_mask(W)
175
- # for n ∈ 1:N
176
- # Amn = vload(spb, (MM(W, ml),n), mask)
177
- # for k ∈ 1:n-1
178
- # Amn = vfnmadd_fast(vload(spa, (MM(W, ml),k), mask), vload(spu, (k,n)), Amn)
179
- # end
180
- # vstore!(spa, Amn / vload(spu, (n,n)), (MM(W, ml),n), mask)
181
- # end
182
- # m = mu
183
- # end
184
- # # @inbounds @fastmath for m ∈ 1:M
185
- # # for n ∈ 1:N
186
- # # Amn = B[m,n]
187
- # # for k ∈ 1:n-1
188
- # # Amn -= A[m,k]*U[k,n]
189
- # # end
190
- # # A[m,n] = Amn / U[n,n]
191
- # # end
192
- # # end
193
- # end
194
- # function nmuladd!(C,A,B,D)
195
- # @turbo for n ∈ axes(C,2), m ∈ axes(C,1)
196
- # Cmn = D[m,n]
197
- # for k ∈ axes(B,1)
198
- # Cmn -= A[m,k]*B[k,n]
199
- # end
200
- # C[m,n] = Cmn
201
- # end
202
- # end
203
87
204
88
@generated function rdiv_solve_W_u! (spc, spb, spa, spu, n, :: StaticInt{W} , :: StaticInt{U} , :: Val{UNIT} ) where {W, U, UNIT}
205
89
quote
@@ -286,7 +170,7 @@ const LDIVBUFFERS = Vector{UInt8}[]
286
170
buff = LDIVBUFFERS[Threads. threadid ()]
287
171
RSUF = StaticInt {UF} ()* VectorizationBase. register_size ()
288
172
L = RSUF* N
289
- L > length (buff) && resize! (buff, L)
173
+ L > length (buff) && resize! (buff, L% UInt )
290
174
ptr = Base. unsafe_convert (Ptr{T}, buff)
291
175
si = StrideIndex {2,(1,2),1} ((VectorizationBase. static_sizeof (T), RSUF), (StaticInt (0 ),StaticInt (0 )))
292
176
stridedpointer (ptr, si, StaticInt {0} ())
@@ -412,24 +296,14 @@ function rdiv_block_N!(
412
296
N_temp = Core. ifelse (repeat, B_normalized, N)
413
297
while true
414
298
# println("Solve with N_temp = $N_temp and n = $n")
415
- rdiv_U! (spc, spa_rdiv, gesp (spu, (n,StaticInt {0} ())), M, N_temp, StaticInt {X} (), Val ( UNIT))
299
+ rdiv_U! (spc, spa_rdiv, gesp (spu, (n,StaticInt {0} ())), M, N_temp, StaticInt {X} (), Val { UNIT} ( ))
416
300
repeat || break
417
301
spa = gesp (spa, (StaticInt (0 ), B_normalized))
418
302
spc = gesp (spc, (StaticInt (0 ), B_normalized))
419
303
spu = gesp (spu, (StaticInt (0 ), B_normalized))
420
- nnext = n + B_normalized
421
- # N_temp =
422
304
n += B_normalized
423
305
repeat = n + B_normalized < N
424
306
N_temp = repeat ? N_temp : N - n
425
- # N_temp = min(n + B_normalized, N) - n
426
- # println("nmuladd with N_temp = $N_temp and n = $n")
427
- # mul!(
428
- # copyto!(view(C, :, n+1:n+N_temp), view(A, :, n+1:n+N_temp)),
429
- # view(C, :, 1:n),
430
- # view(U, 1:n, n+1:n+N_temp),
431
- # -1.0, 1.0
432
- # )
433
307
nmuladd! (spc_base, spa, spu, M, n, N_temp)
434
308
spa_rdiv = spc
435
309
end
@@ -439,15 +313,14 @@ function rdiv_block_MandN!(
439
313
) where {T,UNIT,X}
440
314
B = block_size (Val (T))
441
315
W = VectorizationBase. pick_vector_width (T)
442
- B_normalized = VectorizationBase. vcld (N, VectorizationBase. vcld (N, B)* W)* W
443
316
WUF = W* unroll_factor (W)
444
317
B_m = VectorizationBase. vcld (M, VectorizationBase. vcld (M, B)* WUF)* WUF
445
318
m = 0
446
319
while m < M
447
320
mu = m + B_m
448
321
Mtemp = min (M, mu) - m
449
322
rdiv_block_N! (
450
- spc, spa, spu, Mtemp, N, Val ( UNIT), StaticInt {X} (),
323
+ spc, spa, spu, Mtemp, N, Val { UNIT} ( ), StaticInt {X} (),
451
324
VectorizationBase. vcld (N, VectorizationBase. vcld (N, B)* W)* W
452
325
)
453
326
spa = gesp (spa, (B_m, StaticInt {0} ()))
@@ -458,42 +331,33 @@ function rdiv_block_MandN!(
458
331
end
459
332
function m_thread_block_size (M, N, nthreads, :: Val{T} ) where {T}
460
333
W = VectorizationBase. pick_vector_width (T)
461
- WUF = W * unroll_factor (W)
462
334
nb = clamp (VectorizationBase. vdiv (M * N, StaticInt {256} () * W), 1 , nthreads)
463
335
min (M, VectorizationBase. vcld (M, nb* W)* W)
464
336
end
465
337
338
+ struct RDivBlockMandNv2{UNIT,X} end
339
+ function (f:: RDivBlockMandNv2{UNIT,X} )(allargs, blockstart, blockstop) where {UNIT,X}
340
+ spc, spa, spu, N, Mrem, Nblock, mtb = allargs
341
+ for block = blockstart- 1 : blockstop- 1
342
+ rdiv_block_MandN! (
343
+ gesp (spc, (mtb* block, StaticInt {0} ())),
344
+ gesp (spa, (mtb* block, StaticInt {0} ())),
345
+ spu, Core. ifelse (block == Nblock- 1 , Mrem, mtb), N, Val {UNIT} (), static (X)
346
+ )
347
+ end
348
+ end
349
+
350
+
466
351
function multithread_rdiv! (
467
- spc:: AbstractStridedPointer{T } , spa, spu, M, N, mtb, :: Val{UNIT} , :: StaticInt{X}
468
- ) where {X,T, UNIT}
469
- mtb = 8
352
+ spc:: AbstractStridedPointer{TC } , spa:: AbstractStridedPointer{TA} , spu:: AbstractStridedPointer{TU} , M:: Int , N:: Int , mtb:: Int , :: Val{UNIT} , :: StaticInt{X}
353
+ ) where {X,UNIT,TC,TA,TU }
354
+ # Main._a[] = (spc, spa, spu, M, N, mtb, Val(UNIT), static(X));
470
355
(Md, Mr) = VectorizationBase. vdivrem (M, mtb)
471
356
Nblock = Md + (Mr ≠ 0 )
472
357
Mrem = Core. ifelse (Mr ≠ 0 , Mr, mtb)
473
- # @show mtb, Nblock, Mrem, Md, Mr
474
- # return
475
- let Md = Md, Mr = Mr, Nblock = Md + (Mr ≠ 0 ), Mrem = Core. ifelse (Mr ≠ 0 , Mr, mtb), VUNIT = Val {UNIT} (), StaticX = StaticInt {X} ()
476
- @batch for block in CloseOpen (Nblock)
477
- # for block in CloseOpen(Nblock)
478
- # let block = 0
479
- rdiv_block_MandN! (
480
- # rdiv_block_N!(
481
- gesp (spc, (mtb* block, StaticInt {0} ())),
482
- gesp (spa, (mtb* block, StaticInt {0} ())),
483
- spu, Core. ifelse (block == Nblock- 1 , Mrem, mtb), N, VUNIT, StaticX
484
- # spu, M, N, Val{UNIT}(), StaticInt{X}()
485
- )
486
- end
487
- end
358
+ f = RDivBlockMandNv2 {UNIT,X} ()
359
+ batch (f, (Nblock,min (Nblock,Threads. nthreads ())), spc, spa, spu, N, Mrem, Nblock, mtb)
488
360
nothing
489
- # nlaunch = Md - (Mr == 0)
490
- # threads, torelease = Polyester.request_threads(Base.Threads.threadid(), nlaunch)
491
- # nthread = length(threads)
492
- # if (nthread % Int32) ≤ zero(Int32)
493
- # return rdiv_block_MandN!(spc, spa, spu, M, N, Val(UNIT), StaticInt{X}())
494
- # end
495
- # nbatch = nthread + one(nthread)
496
-
497
361
end
498
362
499
363
# We're using `W x W` blocks, consuming `W` registers
@@ -521,7 +385,7 @@ function rdiv_U!(spc::AbstractStridedPointer{T}, spa::AbstractStridedPointer, sp
521
385
if n > 0
522
386
BdivU_small_kern_u! (spb, spc, spa, spu, n, UF, Val (UNIT))
523
387
end
524
- for i ∈ 1 : Nd
388
+ for _ ∈ 1 : Nd
525
389
rdiv_solve_W_u! (spb, spc, spa, spu, n, WS, UF, Val (UNIT))
526
390
n += W
527
391
end
0 commit comments