Description
Motivation and description
Maybe this is a more general topic for MLJ
, not only related to Flux
. I know that autodiff has been discussed in the past and with MLJFlux
now being developed, I was wondering if this topic has come back into focus.
In an ideal world, it would be possible to differentiate through any SupervisedModel
and get gradients with respect to parameters or inputs. This would, for example, greatly increase the scope of models we can explain through Counterfactual Explanations (see plans outlined here).
MLJFlux
seems like a good place to start, since the underlying models are compatible with Zygote
. But even here we quickly run into issues: for example, it does not seem possible to differentiate through a predict
call.
An example:
using MLJ
Random.seed!(1234)
X, y = make_blobs(1000, 2, centers=2)
X = MLJ.table(Float32.(MLJ.matrix(X)))
NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux
clf = NeuralNetworkClassifier()
mach = machine(clf, X, y)
fit!(mach)
# Two different methods to return softmax output:
using Flux
f(x) = permutedims(pdf(predict(mach, MLJ.table(Float32.(x'))), levels(y)))
g(x) = mach.fitresult[1](Float32.(x))
x = rand(2,1)
Both f
and g
can be used to return softmax output for x
julia> f(x)
2×1 Matrix{Float32}:
0.24943815
0.75056183
julia> g(x)
2×1 Matrix{Float32}:
0.24943815
0.75056183
Autodiff only works for g
,
loss(x, y, fun) = Flux.Losses.crossentropy(fun(x),y)
julia> gradient(loss, x, 1, g)
([2.226835250854492; 0.937971830368042;;], 6.655112266540527, nothing)
but not for f
:
julia> gradient(loss, x, 1, f)
ERROR:
──────────────────────────────────────────────────────────────── Zygote.CompileError ───────────────────────────────────────────────────────────────
╭──── Error Stack ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ │
│ │
│ ╭──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ │
│ │ │ │
│ │ (1) top-level scope │ │
│ │ ~/.julia/packages/CUDA/BbliS/src/initialization.jl:52 │ │
│ │ │ ╭──────────────────────────────────────────────────────────╮ │ │
│ │ ╰─│ 50 quote │ │ │
│ │ │ 51 try │ │ │
│ │ │ ❯ 52 $(ex) │ │ │
│ │ │ 53 finally │ │ │
│ │ │ 54 $task_local_state()... │ │ │
│ │ ╰─────────────────────── error line ───────────────────────╯ │ │
│ │ │ │
│ ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────── TOP LEVEL ───╯ │
│ │
│ (2) top-level scope │
│ REPL[202]:1 │
│ │
│ ─────────────────────────────────────────────────────────── In module Core ─────────────────────────────────────────────────────────── │
│ │
│ ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── │
│ Skipped 16 frames in Zygote, Base │
│ ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── │
│ │
│ (19) (::Core.GeneratedFunctionStub)(::Any, ::Vararg) │
│ ./boot.jl:582 │
│ │
│ ────────────────────────────────────────────────────────── In module Zygote ────────────────────────────────────────────────────────── │
│ │
│ (20) var"#s2948#1107"(::Any, ctx::Any, f::Any, args::Any) │
│ ./none:0 │
│ │
│ ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── │
│ Skipped 6 frames in Zygote │
│ ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── │
│ │
│ ╭──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ │
│ │ │ │
│ │ (28) error(s::String) │ │
│ │ ./error.jl:35 │ │
│ │ │ │
│ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────── ERROR LINE ───╯ │
│ │
╰──── Error Stack ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
╭─────────────────────────────────────────────────────────── Zygote.CompileError ──────────────────────────────────────────────────────────╮
│ │
│ no message for error of type Zygote.CompileError, sorry. │
│ │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
A simple workaround for this specific issue is to just use the Chain
directly to produce the softmax output but this approach does not generalise to other MLJ
models.
I appreciate that this is a very ambitious idea (perhaps previous discussions have that this is simply asking too much), but I would be curious to hear what others think.
Worth mentioning that for the plans mentioned above, I will get some support from a group of CS students soon. So if you have any plans or ongoing work in this space anyway, perhaps there's something we can help with.
Thanks!
Possible Implementation
No response