Skip to content

Commit 7fa5823

Browse files
authored
Merge pull request #24 from JuliaSIMD/fewerinvalidations
reduce invalidations by not using `CPUSummary.num_threads()`
2 parents 0d41240 + 72adc8c commit 7fa5823

File tree

2 files changed

+55
-27
lines changed

2 files changed

+55
-27
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TriangularSolve"
22
uuid = "d5829a12-d9aa-46ab-831f-fb7c9ab06edf"
33
authors = ["chriselrod <[email protected]> and contributors"]
4-
version = "0.1.13"
4+
version = "0.1.14"
55

66
[deps]
77
CloseOpenIntervals = "fb6a15b2-703c-40df-9091-08a04967cfa9"

src/TriangularSolve.jl

+54-26
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ const LDIVBUFFERS = Vector{UInt8}[]
293293
end
294294
_canonicalize(x) = signed(x)
295295
_canonicalize(::StaticInt{N}) where {N} = StaticInt{N}()
296-
function div_dispatch!(C::AbstractMatrix{T}, A, U, ::Val{UNIT}, ::Val{THREAD}) where {UNIT,T,THREAD}
296+
function div_dispatch!(C::AbstractMatrix{T}, A, U, nthread, ::Val{UNIT}) where {UNIT,T}
297297
_M, _N = size(A)
298298
M = _canonicalize(_M)
299299
N = _canonicalize(_N)
@@ -305,8 +305,8 @@ function div_dispatch!(C::AbstractMatrix{T}, A, U, ::Val{UNIT}, ::Val{THREAD}) w
305305
spc = zero_offsets(_spc)
306306
spu = zero_offsets(_spu)
307307
GC.@preserve spap spcp spup begin
308-
mtb = m_thread_block_size(M, N, Val(T))
309-
if THREAD && (VectorizationBase.num_threads() > 1)
308+
mtb = m_thread_block_size(M, N, nthread, Val(T))
309+
if nthread > 1
310310
(M > mtb) && return multithread_rdiv!(spc, spa, spu, M, N, mtb, Val(UNIT), VectorizationBase.contiguous_axis(A))
311311
elseif N > block_size(Val(T))
312312
return rdiv_block_MandN!(spc, spa, spu, M, N, Val(UNIT), VectorizationBase.contiguous_axis(A))
@@ -315,36 +315,69 @@ function div_dispatch!(C::AbstractMatrix{T}, A, U, ::Val{UNIT}, ::Val{THREAD}) w
315315
end
316316
end
317317

