Skip to content

Commit 3770f8b

Browse files
committed
Add derivatives API
1 parent 39d0c76 commit 3770f8b

File tree

1 file changed

+44
-55
lines changed

1 file changed

+44
-55
lines changed

src/derivative.jl

Lines changed: 44 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,11 @@
11

2-
export derivative, derivative!
2+
export derivative, derivative!, derivatives, make_seed
33

44
"""
5-
derivative(f, x, order::Int64)
6-
derivative(f, x, l, order::Int64)
7-
8-
Wrapper functions for converting order from a number to a type. Actual APIs are detailed below:
9-
10-
derivative(f, x::T, ::Val{N})
11-
12-
Computes `order`-th derivative of `f` w.r.t. scalar `x`.
13-
14-
derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, ::Val{N})
5+
derivative(f, x, l, ::Val{N})
6+
derivative(f!, y, x, l, ::Val{N})
157
168
Computes `order`-th directional derivative of `f` w.r.t. vector `x` in direction `l`.
17-
18-
derivative(f, x::AbstractMatrix{T}, ::Val{N})
19-
derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, ::Val{N})
20-
21-
Batch mode derivative / directional derivative calculations, where each column of `x` represents a scalar or a vector. `f` is expected to accept matrices as input.
22-
- For a M-by-N matrix, calculate the directional derivative for each column.
23-
- For a 1-by-N matrix (row vector), calculate the derivative for each scalar.
249
"""
2510
function derivative end
2611

@@ -32,54 +17,58 @@ In-place derivative calculation APIs. `result` is expected to be pre-allocated a
3217
"""
3318
function derivative! end
3419

20+
"""
21+
derivatives(f, x, l, ::Val{N})
22+
derivatives(f!, y, x, l, ::Val{N})
23+
24+
Computes all derivatives of `f` at `x` up to order `N - 1`.
25+
"""
26+
function derivatives end
27+
28+
# Convenience wrapper for adding unit seed to the input
29+
30+
@inline derivative(f, x, order::Int64) = derivative(f, x, one(eltype(x)), order)
31+
3532
# Convenience wrappers for converting orders to value types
3633
# and forward work to core APIs
3734

38-
@inline derivative(f, x, order::Int64) = derivative(f, x, one(eltype(x)), order)
3935
@inline derivative(f, x, l, order::Int64) = derivative(f, x, l, Val{order + 1}())
36+
@inline derivative(f!, y, x, l, order::Int64) = derivative(f!, y, x, l, Val{order + 1}())
37+
@inline derivative!(result, f, x, l, order::Int64) = derivative!(
38+
result, f, x, l, Val{order + 1}())
39+
@inline derivative!(result, f!, y, x, l, order::Int64) = derivative!(
40+
result, f!, y, x, l, Val{order + 1}())
4041

4142
# Core APIs
4243

4344
# Added to help Zygote infer types
44-
@inline function make_taylor(x::T, l::S, ::Val{N}) where {T <: TN, S <: TN, N}
45+
@inline function make_seed(x::T, l::S, ::Val{N}) where {T <: TN, S <: TN, N}
4546
TaylorScalar{T, N}(x, convert(T, l))
4647
end
4748

48-
@inline function make_taylor(x::AbstractArray{T}, l, vN::Val{N}) where {T <: TN, N}
49-
broadcast(make_taylor, x, l, vN)
50-
end
51-
52-
# Out-of-place function, out-of-place derivative
53-
@inline function derivative(f, x, l, vN::Val{N}) where {N}
54-
t = make_taylor(x, l, vN)
55-
return extract_derivative(f(t), N)
56-
end
57-
58-
# Below three advanced APIs do not have convenience wrappers
59-
60-
# In-place function, out-of-place derivative
61-
@inline function derivative(f!, y::AbstractArray{T}, x, l, vN::Val{N}) where {T, N}
62-
s = similar(y, TaylorScalar{T, N})
63-
t = make_taylor(x, l, vN)
64-
f!(s, t)
65-
map!(primal, y, s)
66-
return extract_derivative(s, N)
67-
end
68-
69-
# Out-of-place function, in-place derivative
70-
@inline function derivative!(result, f, x, l, vN::Val{N}) where {N}
71-
t = make_taylor(x, l, vN)
72-
s = f(t)
73-
extract_derivative!(result, s, N)
74-
return result
49+
@inline function make_seed(x::AbstractArray{T}, l, vN::Val{N}) where {T <: TN, N}
50+
broadcast(make_seed, x, l, vN)
7551
end
7652

77-
# In-place function, in-place derivative
78-
@inline function derivative!(result, f!, y::AbstractArray{T}, x, l, vN::Val{N}) where {T, N}
79-
s = similar(y, TaylorScalar{T, N})
80-
t = make_taylor(x, l, vN)
81-
f!(s, t)
82-
map!(primal, y, s)
83-
extract_derivative!(result, s, N)
84-
return result
53+
# `derivative` API: computes the `N - 1`-th derivative of `f` at `x`
54+
@inline derivative(f, x, l, vN::Val{N}) where {N} = extract_derivative(
55+
derivatives(f, x, l, vN), N)
56+
@inline derivative(f!, y, x, l, vN::Val{N}) where {N} = extract_derivative(
57+
derivatives(f!, y, x, l, vN), N)
58+
@inline derivative!(result, f, x, l, vN::Val{N}) where {N} = extract_derivative!(
59+
result, derivatives(f, x, l, vN), N)
60+
@inline derivative!(result, f!, y, x, l, vN::Val{N}) where {N} = extract_derivative!(
61+
result, derivatives(f!, y, x, l, vN), N)
62+
63+
# `derivatives` API: computes all derivatives of `f` at `x` up to order `N - 1`
64+
65+
# Out-of-place function
66+
@inline derivatives(f, x, l, vN::Val{N}) where {N} = f(make_seed(x, l, vN))
67+
68+
# In-place function
69+
@inline function derivatives(f!, y::AbstractArray{T}, x, l, vN::Val{N}) where {T, N}
70+
buffer = similar(y, TaylorScalar{T, N})
71+
f!(buffer, make_seed(x, l, vN))
72+
map!(primal, y, buffer)
73+
return buffer
8574
end

0 commit comments

Comments
 (0)