Skip to content

_copy_output() to replace Base.copy() usage. #529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 52 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
499064f
_copy_temp() to replace Base.copy()
AstitvaAggarwal Mar 21, 2025
b197109
function declaration after _BuiltinArrays const.
AstitvaAggarwal Mar 21, 2025
24d4ffe
additional tests, more method dispatches
AstitvaAggarwal Mar 21, 2025
2eafc8f
Faster, better, stronger
AstitvaAggarwal Mar 22, 2025
70afdbc
tests
AstitvaAggarwal Mar 22, 2025
1378553
minor fix
AstitvaAggarwal Mar 22, 2025
8d7ee73
Arrays case
AstitvaAggarwal Mar 22, 2025
57228de
closer performance to Base.copy()
AstitvaAggarwal Mar 22, 2025
89cb6ec
mutable composite types
AstitvaAggarwal Mar 23, 2025
d9fbde2
mutable composite types
AstitvaAggarwal Mar 23, 2025
b42d3d5
reduced allocations to copy call, improved testing
AstitvaAggarwal Mar 23, 2025
622c394
format
AstitvaAggarwal Mar 23, 2025
078c9d7
tests.
AstitvaAggarwal Mar 23, 2025
c4abc6f
why does CI fail?
AstitvaAggarwal Mar 23, 2025
97137cf
use ccall with _BuiltinArrays copy logic
AstitvaAggarwal Mar 23, 2025
ddc9461
using Base.deepcopy() handling of structs
AstitvaAggarwal Mar 23, 2025
3638492
format
AstitvaAggarwal Mar 23, 2025
1d59810
modify tests
AstitvaAggarwal Mar 23, 2025
f0bc196
modify testing
AstitvaAggarwal Mar 23, 2025
144d2f4
nfields zero case
AstitvaAggarwal Mar 23, 2025
a940fb9
see why CI is failing
AstitvaAggarwal Mar 23, 2025
711fb47
test why CI fails
AstitvaAggarwal Mar 23, 2025
444aa8c
.
AstitvaAggarwal Mar 23, 2025
eca35a0
testing over closures
AstitvaAggarwal Mar 23, 2025
134691c
NaN handling
AstitvaAggarwal Mar 23, 2025
502ea8d
NaN Handling
AstitvaAggarwal Mar 23, 2025
c4487ed
See why CI fails
AstitvaAggarwal Mar 24, 2025
2a968ec
Unitialized mutable structs...
AstitvaAggarwal Mar 24, 2025
02a5809
fix tests
AstitvaAggarwal Mar 24, 2025
724d841
nospecialize
AstitvaAggarwal Mar 24, 2025
3b70634
type dispatch over uninitialized Float64, Int64
AstitvaAggarwal Mar 24, 2025
8dfe0e6
Update interface.jl
AstitvaAggarwal Mar 24, 2025
c5809ff
Update interface.jl
AstitvaAggarwal Mar 24, 2025
6cae723
Update interface.jl
AstitvaAggarwal Mar 25, 2025
b796192
check CI failures
AstitvaAggarwal Mar 25, 2025
ee58a00
.
AstitvaAggarwal Mar 25, 2025
0321bda
.
AstitvaAggarwal Mar 25, 2025
1c863c9
Attention to detail!
AstitvaAggarwal Mar 25, 2025
c950e03
Prediciton : it passes all tests.
AstitvaAggarwal Mar 25, 2025
0a3fd0e
remove call from value_and_pullback!!
AstitvaAggarwal Mar 25, 2025
e2d1789
remove usage completely for once
AstitvaAggarwal Mar 25, 2025
ffddc96
Merge branch 'main' into develop
AstitvaAggarwal Mar 25, 2025
307509f
_copy_temp for value_and_pullback!! & prepare_pullback_cache
AstitvaAggarwal Mar 26, 2025
c4ce5c0
Merge branch 'develop' of https://github.com/AstitvaAggarwal/Mooncake…
AstitvaAggarwal Mar 26, 2025
0eb390d
Merge branch 'main' into develop
AstitvaAggarwal Mar 28, 2025
d6e9f94
Merge branch 'main' into develop
yebai Mar 30, 2025
d2df49f
changes from reviews
AstitvaAggarwal Apr 1, 2025
a7a482f
Merge branch 'develop' of https://github.com/AstitvaAggarwal/Mooncake…
AstitvaAggarwal Apr 1, 2025
dff9c21
version update
AstitvaAggarwal Apr 1, 2025
94c261f
rename to _copy_output
AstitvaAggarwal Apr 2, 2025
edccbed
Merge branch 'main' into develop
AstitvaAggarwal Apr 2, 2025
6b682f5
version change
AstitvaAggarwal Apr 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.114"
version = "0.4.115"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
57 changes: 55 additions & 2 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
out, pb!! = rule(fx_fwds...)
@assert _typeof(tangent(out)) == fdata_type(T)
increment!!(tangent(out), fdata(ȳ))
v = y_cache === nothing ? copy(primal(out)) : _copy!!(y_cache, primal(out))
v = y_cache === nothing ? _copy_output(primal(out)) : _copy!!(y_cache, primal(out))
return v, tuple_map((f, r) -> tangent(fdata(tangent(f)), r), fx, pb!!(rdata(ȳ)))
end

