24
24
const MatrixOrView{T} = Union{Matrix{T},SubArray{T,2 ,<: Array{T} }}
25
25
const VecOrView{T} = Union{Vector{T},SubArray{T,1 ,<: Array{T} }}
26
26
const BlasRealFloat = Union{Float32,Float64}
27
+ const BlasComplexFloat = Union{ComplexF32,ComplexF64}
27
28
28
29
"""
29
- arrayify(x::CoDual{<:AbstractArray{<:BlasRealFloat }})
30
+ arrayify(x::CoDual{<:AbstractArray{<:BlasFloat }})
30
31
31
32
Return the primal field of `x`, and convert its fdata into an array of the same type as the
32
33
primal. This operation is not guaranteed to be possible for all array types, but seems to be
33
34
possible for all array types of interest so far.
34
35
"""
35
- function arrayify (x:: CoDual{A} ) where {A<: AbstractArray{<:BlasRealFloat } }
36
- return arrayify (primal (x), tangent (x)):: Tuple{A,A}
36
+ function arrayify (x:: CoDual{A} ) where {A<: AbstractArray{<:BlasFloat } }
37
+ return arrayify (primal (x), tangent (x)) # NOTE: for complex number, the tangent is a reinterpreted version of the primal
37
38
end
38
39
arrayify (x:: Array{P} , dx:: Array{P} ) where {P<: BlasRealFloat } = (x, dx)
40
+ function arrayify (x:: Array{P} , dx:: Array{<:Tangent} ) where {P<: BlasComplexFloat }
41
+ return x, reinterpret (P, dx)
42
+ end
39
43
function arrayify (x:: A , dx:: FData ) where {A<: SubArray{<:BlasRealFloat} }
40
44
_, _dx = arrayify (x. parent, dx. data. parent)
41
45
return x, A (_dx, x. indices, x. offset1, x. stride1)
@@ -299,6 +303,45 @@ function rrule!!(
299
303
return y_dy, symv!_adjoint
300
304
end
301
305
306
+ @is_primitive (
307
+ MinimalCtx,
308
+ Tuple{
309
+ typeof (BLAS. nrm2),Int,X,Int
310
+ } where {T<: BlasFloat ,X<: Union{Ptr{T},AbstractArray{T}} },
311
+ )
312
+ function rrule!! (
313
+ :: CoDual{typeof(BLAS.nrm2)} ,
314
+ n:: CoDual{<:Integer} ,
315
+ X_dX:: CoDual{<:Union{Ptr{T},AbstractArray{T}} where {T<:BlasFloat}} ,
316
+ incx:: CoDual{<:Integer} ,
317
+ )
318
+ X, dX = arrayify (X_dX)
319
+ y = BLAS. nrm2 (n. x, X, incx. x)
320
+ function nrm2_pb!! (dy)
321
+ view (dX, 1 : (incx. x): (incx. x * n. x)) .+ =
322
+ view (X, 1 : (incx. x): (incx. x * n. x)) .* (dy / y)
323
+ return NoRData (), NoRData (), NoRData (), NoRData ()
324
+ end
325
+ return CoDual (y, NoFData ()), nrm2_pb!!
326
+ end
327
+
328
+ @is_primitive (
329
+ MinimalCtx,
330
+ Tuple{typeof (BLAS. nrm2),X} where {T<: BlasFloat ,X<: Union{Ptr{T},AbstractArray{T}} },
331
+ )
332
+ function rrule!! (
333
+ :: CoDual{typeof(BLAS.nrm2)} ,
334
+ X_dX:: CoDual{<:Union{Ptr{T},AbstractArray{T}} where {T<:BlasFloat}} ,
335
+ )
336
+ X, dX = arrayify (X_dX)
337
+ y = BLAS. nrm2 (X)
338
+ function nrm2_pb!! (dy)
339
+ dX .+ = X .* (dy / y) # TODO : verify for complex numbers
340
+ return NoRData (), NoRData ()
341
+ end
342
+ return CoDual (y, NoFData ()), nrm2_pb!!
343
+ end
344
+
302
345
@is_primitive (
303
346
MinimalCtx,
304
347
Tuple{
@@ -755,7 +798,7 @@ for (trsm, elty) in ((:dtrsm_, :Float64), (:strsm_, :Float32))
755
798
end
756
799
end
757
800
758
- function blas_matrices (rng:: AbstractRNG , P:: Type{<:BlasRealFloat } , p:: Int , q:: Int )
801
+ function blas_matrices (rng:: AbstractRNG , P:: Type{<:BlasFloat } , p:: Int , q:: Int )
759
802
Xs = Any[
760
803
randn (rng, P, p, q),
761
804
view (randn (rng, P, p + 5 , 2 q), 3 : (p + 2 ), 1 : 2 : (2 q)),
@@ -767,7 +810,7 @@ function blas_matrices(rng::AbstractRNG, P::Type{<:BlasRealFloat}, p::Int, q::In
767
810
return Xs
768
811
end
769
812
770
- function blas_vectors (rng:: AbstractRNG , P:: Type{<:BlasRealFloat } , p:: Int )
813
+ function blas_vectors (rng:: AbstractRNG , P:: Type{<:BlasFloat } , p:: Int )
771
814
xs = Any[
772
815
randn (rng, P, p),
773
816
view (randn (rng, P, p + 5 ), 3 : (p + 2 )),
@@ -789,6 +832,19 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas})
789
832
rng = rng_ctor (123456 )
790
833
791
834
test_cases = vcat (
835
+ # nrm2(x)
836
+ map_prod ([Ps... , ComplexF64, ComplexF32]) do (P,)
837
+ return map ([randn (rng, P, 105 )]) do x
838
+ (false , :none , nothing , BLAS. nrm2, x)
839
+ end
840
+ end ... ,
841
+
842
+ # nrm2(n, x, incx)
843
+ map_prod ([Ps... , ComplexF64, ComplexF32], [5 , 3 ], [1 , 2 ]) do (P, n, incx)
844
+ return map ([randn (rng, P, 105 )]) do x
845
+ (false , :none , nothing , BLAS. nrm2, n, x, incx)
846
+ end
847
+ end ... ,
792
848
793
849
# gemv!
794
850
map_prod (t_flags, [1 , 3 ], [1 , 2 ], Ps) do (tA, M, N, P)
0 commit comments