Skip to content

Commit c6dc209

Browse files
authored
WIP vector_mode_dual_eval -> vector_mode_dual_eval! (#528)
1 parent 80a13d9 commit c6dc209

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

src/apiutils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ end
3131

3232
@inline static_dual_eval(::Type{T}, f, x::StaticArray) where T = f(dualize(T, x))
3333

34-
function vector_mode_dual_eval(f::F, x, cfg::Union{JacobianConfig,GradientConfig}) where {F}
34+
function vector_mode_dual_eval!(f::F, cfg::Union{JacobianConfig,GradientConfig}, x) where {F}
3535
xdual = cfg.duals
3636
seed!(xdual, x, cfg.seeds)
3737
return f(xdual)
3838
end
3939

40-
function vector_mode_dual_eval(f!::F, y, x, cfg::JacobianConfig) where {F}
40+
function vector_mode_dual_eval!(f!::F, cfg::JacobianConfig, y, x) where {F}
4141
ydual, xdual = cfg.duals
4242
seed!(xdual, x, cfg.seeds)
4343
seed!(ydual, y)

src/gradient.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,14 @@ const GRAD_ERROR = DimensionMismatch("gradient(f, x) expects that f(x) is a real
103103
###############
104104

105105
function vector_mode_gradient(f::F, x, cfg::GradientConfig{T}) where {T, F}
106-
ydual = vector_mode_dual_eval(f, x, cfg)
106+
ydual = vector_mode_dual_eval!(f, cfg, x)
107107
ydual isa Real || throw(GRAD_ERROR)
108108
result = similar(x, valtype(ydual))
109109
return extract_gradient!(T, result, ydual)
110110
end
111111

112112
function vector_mode_gradient!(result, f::F, x, cfg::GradientConfig{T}) where {T, F}
113-
ydual = vector_mode_dual_eval(f, x, cfg)
113+
ydual = vector_mode_dual_eval!(f, cfg, x)
114114
result = extract_gradient!(T, result, ydual)
115115
return result
116116
end

src/jacobian.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ reshape_jacobian(result::DiffResult, ydual, xdual) = reshape_jacobian(DiffResult
144144
###############
145145

146146
function vector_mode_jacobian(f::F, x, cfg::JacobianConfig{T,V,N}) where {F,T,V,N}
147-
ydual = vector_mode_dual_eval(f, x, cfg)
147+
ydual = vector_mode_dual_eval!(f, cfg, x)
148148
ydual isa AbstractArray || throw(JACOBIAN_ERROR)
149149
result = similar(ydual, valtype(eltype(ydual)), length(ydual), N)
150150
extract_jacobian!(T, result, ydual, N)
@@ -153,7 +153,7 @@ function vector_mode_jacobian(f::F, x, cfg::JacobianConfig{T,V,N}) where {F,T,V,
153153
end
154154

155155
function vector_mode_jacobian(f!::F, y, x, cfg::JacobianConfig{T,V,N}) where {F,T,V,N}
156-
ydual = vector_mode_dual_eval(f!, y, x, cfg)
156+
ydual = vector_mode_dual_eval!(f!, cfg, y, x)
157157
map!(d -> value(T,d), y, ydual)
158158
result = similar(y, length(y), N)
159159
extract_jacobian!(T, result, ydual, N)
@@ -162,14 +162,14 @@ function vector_mode_jacobian(f!::F, y, x, cfg::JacobianConfig{T,V,N}) where {F,
162162
end
163163

164164
function vector_mode_jacobian!(result, f::F, x, cfg::JacobianConfig{T,V,N}) where {F,T,V,N}
165-
ydual = vector_mode_dual_eval(f, x, cfg)
165+
ydual = vector_mode_dual_eval!(f, cfg, x)
166166
extract_jacobian!(T, result, ydual, N)
167167
extract_value!(T, result, ydual)
168168
return result
169169
end
170170

171171
function vector_mode_jacobian!(result, f!::F, y, x, cfg::JacobianConfig{T,V,N}) where {F,T,V,N}
172-
ydual = vector_mode_dual_eval(f!, y, x, cfg)
172+
ydual = vector_mode_dual_eval!(f!, cfg, y, x)
173173
map!(d -> value(T,d), y, ydual)
174174
extract_jacobian!(T, result, ydual, N)
175175
extract_value!(T, result, y, ydual)

0 commit comments

Comments
 (0)