Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_copy_temp() to replace Base.copy() usage. #529

Open
wants to merge 46 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
499064f
_copy_temp() to replace Base.copy()
AstitvaAggarwal Mar 21, 2025
b197109
function declaration after _BuiltinArrays const.
AstitvaAggarwal Mar 21, 2025
24d4ffe
additional tests, more method dispatches
AstitvaAggarwal Mar 21, 2025
2eafc8f
Faster, better, stronger
AstitvaAggarwal Mar 22, 2025
70afdbc
tests
AstitvaAggarwal Mar 22, 2025
1378553
minor fix
AstitvaAggarwal Mar 22, 2025
8d7ee73
Arrays case
AstitvaAggarwal Mar 22, 2025
57228de
closer performance to Base.copy()
AstitvaAggarwal Mar 22, 2025
89cb6ec
mutable composite types
AstitvaAggarwal Mar 23, 2025
d9fbde2
mutable composite types
AstitvaAggarwal Mar 23, 2025
b42d3d5
reduced allocations to copy call, improved testing
AstitvaAggarwal Mar 23, 2025
622c394
format
AstitvaAggarwal Mar 23, 2025
078c9d7
tests.
AstitvaAggarwal Mar 23, 2025
c4abc6f
why does CI fail?
AstitvaAggarwal Mar 23, 2025
97137cf
use ccall with _BuiltinArrays copy logic
AstitvaAggarwal Mar 23, 2025
ddc9461
using Base.deepcopy() handling of structs
AstitvaAggarwal Mar 23, 2025
3638492
format
AstitvaAggarwal Mar 23, 2025
1d59810
modify tests
AstitvaAggarwal Mar 23, 2025
f0bc196
modify testing
AstitvaAggarwal Mar 23, 2025
144d2f4
nfields zero case
AstitvaAggarwal Mar 23, 2025
a940fb9
see why CI is failing
AstitvaAggarwal Mar 23, 2025
711fb47
test why CI fails
AstitvaAggarwal Mar 23, 2025
444aa8c
.
AstitvaAggarwal Mar 23, 2025
eca35a0
testing over closures
AstitvaAggarwal Mar 23, 2025
134691c
NaN handling
AstitvaAggarwal Mar 23, 2025
502ea8d
NaN Handling
AstitvaAggarwal Mar 23, 2025
c4487ed
See why CI fails
AstitvaAggarwal Mar 24, 2025
2a968ec
Unitialized mutable structs...
AstitvaAggarwal Mar 24, 2025
02a5809
fix tests
AstitvaAggarwal Mar 24, 2025
724d841
nospecialize
AstitvaAggarwal Mar 24, 2025
3b70634
type dispatch over uninitialized Float64, Int64
AstitvaAggarwal Mar 24, 2025
8dfe0e6
Update interface.jl
AstitvaAggarwal Mar 24, 2025
c5809ff
Update interface.jl
AstitvaAggarwal Mar 24, 2025
6cae723
Update interface.jl
AstitvaAggarwal Mar 25, 2025
b796192
check CI failures
AstitvaAggarwal Mar 25, 2025
ee58a00
.
AstitvaAggarwal Mar 25, 2025
0321bda
.
AstitvaAggarwal Mar 25, 2025
1c863c9
Attention to detail!
AstitvaAggarwal Mar 25, 2025
c950e03
Prediciton : it passes all tests.
AstitvaAggarwal Mar 25, 2025
0a3fd0e
remove call from value_and_pullback!!
AstitvaAggarwal Mar 25, 2025
e2d1789
remove usage completely for once
AstitvaAggarwal Mar 25, 2025
ffddc96
Merge branch 'main' into develop
AstitvaAggarwal Mar 25, 2025
307509f
_copy_temp for value_and_pullback!! & prepare_pullback_cache
AstitvaAggarwal Mar 26, 2025
c4ce5c0
Merge branch 'develop' of https://github.com/AstitvaAggarwal/Mooncake…
AstitvaAggarwal Mar 26, 2025
0eb390d
Merge branch 'main' into develop
AstitvaAggarwal Mar 28, 2025
d6e9f94
Merge branch 'main' into develop
yebai Mar 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
out, pb!! = rule(fx_fwds...)
@assert _typeof(tangent(out)) == fdata_type(T)
increment!!(tangent(out), fdata(ȳ))
v = y_cache === nothing ? copy(primal(out)) : _copy!!(y_cache, primal(out))
v = y_cache === nothing ? _copy_temp(primal(out)) : _copy!!(y_cache, primal(out))
return v, tuple_map((f, r) -> tangent(fdata(tangent(f)), r), fx, pb!!(rdata(ȳ)))
end

