Skip to content

Commit 75785aa

Browse files
authored
Shuffle Code Around (#497)
* Move nrm2 to level 1 and remove redundant comment * Bump patch again
1 parent 142f687 commit 75785aa

File tree

2 files changed

+40
-40
lines changed

2 files changed

+40
-40
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Mooncake"
22
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
33
authors = ["Will Tebbutt, Hong Ge, and contributors"]
4-
version = "0.4.96"
4+
version = "0.4.97"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/rrules/blas.jl

+39-39
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,45 @@ for (fname, elty) in ((:cblas_ddot, :Float64), (:cblas_sdot, :Float32))
115115
end
116116
end
117117

118+
@is_primitive(
119+
MinimalCtx,
120+
Tuple{
121+
typeof(BLAS.nrm2),Int,X,Int
122+
} where {T<:BlasFloat,X<:Union{Ptr{T},AbstractArray{T}}},
123+
)
124+
function rrule!!(
125+
::CoDual{typeof(BLAS.nrm2)},
126+
n::CoDual{<:Integer},
127+
X_dX::CoDual{<:Union{Ptr{T},AbstractArray{T}} where {T<:BlasFloat}},
128+
incx::CoDual{<:Integer},
129+
)
130+
X, dX = arrayify(X_dX)
131+
y = BLAS.nrm2(n.x, X, incx.x)
132+
function nrm2_pb!!(dy)
133+
view(dX, 1:(incx.x):(incx.x * n.x)) .+=
134+
view(X, 1:(incx.x):(incx.x * n.x)) .* (dy / y)
135+
return NoRData(), NoRData(), NoRData(), NoRData()
136+
end
137+
return CoDual(y, NoFData()), nrm2_pb!!
138+
end
139+
140+
@is_primitive(
141+
MinimalCtx,
142+
Tuple{typeof(BLAS.nrm2),X} where {T<:BlasFloat,X<:Union{Ptr{T},AbstractArray{T}}},
143+
)
144+
function rrule!!(
145+
::CoDual{typeof(BLAS.nrm2)},
146+
X_dX::CoDual{<:Union{Ptr{T},AbstractArray{T}} where {T<:BlasFloat}},
147+
)
148+
X, dX = arrayify(X_dX)
149+
y = BLAS.nrm2(X)
150+
function nrm2_pb!!(dy)
151+
dX .+= X .* (dy / y)
152+
return NoRData(), NoRData()
153+
end
154+
return CoDual(y, NoFData()), nrm2_pb!!
155+
end
156+
118157
for (fname, elty) in ((:dscal_, :Float64), (:sscal_, :Float32))
119158
@eval @inline function Mooncake.rrule!!(
120159
::CoDual{typeof(_foreigncall_)},
@@ -303,45 +342,6 @@ function rrule!!(
303342
return y_dy, symv!_adjoint
304343
end
305344

306-
@is_primitive(
307-
MinimalCtx,
308-
Tuple{
309-
typeof(BLAS.nrm2),Int,X,Int
310-
} where {T<:BlasFloat,X<:Union{Ptr{T},AbstractArray{T}}},
311-
)
312-
function rrule!!(
313-
::CoDual{typeof(BLAS.nrm2)},
314-
n::CoDual{<:Integer},
315-
X_dX::CoDual{<:Union{Ptr{T},AbstractArray{T}} where {T<:BlasFloat}},
316-
incx::CoDual{<:Integer},
317-
)
318-
X, dX = arrayify(X_dX)
319-
y = BLAS.nrm2(n.x, X, incx.x)
320-
function nrm2_pb!!(dy)
321-
view(dX, 1:(incx.x):(incx.x * n.x)) .+=
322-
view(X, 1:(incx.x):(incx.x * n.x)) .* (dy / y)
323-
return NoRData(), NoRData(), NoRData(), NoRData()
324-
end
325-
return CoDual(y, NoFData()), nrm2_pb!!
326-
end
327-
328-
@is_primitive(
329-
MinimalCtx,
330-
Tuple{typeof(BLAS.nrm2),X} where {T<:BlasFloat,X<:Union{Ptr{T},AbstractArray{T}}},
331-
)
332-
function rrule!!(
333-
::CoDual{typeof(BLAS.nrm2)},
334-
X_dX::CoDual{<:Union{Ptr{T},AbstractArray{T}} where {T<:BlasFloat}},
335-
)
336-
X, dX = arrayify(X_dX)
337-
y = BLAS.nrm2(X)
338-
function nrm2_pb!!(dy)
339-
dX .+= X .* (dy / y) # TODO: verify for complex numbers
340-
return NoRData(), NoRData()
341-
end
342-
return CoDual(y, NoFData()), nrm2_pb!!
343-
end
344-
345345
@is_primitive(
346346
MinimalCtx,
347347
Tuple{

0 commit comments

Comments
 (0)