Skip to content

Commit 142f687

Browse files
New - the backward rule for BLAS.nrm2 (#496)
* new - the backward rule for BLAS.nrm2 * fix tests * resolve issues * Update src/rrules/blas.jl Co-authored-by: Will Tebbutt <[email protected]> Signed-off-by: Jinguo Liu <[email protected]> * bump version --------- Signed-off-by: Jinguo Liu <[email protected]> Co-authored-by: Will Tebbutt <[email protected]>
1 parent 6c99c65 commit 142f687

File tree

3 files changed

+63
-7
lines changed

3 files changed

+63
-7
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Mooncake"
22
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
33
authors = ["Will Tebbutt, Hong Ge, and contributors"]
4-
version = "0.4.95"
4+
version = "0.4.96"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/Mooncake.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ using Core:
4343
compilerbarrier
4444
using Core.Compiler: IRCode, NewInstruction
4545
using Core.Intrinsics: pointerref, pointerset
46-
using LinearAlgebra.BLAS: @blasfunc, BlasInt, trsm!
46+
using LinearAlgebra.BLAS: @blasfunc, BlasInt, trsm!, BlasFloat
4747
using LinearAlgebra.LAPACK: getrf!, getrs!, getri!, trtrs!, potrf!, potrs!
4848
using FunctionWrappers: FunctionWrapper
4949

src/rrules/blas.jl

+61-5
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,22 @@ end
2424
const MatrixOrView{T} = Union{Matrix{T},SubArray{T,2,<:Array{T}}}
2525
const VecOrView{T} = Union{Vector{T},SubArray{T,1,<:Array{T}}}
2626
const BlasRealFloat = Union{Float32,Float64}
27+
const BlasComplexFloat = Union{ComplexF32,ComplexF64}
2728

2829
"""
29-
arrayify(x::CoDual{<:AbstractArray{<:BlasRealFloat}})
30+
arrayify(x::CoDual{<:AbstractArray{<:BlasFloat}})
3031
3132
Return the primal field of `x`, and convert its fdata into an array of the same type as the
3233
primal. This operation is not guaranteed to be possible for all array types, but seems to be
3334
possible for all array types of interest so far.
3435
"""
35-
function arrayify(x::CoDual{A}) where {A<:AbstractArray{<:BlasRealFloat}}
36-
return arrayify(primal(x), tangent(x))::Tuple{A,A}
36+
function arrayify(x::CoDual{A}) where {A<:AbstractArray{<:BlasFloat}}
37+
return arrayify(primal(x), tangent(x)) # NOTE: for complex number, the tangent is a reinterpreted version of the primal
3738
end
3839
arrayify(x::Array{P}, dx::Array{P}) where {P<:BlasRealFloat} = (x, dx)
40+
function arrayify(x::Array{P}, dx::Array{<:Tangent}) where {P<:BlasComplexFloat}
41+
return x, reinterpret(P, dx)
42+
end
3943
function arrayify(x::A, dx::FData) where {A<:SubArray{<:BlasRealFloat}}
4044
_, _dx = arrayify(x.parent, dx.data.parent)
4145
return x, A(_dx, x.indices, x.offset1, x.stride1)
@@ -299,6 +303,45 @@ function rrule!!(
299303
return y_dy, symv!_adjoint
300304
end
301305

306+
@is_primitive(
307+
MinimalCtx,
308+
Tuple{
309+
typeof(BLAS.nrm2),Int,X,Int
310+
} where {T<:BlasFloat,X<:Union{Ptr{T},AbstractArray{T}}},
311+
)
312+
function rrule!!(
313+
::CoDual{typeof(BLAS.nrm2)},
314+
n::CoDual{<:Integer},
315+
X_dX::CoDual{<:Union{Ptr{T},AbstractArray{T}} where {T<:BlasFloat}},
316+
incx::CoDual{<:Integer},
317+
)
318+
X, dX = arrayify(X_dX)
319+
y = BLAS.nrm2(n.x, X, incx.x)
320+
function nrm2_pb!!(dy)
321+
view(dX, 1:(incx.x):(incx.x * n.x)) .+=
322+
view(X, 1:(incx.x):(incx.x * n.x)) .* (dy / y)
323+
return NoRData(), NoRData(), NoRData(), NoRData()
324+
end
325+
return CoDual(y, NoFData()), nrm2_pb!!
326+
end
327+
328+
@is_primitive(
329+
MinimalCtx,
330+
Tuple{typeof(BLAS.nrm2),X} where {T<:BlasFloat,X<:Union{Ptr{T},AbstractArray{T}}},
331+
)
332+
function rrule!!(
333+
::CoDual{typeof(BLAS.nrm2)},
334+
X_dX::CoDual{<:Union{Ptr{T},AbstractArray{T}} where {T<:BlasFloat}},
335+
)
336+
X, dX = arrayify(X_dX)
337+
y = BLAS.nrm2(X)
338+
function nrm2_pb!!(dy)
339+
dX .+= X .* (dy / y) # TODO: verify for complex numbers
340+
return NoRData(), NoRData()
341+
end
342+
return CoDual(y, NoFData()), nrm2_pb!!
343+
end
344+
302345
@is_primitive(
303346
MinimalCtx,
304347
Tuple{
@@ -755,7 +798,7 @@ for (trsm, elty) in ((:dtrsm_, :Float64), (:strsm_, :Float32))
755798
end
756799
end
757800

758-
function blas_matrices(rng::AbstractRNG, P::Type{<:BlasRealFloat}, p::Int, q::Int)
801+
function blas_matrices(rng::AbstractRNG, P::Type{<:BlasFloat}, p::Int, q::Int)
759802
Xs = Any[
760803
randn(rng, P, p, q),
761804
view(randn(rng, P, p + 5, 2q), 3:(p + 2), 1:2:(2q)),
@@ -767,7 +810,7 @@ function blas_matrices(rng::AbstractRNG, P::Type{<:BlasRealFloat}, p::Int, q::In
767810
return Xs
768811
end
769812

770-
function blas_vectors(rng::AbstractRNG, P::Type{<:BlasRealFloat}, p::Int)
813+
function blas_vectors(rng::AbstractRNG, P::Type{<:BlasFloat}, p::Int)
771814
xs = Any[
772815
randn(rng, P, p),
773816
view(randn(rng, P, p + 5), 3:(p + 2)),
@@ -789,6 +832,19 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas})
789832
rng = rng_ctor(123456)
790833

791834
test_cases = vcat(
835+
# nrm2(x)
836+
map_prod([Ps..., ComplexF64, ComplexF32]) do (P,)
837+
return map([randn(rng, P, 105)]) do x
838+
(false, :none, nothing, BLAS.nrm2, x)
839+
end
840+
end...,
841+
842+
# nrm2(n, x, incx)
843+
map_prod([Ps..., ComplexF64, ComplexF32], [5, 3], [1, 2]) do (P, n, incx)
844+
return map([randn(rng, P, 105)]) do x
845+
(false, :none, nothing, BLAS.nrm2, n, x, incx)
846+
end
847+
end...,
792848

793849
# gemv!
794850
map_prod(t_flags, [1, 3], [1, 2], Ps) do (tA, M, N, P)

0 commit comments

Comments
 (0)