Skip to content

Commit 7d00e19

Browse files
authored
Update interface (#128)
* Update interface * Bump patch version
1 parent 0f7976c commit 7d00e19

File tree

4 files changed

+9
-6
lines changed

4 files changed

+9
-6
lines changed

.github/workflows/CI.yml

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ jobs:
3131
- 'integration_testing/array'
3232
- 'integration_testing/turing'
3333
- 'integration_testing/temporalgps'
34+
- 'interface'
3435
steps:
3536
- uses: actions/checkout@v4
3637
- uses: julia-actions/setup-julia@v1

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Tapir"
22
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
33
authors = ["Will Tebbutt, Hong Ge, and contributors"]
4-
version = "0.2.2"
4+
version = "0.2.3"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/interface.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ ensure that you zero-out the tangent fields of `x` each time.
88
"""
99
function value_and_pullback!!(rule::R, ȳ::T, fx::Vararg{CoDual, N}) where {R, N, T}
1010
out, pb!! = rule(fx...)
11-
@assert _typeof(tangent(out)) == T
12-
ty = increment!!(tangent(out), )
11+
@assert _typeof(tangent(out)) == fdata_type(T)
12+
increment!!(tangent(out), fdata(ȳ))
1313
v = copy(primal(out))
14-
return v, pb!!(ty, map(tangent, fx)...)
14+
return v, pb!!(rdata(ȳ))
1515
end
1616

1717
"""
@@ -48,7 +48,7 @@ use-case, consider pre-allocating the `CoDual`s and calling the other method of
4848
function.
4949
"""
5050
function value_and_pullback!!(rule::R, ȳ, fx::Vararg{Any, N}) where {R, N}
51-
return value_and_pullback!!(rule, ȳ, map(zero_codual, fx)...)
51+
return value_and_pullback!!(rule, ȳ, map(zero_fcodual, fx)...)
5252
end
5353

5454
"""
@@ -57,5 +57,5 @@ end
5757
Equivalent to `value_and_pullback(rule, 1.0, f, x...)` -- assumes `f` returns a `Float64`.
5858
"""
5959
function value_and_gradient!!(rule::R, fx::Vararg{Any, N}) where {R, N}
60-
return value_and_gradient!!(rule, map(zero_codual, fx)...)
60+
return value_and_gradient!!(rule, map(zero_fcodual, fx)...)
6161
end

test/runtests.jl

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ include("front_matter.jl")
5858
include(joinpath("integration_testing", "turing.jl"))
5959
elseif test_group == "integration_testing/temporalgps"
6060
include(joinpath("integration_testing", "temporalgps.jl"))
61+
elseif test_group == "interface"
62+
include("interface.jl")
6163
else
6264
throw(error("test_group=$(test_group) is not recognised"))
6365
end

0 commit comments

Comments
 (0)