From eb5934a443ccea5b9052a08149f1df0990abcc23 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 4 Jan 2025 00:59:32 -0500 Subject: [PATCH] =?UTF-8?q?More=20unthunking=20in=20`=E2=88=87chunk`=20(#1?= =?UTF-8?q?80)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update utils.jl * version 0.4.5 --- Project.toml | 2 +- src/utils.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 74a4ad7..6b3292e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLUtils" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" authors = ["Carlo Lucibello and contributors"] -version = "0.4.4" +version = "0.4.5" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/utils.jl b/src/utils.jl index f7a4afa..b9c25b3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -237,7 +237,8 @@ end @non_differentiable _partition_idxs(::Any...) # Similar to ∇eachslice https://github.com/JuliaDiff/ChainRules.jl/blob/8108a77a96af5d4b0c460aac393e44f8943f3c5e/src/rulesets/Base/indexing.jl#L77 -function ∇chunk(dys, x, idxs, vd::Val{dim}) where {dim} +function ∇chunk(dys_raw, x, idxs, vd::Val{dim}) where {dim} + dys = unthunk.(unthunk(dys_raw)) # https://github.com/FluxML/Zygote.jl/pull/966#issuecomment-2569227272 i1 = findfirst(dy -> !(dy isa AbstractZero), dys) if i1 === nothing # all slices are Zero! return _zero_fill!(similar(x, float(eltype(x))))