@@ -293,7 +293,7 @@ const LDIVBUFFERS = Vector{UInt8}[]
293
293
end
294
294
_canonicalize (x) = signed (x)
295
295
_canonicalize (:: StaticInt{N} ) where {N} = StaticInt {N} ()
296
- function div_dispatch! (C:: AbstractMatrix{T} , A, U, :: Val{UNIT} , :: Val{THREAD } ) where {UNIT,T,THREAD }
296
+ function div_dispatch! (C:: AbstractMatrix{T} , A, U, nthread , :: Val{UNIT } ) where {UNIT,T}
297
297
_M, _N = size (A)
298
298
M = _canonicalize (_M)
299
299
N = _canonicalize (_N)
@@ -305,8 +305,8 @@ function div_dispatch!(C::AbstractMatrix{T}, A, U, ::Val{UNIT}, ::Val{THREAD}) w
305
305
spc = zero_offsets (_spc)
306
306
spu = zero_offsets (_spu)
307
307
GC. @preserve spap spcp spup begin
308
- mtb = m_thread_block_size (M, N, Val (T))
309
- if THREAD && (VectorizationBase . num_threads () > 1 )
308
+ mtb = m_thread_block_size (M, N, nthread, Val (T))
309
+ if nthread > 1
310
310
(M > mtb) && return multithread_rdiv! (spc, spa, spu, M, N, mtb, Val (UNIT), VectorizationBase. contiguous_axis (A))
311
311
elseif N > block_size (Val (T))
312
312
return rdiv_block_MandN! (spc, spa, spu, M, N, Val (UNIT), VectorizationBase. contiguous_axis (A))
@@ -315,36 +315,69 @@ function div_dispatch!(C::AbstractMatrix{T}, A, U, ::Val{UNIT}, ::Val{THREAD}) w
315
315
end
316
316
end
317
317
318
- function rdiv! (A:: AbstractMatrix{T} , U:: UpperTriangular{T} , :: Val{THREAD} = Val (true )) where {T<: Union{Float32,Float64} ,THREAD}
319
- div_dispatch! (A, A, parent (U), Val (false ), Val (THREAD))
318
+ _nthreads () = min (Int (VectorizationBase. num_cores ()):: Int , Threads. nthreads ():: Int )
319
+ function rdiv! (A:: AbstractMatrix{T} , U:: UpperTriangular{T} , :: Val{true} = Val (true )) where {T<: Union{Float32,Float64} }
320
+ div_dispatch! (A, A, parent (U), _nthreads (), Val (false ))
320
321
return A
321
322
end
322
- function rdiv! (C:: AbstractMatrix{T} , A:: AbstractMatrix{T} , U:: UpperTriangular{T} , :: Val{THREAD} = Val (true )) where {T<: Union{Float32,Float64} ,THREAD}
323
- div_dispatch! (C, A, parent (U), Val (false ), Val (THREAD))
323
+ function rdiv! (A:: AbstractMatrix{T} , U:: UpperTriangular{T} , :: Val{false} ) where {T<: Union{Float32,Float64} }
324
+ div_dispatch! (A, A, parent (U), static (0 ), Val (false ))
325
+ return A
326
+ end
327
+ function rdiv! (C:: AbstractMatrix{T} , A:: AbstractMatrix{T} , U:: UpperTriangular{T} , :: Val{true} = Val (true )) where {T<: Union{Float32,Float64} }
328
+ div_dispatch! (C, A, parent (U), _nthreads (), Val (false ))
324
329
return C
325
330
end
326
- function rdiv! (A:: AbstractMatrix{T} , U:: UnitUpperTriangular{T} , :: Val{THREAD} = Val (true )) where {T<: Union{Float32,Float64} ,THREAD}
327
- div_dispatch! (A, A, parent (U), Val (true ), Val (THREAD))
331
+ function rdiv! (C:: AbstractMatrix{T} , A:: AbstractMatrix{T} , U:: UpperTriangular{T} , :: Val{false} ) where {T<: Union{Float32,Float64} }
332
+ div_dispatch! (C, A, parent (U), static (0 ), Val (false ))
333
+ return C
334
+ end
335
+ function rdiv! (A:: AbstractMatrix{T} , U:: UnitUpperTriangular{T} , :: Val{true} = Val (true )) where {T<: Union{Float32,Float64} }
336
+ div_dispatch! (A, A, parent (U), _nthreads (), Val (true ))
328
337
return A
329
338
end
330
- function rdiv! (C:: AbstractMatrix{T} , A:: AbstractMatrix{T} , U:: UnitUpperTriangular{T} , :: Val{THREAD} = Val (true )) where {T<: Union{Float32,Float64} ,THREAD}
331
- div_dispatch! (C, A, parent (U), Val (true ), Val (THREAD))
339
+ function rdiv! (A:: AbstractMatrix{T} , U:: UnitUpperTriangular{T} , :: Val{false} ) where {T<: Union{Float32,Float64} }
340
+ div_dispatch! (A, A, parent (U), static (0 ), Val (true ))
341
+ return A
342
+ end
343
+ function rdiv! (C:: AbstractMatrix{T} , A:: AbstractMatrix{T} , U:: UnitUpperTriangular{T} , :: Val{true} = Val (true )) where {T<: Union{Float32,Float64} }
344
+ div_dispatch! (C, A, parent (U), _nthreads (), Val (true ))
345
+ return C
346
+ end
347
+ function rdiv! (C:: AbstractMatrix{T} , A:: AbstractMatrix{T} , U:: UnitUpperTriangular{T} , :: Val{false} ) where {T<: Union{Float32,Float64} }
348
+ div_dispatch! (C, A, parent (U), static (0 ), Val (true ))
332
349
return C
333
350
end
334
- function ldiv! (U:: LowerTriangular{T} , A:: AbstractMatrix{T} , :: Val{THREAD } = Val (true )) where {T<: Union{Float32,Float64} ,THREAD }
335
- div_dispatch! (transpose (A), transpose (A), transpose (parent (U)), Val ( false ), Val (THREAD ))
351
+ function ldiv! (U:: LowerTriangular{T} , A:: AbstractMatrix{T} , :: Val{true } = Val (true )) where {T<: Union{Float32,Float64} }
352
+ div_dispatch! (transpose (A), transpose (A), transpose (parent (U)), _nthreads ( ), Val (false ))
336
353
return A
337
354
end
338
- function ldiv! (C:: AbstractMatrix{T} , U:: LowerTriangular{T} , A:: AbstractMatrix{T} , :: Val{THREAD} = Val (true )) where {T<: Union{Float32,Float64} ,THREAD}
339
- div_dispatch! (transpose (C), transpose (A), transpose (parent (U)), Val (false ), Val (THREAD))
355
+ function ldiv! (U:: LowerTriangular{T} , A:: AbstractMatrix{T} , :: Val{false} ) where {T<: Union{Float32,Float64} }
356
+ div_dispatch! (transpose (A), transpose (A), transpose (parent (U)), static (0 ), Val (false ))
357
+ return A
358
+ end
359
+ function ldiv! (C:: AbstractMatrix{T} , U:: LowerTriangular{T} , A:: AbstractMatrix{T} , :: Val{true} = Val (true )) where {T<: Union{Float32,Float64} }
360
+ div_dispatch! (transpose (C), transpose (A), transpose (parent (U)), _nthreads (), Val (false ))
361
+ return C
362
+ end
363
+ function ldiv! (C:: AbstractMatrix{T} , U:: LowerTriangular{T} , A:: AbstractMatrix{T} , :: Val{false} ) where {T<: Union{Float32,Float64} }
364
+ div_dispatch! (transpose (C), transpose (A), transpose (parent (U)), static (0 ), Val (false ))
340
365
return C
341
366
end
342
- function ldiv! (U:: UnitLowerTriangular{T} , A:: AbstractMatrix{T} , :: Val{THREAD} = Val (true )) where {T<: Union{Float32,Float64} ,THREAD}
343
- div_dispatch! (transpose (A), transpose (A), transpose (parent (U)), Val (true ), Val (THREAD))
367
+ function ldiv! (U:: UnitLowerTriangular{T} , A:: AbstractMatrix{T} , :: Val{true} = Val (true )) where {T<: Union{Float32,Float64} }
368
+ div_dispatch! (transpose (A), transpose (A), transpose (parent (U)), _nthreads (), Val (true ))
369
+ return A
370
+ end
371
+ function ldiv! (U:: UnitLowerTriangular{T} , A:: AbstractMatrix{T} , :: Val{false} ) where {T<: Union{Float32,Float64} }
372
+ div_dispatch! (transpose (A), transpose (A), transpose (parent (U)), static (0 ), Val (true ))
344
373
return A
345
374
end
346
- function ldiv! (C:: AbstractMatrix{T} , U:: UnitLowerTriangular{T} , A:: AbstractMatrix{T} , :: Val{THREAD} = Val (true )) where {T<: Union{Float32,Float64} ,THREAD}
347
- div_dispatch! (transpose (C), transpose (A), transpose (parent (U)), Val (true ), Val (THREAD))
375
+ function ldiv! (C:: AbstractMatrix{T} , U:: UnitLowerTriangular{T} , A:: AbstractMatrix{T} , :: Val{true} = Val (true )) where {T<: Union{Float32,Float64} }
376
+ div_dispatch! (transpose (C), transpose (A), transpose (parent (U)), _nthreads (), Val (true ))
377
+ return C
378
+ end
379
+ function ldiv! (C:: AbstractMatrix{T} , U:: UnitLowerTriangular{T} , A:: AbstractMatrix{T} , :: Val{false} ) where {T<: Union{Float32,Float64} }
380
+ div_dispatch! (transpose (C), transpose (A), transpose (parent (U)), static (0 ), Val (true ))
348
381
return C
349
382
end
350
383
@@ -423,15 +456,10 @@ function rdiv_block_MandN!(
423
456
end
424
457
nothing
425
458
end
426
- function _nthreads ()
427
- nc = VectorizationBase. num_cores ()
428
- nt = VectorizationBase. num_threads ()
429
- ifelse (Static. lt (nc,nt),nc,nt)
430
- end
431
- function m_thread_block_size (M, N, :: Val{T} ) where {T}
459
+ function m_thread_block_size (M, N, nthreads, :: Val{T} ) where {T}
432
460
W = VectorizationBase. pick_vector_width (T)
433
461
WUF = W * unroll_factor (W)
434
- nb = clamp (VectorizationBase. vdiv (M * N, StaticInt {256} () * W), 1 , Int ( _nthreads ()) )
462
+ nb = clamp (VectorizationBase. vdiv (M * N, StaticInt {256} () * W), 1 , nthreads )
435
463
min (M, VectorizationBase. vcld (M, nb* W)* W)
436
464
end
437
465
0 commit comments