Expand Down Expand Up @@ -238,6 +238,63 @@

const _BuiltinArrays = @static VERSION >= v"1.11" ? Union{Array,Memory} : Array

# explicit for svec
function _copy_temp(x::P) where {P<:SimpleVector}
return Core.svec([map(_copy_temp, x_sub) for x_sub in x]...)

Check warning on line 243 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L242-L243

Added lines #L242 - L243 were not covered by tests
end

# Array, Memory
function _copy_temp(x::P) where {P<:_BuiltinArrays}
temp = P(undef, size(x)...)
@inbounds for i in eachindex(temp)
isassigned(x, i) && (temp[i] = _copy_temp(x[i]))
end
return temp
end

# Tuple, NamedTuple
function _copy_temp(x::P) where {P<:Union{Tuple,NamedTuple}}
return map(_copy_temp, x)

Check warning on line 257 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L256-L257

Added lines #L256 - L257 were not covered by tests
end

# mutable composite types, bitstype
function _copy_temp(x::P) where {P}
isbitstype(P) && return x
nf = nfields(P)

Check warning on line 263 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L263

Added line #L263 was not covered by tests

if ismutable(x)
temp = ccall(:jl_new_struct_uninit, Any, (Any,), P)
for x_sub in 1:nf
if isdefined(x, x_sub)
ccall(

Check warning on line 269 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L265-L269

Added lines #L265 - L269 were not covered by tests
:jl_set_nth_field,
Cvoid,
(Any, Csize_t, Any),
temp,
x_sub - 1,
_copy_temp(getfield(x, x_sub)),
)
end
end

Check warning on line 278 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L278

Added line #L278 was not covered by tests

return temp

Check warning on line 280 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L280

Added line #L280 was not covered by tests
else
flds = Vector{Any}(undef, nf)
for x_sub in 1:nf
if isdefined(x, x_sub)
flds[x_sub] = _copy_temp(getfield(x, x_sub))

Check warning on line 285 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L282-L285

Added lines #L282 - L285 were not covered by tests
else
nf = x_sub - 1 # Assumes if a undefined field is found, all subsequent fields are undefined.
break

Check warning on line 288 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L287-L288

Added lines #L287 - L288 were not covered by tests
end
end

Check warning on line 290 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L290

Added line #L290 was not covered by tests

# when immutable struct object created by non initializing inner constructor. (Base.deepcopy misses this out)
!isassigned(flds, 1) && return x
return ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), P, flds, nf)

Check warning on line 294 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L293-L294

Added lines #L293 - L294 were not covered by tests
end
end

function __exclude_unsupported_output_internal!(
y::T, address_set::Set{UInt}
) where {T<:_BuiltinArrays}
Expand Down Expand Up @@ -291,7 +348,7 @@
__exclude_unsupported_output(y)

# Construct cache for output. Check that `copy!`ing appears to work.
y_cache = copy(primal(y))
y_cache = _copy_temp(primal(y))
return Cache(rule, _copy!!(y_cache, primal(y)), tangents)
end

Expand Down
56 changes: 56 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,5 +175,61 @@ end
@test isa(err, Mooncake.ValueAndPullbackReturnTypeError)
end
end

# Test for new copy function
@testset for i in eachindex(additional_test_set)
original = additional_test_set[i][2]
try
if isnothing(Mooncake.__exclude_unsupported_output(original))
test_copy = Mooncake._copy_temp(original)

function comparisons(
original::P, test_copy::P
) where {P<:Mooncake._BuiltinArrays}
@test original !== test_copy
@test size(original) == size(test_copy)

# Value caching for pure immutable Types!
for i in eachindex(test_copy)
if !isassigned(test_copy, i)
@test !isassigned(original, i)
else
comparisons(original[i], test_copy[i])
end
end
end

function comparisons(original::P, test_copy::P) where {P}
(isbitstype(P) && !isnothing(original) && isnan(original)) &&
return @test isnan(test_copy)
isbitstype(P) && return @test test_copy == original

fields_copy = [
if !isdefined(test_copy, name)
nothing
else
getfield(test_copy, name)
end for name in fieldnames(typeof(test_copy))
]
fields_orig = [
!isdefined(original, name) ? nothing : getfield(original, name)
for name in fieldnames(typeof(original))
]

return comparisons(fields_orig, fields_copy)
end

@test typeof(test_copy) == typeof(original)
# isbitstypes with same values are stored in the same address (Value Caching).
if isbitstype(typeof(original))
@test test_copy == original
else
comparisons(original, test_copy)
end
end
catch err
@test isa(err, Mooncake.ValueAndPullbackReturnTypeError)
end
end
end
end
Loading