Expand Down Expand Up @@ -238,6 +238,59 @@

const _BuiltinArrays = @static VERSION >= v"1.11" ? Union{Array,Memory} : Array

# explicit for svec
_copy_output(x::SimpleVector) = Core.svec([map(_copy_output, x_sub) for x_sub in x]...)

Check warning on line 242 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L242

Added line #L242 was not covered by tests

# Array, Memory
function _copy_output(x::P) where {P<:_BuiltinArrays}
temp = P(undef, size(x)...)
@inbounds for i in eachindex(temp)
isassigned(x, i) && (temp[i] = _copy_output(x[i]))
end
return temp
end

# Tuple, NamedTuple
_copy_output(x::Union{Tuple,NamedTuple}) = map(_copy_output, x)

Check warning on line 254 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L254

Added line #L254 was not covered by tests

# mutable composite types, bitstype
function _copy_output(x::P) where {P}
isbitstype(P) && return x
nf = nfields(P)

Check warning on line 259 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L259

Added line #L259 was not covered by tests

if ismutable(x)
temp = ccall(:jl_new_struct_uninit, Any, (Any,), P)
for x_sub in 1:nf
if isdefined(x, x_sub)
ccall(

Check warning on line 265 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L261-L265

Added lines #L261 - L265 were not covered by tests
:jl_set_nth_field,
Cvoid,
(Any, Csize_t, Any),
temp,
x_sub - 1,
_copy_output(getfield(x, x_sub)),
)
end
end

Check warning on line 274 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L274

Added line #L274 was not covered by tests

return temp

Check warning on line 276 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L276

Added line #L276 was not covered by tests
else
flds = Vector{Any}(undef, nf)
for x_sub in 1:nf
if isdefined(x, x_sub)
flds[x_sub] = _copy_output(getfield(x, x_sub))

Check warning on line 281 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L278-L281

Added lines #L278 - L281 were not covered by tests
else
nf = x_sub - 1 # Assumes if a undefined field is found, all subsequent fields are undefined.
break

Check warning on line 284 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L283-L284

Added lines #L283 - L284 were not covered by tests
end
end

Check warning on line 286 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L286

Added line #L286 was not covered by tests

# when immutable struct object created by non initializing inner constructor. (Base.deepcopy misses this out)
!isassigned(flds, 1) && return x
return ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), P, flds, nf)

Check warning on line 290 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L289-L290

Added lines #L289 - L290 were not covered by tests
end
end

function __exclude_unsupported_output_internal!(
y::T, address_set::Set{UInt}
) where {T<:_BuiltinArrays}
Expand Down Expand Up @@ -291,7 +344,7 @@
__exclude_unsupported_output(y)

# Construct cache for output. Check that `copy!`ing appears to work.
y_cache = copy(primal(y))
y_cache = _copy_output(primal(y))
return Cache(rule, _copy!!(y_cache, primal(y)), tangents)
end

Expand Down
19 changes: 17 additions & 2 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,24 @@ end

additional_test_set = Mooncake.tangent_test_cases()

@testset for i in eachindex(additional_test_set)
@testset "__exclude_unsupported_output , $(test_set)" for test_set in
additional_test_set
try
Mooncake.__exclude_unsupported_output(additional_test_set[i][2])
Mooncake.__exclude_unsupported_output(test_set[2])
catch err
@test isa(err, Mooncake.ValueAndPullbackReturnTypeError)
end
end

@testset "_copy_output , $(test_set)" for test_set in additional_test_set
original = test_set[2]
try
if isnothing(Mooncake.__exclude_unsupported_output(original))
test_copy = Mooncake._copy_output(original)

@test Mooncake.TestUtils.has_equal_data(original, test_copy)
@test typeof(test_copy) == typeof(original)
end
catch err
@test isa(err, Mooncake.ValueAndPullbackReturnTypeError)
end
Expand Down