@@ -115,6 +115,45 @@ for (fname, elty) in ((:cblas_ddot, :Float64), (:cblas_sdot, :Float32))
115
115
end
116
116
end
117
117
118
+ @is_primitive (
119
+ MinimalCtx,
120
+ Tuple{
121
+ typeof (BLAS. nrm2),Int,X,Int
122
+ } where {T<: BlasFloat ,X<: Union{Ptr{T},AbstractArray{T}} },
123
+ )
124
+ function rrule!! (
125
+ :: CoDual{typeof(BLAS.nrm2)} ,
126
+ n:: CoDual{<:Integer} ,
127
+ X_dX:: CoDual{<:Union{Ptr{T},AbstractArray{T}} where {T<:BlasFloat}} ,
128
+ incx:: CoDual{<:Integer} ,
129
+ )
130
+ X, dX = arrayify (X_dX)
131
+ y = BLAS. nrm2 (n. x, X, incx. x)
132
+ function nrm2_pb!! (dy)
133
+ view (dX, 1 : (incx. x): (incx. x * n. x)) .+ =
134
+ view (X, 1 : (incx. x): (incx. x * n. x)) .* (dy / y)
135
+ return NoRData (), NoRData (), NoRData (), NoRData ()
136
+ end
137
+ return CoDual (y, NoFData ()), nrm2_pb!!
138
+ end
139
+
140
+ @is_primitive (
141
+ MinimalCtx,
142
+ Tuple{typeof (BLAS. nrm2),X} where {T<: BlasFloat ,X<: Union{Ptr{T},AbstractArray{T}} },
143
+ )
144
+ function rrule!! (
145
+ :: CoDual{typeof(BLAS.nrm2)} ,
146
+ X_dX:: CoDual{<:Union{Ptr{T},AbstractArray{T}} where {T<:BlasFloat}} ,
147
+ )
148
+ X, dX = arrayify (X_dX)
149
+ y = BLAS. nrm2 (X)
150
+ function nrm2_pb!! (dy)
151
+ dX .+ = X .* (dy / y)
152
+ return NoRData (), NoRData ()
153
+ end
154
+ return CoDual (y, NoFData ()), nrm2_pb!!
155
+ end
156
+
118
157
for (fname, elty) in ((:dscal_ , :Float64 ), (:sscal_ , :Float32 ))
119
158
@eval @inline function Mooncake. rrule!! (
120
159
:: CoDual{typeof(_foreigncall_)} ,
@@ -303,45 +342,6 @@ function rrule!!(
303
342
return y_dy, symv!_adjoint
304
343
end
305
344
306
- @is_primitive (
307
- MinimalCtx,
308
- Tuple{
309
- typeof (BLAS. nrm2),Int,X,Int
310
- } where {T<: BlasFloat ,X<: Union{Ptr{T},AbstractArray{T}} },
311
- )
312
- function rrule!! (
313
- :: CoDual{typeof(BLAS.nrm2)} ,
314
- n:: CoDual{<:Integer} ,
315
- X_dX:: CoDual{<:Union{Ptr{T},AbstractArray{T}} where {T<:BlasFloat}} ,
316
- incx:: CoDual{<:Integer} ,
317
- )
318
- X, dX = arrayify (X_dX)
319
- y = BLAS. nrm2 (n. x, X, incx. x)
320
- function nrm2_pb!! (dy)
321
- view (dX, 1 : (incx. x): (incx. x * n. x)) .+ =
322
- view (X, 1 : (incx. x): (incx. x * n. x)) .* (dy / y)
323
- return NoRData (), NoRData (), NoRData (), NoRData ()
324
- end
325
- return CoDual (y, NoFData ()), nrm2_pb!!
326
- end
327
-
328
- @is_primitive (
329
- MinimalCtx,
330
- Tuple{typeof (BLAS. nrm2),X} where {T<: BlasFloat ,X<: Union{Ptr{T},AbstractArray{T}} },
331
- )
332
- function rrule!! (
333
- :: CoDual{typeof(BLAS.nrm2)} ,
334
- X_dX:: CoDual{<:Union{Ptr{T},AbstractArray{T}} where {T<:BlasFloat}} ,
335
- )
336
- X, dX = arrayify (X_dX)
337
- y = BLAS. nrm2 (X)
338
- function nrm2_pb!! (dy)
339
- dX .+ = X .* (dy / y) # TODO : verify for complex numbers
340
- return NoRData (), NoRData ()
341
- end
342
- return CoDual (y, NoFData ()), nrm2_pb!!
343
- end
344
-
345
345
@is_primitive (
346
346
MinimalCtx,
347
347
Tuple{
0 commit comments