Skip to content

Incorrect gradient when diff path goes through a for loop #126

@dfdx

Description

@dfdx

From FluxML/NNlib.jl#434 (comment)

function prod2(xs::Vector)
    p = one(eltype(xs))
    for x in xs
        p = p * x
        p == 0 && break  # exit early once you know the answer
    end
    p
end

ChainRulesCore.@non_differentiable eltype(::Any)

function main()
    x = rand(3)
    Yota.grad(prod2, x)
    _, tape = trace(prod2, x; ctx=GradCtx())
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions