1
1
2
- export derivative, derivative!
2
+ export derivative, derivative!, derivatives, make_seed
3
3
4
4
"""
5
- derivative(f, x, order::Int64)
6
- derivative(f, x, l, order::Int64)
7
-
8
- Wrapper functions for converting order from a number to a type. Actual APIs are detailed below:
9
-
10
- derivative(f, x::T, ::Val{N})
11
-
12
- Computes `order`-th derivative of `f` w.r.t. scalar `x`.
13
-
14
- derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, ::Val{N})
5
+ derivative(f, x, l, ::Val{N})
6
+ derivative(f!, y, x, l, ::Val{N})
15
7
16
8
Computes `order`-th directional derivative of `f` w.r.t. vector `x` in direction `l`.
17
-
18
- derivative(f, x::AbstractMatrix{T}, ::Val{N})
19
- derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, ::Val{N})
20
-
21
- Batch mode derivative / directional derivative calculations, where each column of `x` represents a scalar or a vector. `f` is expected to accept matrices as input.
22
- - For a M-by-N matrix, calculate the directional derivative for each column.
23
- - For a 1-by-N matrix (row vector), calculate the derivative for each scalar.
24
9
"""
25
10
function derivative end
26
11
@@ -32,54 +17,58 @@ In-place derivative calculation APIs. `result` is expected to be pre-allocated a
32
17
"""
33
18
function derivative! end
34
19
20
+ """
21
+ derivatives(f, x, l, ::Val{N})
22
+ derivatives(f!, y, x, l, ::Val{N})
23
+
24
+ Computes all derivatives of `f` at `x` up to order `N - 1`.
25
+ """
26
+ function derivatives end
27
+
28
+ # Convenience wrapper for adding unit seed to the input
29
+
30
+ @inline derivative (f, x, order:: Int64 ) = derivative (f, x, one (eltype (x)), order)
31
+
35
32
# Convenience wrappers for converting orders to value types
36
33
# and forward work to core APIs
37
34
38
- @inline derivative (f, x, order:: Int64 ) = derivative (f, x, one (eltype (x)), order)
39
35
@inline derivative (f, x, l, order:: Int64 ) = derivative (f, x, l, Val {order + 1} ())
36
+ @inline derivative (f!, y, x, l, order:: Int64 ) = derivative (f!, y, x, l, Val {order + 1} ())
37
+ @inline derivative! (result, f, x, l, order:: Int64 ) = derivative! (
38
+ result, f, x, l, Val {order + 1} ())
39
+ @inline derivative! (result, f!, y, x, l, order:: Int64 ) = derivative! (
40
+ result, f!, y, x, l, Val {order + 1} ())
40
41
41
42
# Core APIs
42
43
43
44
# Added to help Zygote infer types
44
- @inline function make_taylor (x:: T , l:: S , :: Val{N} ) where {T <: TN , S <: TN , N}
45
+ @inline function make_seed (x:: T , l:: S , :: Val{N} ) where {T <: TN , S <: TN , N}
45
46
TaylorScalar {T, N} (x, convert (T, l))
46
47
end
47
48
48
- @inline function make_taylor (x:: AbstractArray{T} , l, vN:: Val{N} ) where {T <: TN , N}
49
- broadcast (make_taylor, x, l, vN)
50
- end
51
-
52
- # Out-of-place function, out-of-place derivative
53
- @inline function derivative (f, x, l, vN:: Val{N} ) where {N}
54
- t = make_taylor (x, l, vN)
55
- return extract_derivative (f (t), N)
56
- end
57
-
58
- # Below three advanced APIs do not have convenience wrappers
59
-
60
- # In-place function, out-of-place derivative
61
- @inline function derivative (f!, y:: AbstractArray{T} , x, l, vN:: Val{N} ) where {T, N}
62
- s = similar (y, TaylorScalar{T, N})
63
- t = make_taylor (x, l, vN)
64
- f! (s, t)
65
- map! (primal, y, s)
66
- return extract_derivative (s, N)
67
- end
68
-
69
- # Out-of-place function, in-place derivative
70
- @inline function derivative! (result, f, x, l, vN:: Val{N} ) where {N}
71
- t = make_taylor (x, l, vN)
72
- s = f (t)
73
- extract_derivative! (result, s, N)
74
- return result
49
+ @inline function make_seed (x:: AbstractArray{T} , l, vN:: Val{N} ) where {T <: TN , N}
50
+ broadcast (make_seed, x, l, vN)
75
51
end
76
52
77
- # In-place function, in-place derivative
78
- @inline function derivative! (result, f!, y:: AbstractArray{T} , x, l, vN:: Val{N} ) where {T, N}
79
- s = similar (y, TaylorScalar{T, N})
80
- t = make_taylor (x, l, vN)
81
- f! (s, t)
82
- map! (primal, y, s)
83
- extract_derivative! (result, s, N)
84
- return result
53
+ # `derivative` API: computes the `N - 1`-th derivative of `f` at `x`
54
+ @inline derivative (f, x, l, vN:: Val{N} ) where {N} = extract_derivative (
55
+ derivatives (f, x, l, vN), N)
56
+ @inline derivative (f!, y, x, l, vN:: Val{N} ) where {N} = extract_derivative (
57
+ derivatives (f!, y, x, l, vN), N)
58
+ @inline derivative! (result, f, x, l, vN:: Val{N} ) where {N} = extract_derivative! (
59
+ result, derivatives (f, x, l, vN), N)
60
+ @inline derivative! (result, f!, y, x, l, vN:: Val{N} ) where {N} = extract_derivative! (
61
+ result, derivatives (f!, y, x, l, vN), N)
62
+
63
+ # `derivatives` API: computes all derivatives of `f` at `x` up to order `N - 1`
64
+
65
+ # Out-of-place function
66
+ @inline derivatives (f, x, l, vN:: Val{N} ) where {N} = f (make_seed (x, l, vN))
67
+
68
+ # In-place function
69
+ @inline function derivatives (f!, y:: AbstractArray{T} , x, l, vN:: Val{N} ) where {T, N}
70
+ buffer = similar (y, TaylorScalar{T, N})
71
+ f! (buffer, make_seed (x, l, vN))
72
+ map! (primal, y, buffer)
73
+ return buffer
85
74
end
0 commit comments