318-
function rdiv!(A::AbstractMatrix{T}, U::UpperTriangular{T}, ::Val{THREAD} = Val(true)) where {T<:Union{Float32,Float64},THREAD}
319-
div_dispatch!(A, A, parent(U), Val(false), Val(THREAD))
318+
_nthreads() = min(Int(VectorizationBase.num_cores())::Int, Threads.nthreads()::Int)
319+
function rdiv!(A::AbstractMatrix{T}, U::UpperTriangular{T}, ::Val{true} = Val(true)) where {T<:Union{Float32,Float64}}
320+
div_dispatch!(A, A, parent(U), _nthreads(), Val(false))
320321
return A
321322
end
322-
function rdiv!(C::AbstractMatrix{T}, A::AbstractMatrix{T}, U::UpperTriangular{T}, ::Val{THREAD} = Val(true)) where {T<:Union{Float32,Float64},THREAD}
323-
div_dispatch!(C, A, parent(U), Val(false), Val(THREAD))
323+
function rdiv!(A::AbstractMatrix{T}, U::UpperTriangular{T}, ::Val{false}) where {T<:Union{Float32,Float64}}
324+
div_dispatch!(A, A, parent(U), static(0), Val(false))
325+
return A
326+
end
327+
function rdiv!(C::AbstractMatrix{T}, A::AbstractMatrix{T}, U::UpperTriangular{T}, ::Val{true} = Val(true)) where {T<:Union{Float32,Float64}}
328+
div_dispatch!(C, A, parent(U), _nthreads(), Val(false))
324329
return C
325330
end
326-
function rdiv!(A::AbstractMatrix{T}, U::UnitUpperTriangular{T}, ::Val{THREAD} = Val(true)) where {T<:Union{Float32,Float64},THREAD}
327-
div_dispatch!(A, A, parent(U), Val(true), Val(THREAD))
331+
function rdiv!(C::AbstractMatrix{T}, A::AbstractMatrix{T}, U::UpperTriangular{T}, ::Val{false}) where {T<:Union{Float32,Float64}}
332+
div_dispatch!(C, A, parent(U), static(0), Val(false))
333+
return C
334+
end
335+
function rdiv!(A::AbstractMatrix{T}, U::UnitUpperTriangular{T}, ::Val{true} = Val(true)) where {T<:Union{Float32,Float64}}
336+
div_dispatch!(A, A, parent(U), _nthreads(), Val(true))
328337
return A
329338
end
330-
function rdiv!(C::AbstractMatrix{T}, A::AbstractMatrix{T}, U::UnitUpperTriangular{T}, ::Val{THREAD} = Val(true)) where {T<:Union{Float32,Float64},THREAD}
331-
div_dispatch!(C, A, parent(U), Val(true), Val(THREAD))
339+
function rdiv!(A::AbstractMatrix{T}, U::UnitUpperTriangular{T}, ::Val{false}) where {T<:Union{Float32,Float64}}
340+
div_dispatch!(A, A, parent(U), static(0), Val(true))
341+
return A
342+
end
343+
function rdiv!(C::AbstractMatrix{T}, A::AbstractMatrix{T}, U::UnitUpperTriangular{T}, ::Val{true} = Val(true)) where {T<:Union{Float32,Float64}}
344+
div_dispatch!(C, A, parent(U), _nthreads(), Val(true))
345+
return C
346+
end
347+
function rdiv!(C::AbstractMatrix{T}, A::AbstractMatrix{T}, U::UnitUpperTriangular{T}, ::Val{false}) where {T<:Union{Float32,Float64}}
348+
div_dispatch!(C, A, parent(U), static(0), Val(true))
332349
return C
333350
end
334-
function ldiv!(U::LowerTriangular{T}, A::AbstractMatrix{T}, ::Val{THREAD} = Val(true)) where {T<:Union{Float32,Float64},THREAD}
335-
div_dispatch!(transpose(A), transpose(A), transpose(parent(U)), Val(false), Val(THREAD))
351+
function ldiv!(U::LowerTriangular{T}, A::AbstractMatrix{T}, ::Val{true} = Val(true)) where {T<:Union{Float32,Float64}}
352+
div_dispatch!(transpose(A), transpose(A), transpose(parent(U)), _nthreads(), Val(false))
336353
return A
337354
end
338-
function ldiv!(C::AbstractMatrix{T}, U::LowerTriangular{T}, A::AbstractMatrix{T}, ::Val{THREAD} = Val(true)) where {T<:Union{Float32,Float64},THREAD}
339-
div_dispatch!(transpose(C), transpose(A), transpose(parent(U)), Val(false), Val(THREAD))
355+
function ldiv!(U::LowerTriangular{T}, A::AbstractMatrix{T}, ::Val{false}) where {T<:Union{Float32,Float64}}
356+
div_dispatch!(transpose(A), transpose(A), transpose(parent(U)), static(0), Val(false))
357+
return A
358+
end
359+
function ldiv!(C::AbstractMatrix{T}, U::LowerTriangular{T}, A::AbstractMatrix{T}, ::Val{true} = Val(true)) where {T<:Union{Float32,Float64}}
360+
div_dispatch!(transpose(C), transpose(A), transpose(parent(U)), _nthreads(), Val(false))
361+
return C
362+
end
363+
function ldiv!(C::AbstractMatrix{T}, U::LowerTriangular{T}, A::AbstractMatrix{T}, ::Val{false}) where {T<:Union{Float32,Float64}}
364+
div_dispatch!(transpose(C), transpose(A), transpose(parent(U)), static(0), Val(false))
340365
return C
341366
end
342-
function ldiv!(U::UnitLowerTriangular{T}, A::AbstractMatrix{T}, ::Val{THREAD} = Val(true)) where {T<:Union{Float32,Float64},THREAD}
343-
div_dispatch!(transpose(A), transpose(A), transpose(parent(U)), Val(true), Val(THREAD))
367+
function ldiv!(U::UnitLowerTriangular{T}, A::AbstractMatrix{T}, ::Val{true} = Val(true)) where {T<:Union{Float32,Float64}}
368+
div_dispatch!(transpose(A), transpose(A), transpose(parent(U)), _nthreads(), Val(true))
369+
return A
370+
end
371+
function ldiv!(U::UnitLowerTriangular{T}, A::AbstractMatrix{T}, ::Val{false}) where {T<:Union{Float32,Float64}}
372+
div_dispatch!(transpose(A), transpose(A), transpose(parent(U)), static(0), Val(true))
344373
return A
345374
end
346-
function ldiv!(C::AbstractMatrix{T}, U::UnitLowerTriangular{T}, A::AbstractMatrix{T}, ::Val{THREAD} = Val(true)) where {T<:Union{Float32,Float64},THREAD}
347-
div_dispatch!(transpose(C), transpose(A), transpose(parent(U)), Val(true), Val(THREAD))
375+
function ldiv!(C::AbstractMatrix{T}, U::UnitLowerTriangular{T}, A::AbstractMatrix{T}, ::Val{true} = Val(true)) where {T<:Union{Float32,Float64}}
376+
div_dispatch!(transpose(C), transpose(A), transpose(parent(U)), _nthreads(), Val(true))
377+
return C
378+
end
379+
function ldiv!(C::AbstractMatrix{T}, U::UnitLowerTriangular{T}, A::AbstractMatrix{T}, ::Val{false}) where {T<:Union{Float32,Float64}}
380+
div_dispatch!(transpose(C), transpose(A), transpose(parent(U)), static(0), Val(true))
348381
return C
349382
end
350383

@@ -423,15 +456,10 @@ function rdiv_block_MandN!(
423456
end
424457
nothing
425458
end
426-
function _nthreads()
427-
nc = VectorizationBase.num_cores()
428-
nt = VectorizationBase.num_threads()
429-
ifelse(Static.lt(nc,nt),nc,nt)
430-
end
431-
function m_thread_block_size(M, N, ::Val{T}) where {T}
459+
function m_thread_block_size(M, N, nthreads, ::Val{T}) where {T}
432460
W = VectorizationBase.pick_vector_width(T)
433461
WUF = W * unroll_factor(W)
434-
nb = clamp(VectorizationBase.vdiv(M * N, StaticInt{256}() * W), 1, Int(_nthreads()))
462+
nb = clamp(VectorizationBase.vdiv(M * N, StaticInt{256}() * W), 1, nthreads)
435463
min(M, VectorizationBase.vcld(M, nb*W)*W)
436464
end
437465

0 commit comments

Comments
 (0)