-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for (vision) transformers, add options to set last-layer relevance #15
base: main
Are you sure you want to change the base?
Add support for (vision) transformers, add options to set last-layer relevance #15
Conversation
add canonization for SkipConnection layers; fix model splitting edge …
Hotfix/canonize
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #15 +/- ##
==========================================
- Coverage 96.66% 0.00% -96.67%
==========================================
Files 14 15 +1
Lines 660 698 +38
==========================================
- Hits 638 0 -638
- Misses 22 698 +676 ☔ View full report in Codecov by Sentry. |
Maybe move prepare_vit to canonize |
src/extensions.jl
Outdated
struct SelectClassToken end | ||
Flux.@functor SelectClassToken |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XAIBase exports generic feature selectors.
Maybe these could be used here and extended for transformers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do that, but I don't get how these feature selectors are supposed to be used in a model / why there are no rules for them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay as we discussed, it does not really make sense to use the feature selectors. I think the remaining question is where you want to define new layers in the codebase - maybe an extra file src/layers.jl
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay as we discussed, it does not really make sense to use the feature selectors.
Sorry, it's been a while... Can you remind me what the exact issue was? 😅
I can vaguely remember it was something that should go in XAIBase.jl.
Similar to this: https://github.com/Julia-XAI/XAIBase.jl/blob/main/src/feature_selection.jl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, so vision transformers have a special token that is selected near the output, and all other tokens are discarded. They implement this in Metalhead through an anonymous function, so we can't use it for computing LRP. What I did was implementing this simple Flux layer (and an associated rule), that is swapped in for the anonymous function. The problem with the feature selector is that we need an actual layer, so I think we decided that this is probably not the right place ^^
Vision Transformers
Add explanations for (vision) transformer (
ViT
) models by adding the package extensionVisionTransformerExt
that depends onMetalhead.jl
.Adds the rules
SelfAttentionRule
forMultiHeadSelfAttention
layersPositionalEmbeddingRule
forViPosEmbedding
layersAlso adds support for some special layers of vision transformers by adding a method for
ZeroRule
:_flatten_spatial
is a reshaping layer near the inputClassTokens
adds a class token to the modelSelectClassToken
only retains the class token for the model prediction (this layer was added by me because Metalhead uses an anonymous function for this purpose, that we have to swap for a "real" layer before explaining the model)So far, no support for
Flux.jl
's built-inMultiHeadAttention
layer was added, because this layer does not work nicely with Chains (Metalhead
'sMultiHeadSelfAttention
layer is not limited to vision transformer, but can also be used to build "regular" transformer models as long as they use only self-attention, e.g. encoder-only models, something like BERT should be doable).In addition, the function
prepare_vit
can be used to prepareMetalhead
'sViT
(convert it to aChain
, addSelectClassToken
layer).Last layer relevance
Adds the keyword arguments
normalize_output=true
,R=nothing
toLRP
. IfR
is supplied, the relevances in the last layer are set toR
. Ifnormalize_output
is false, the target neuron activation is not set to one, but remains the "raw" activation from the forward pass.Canonization
This PR already contains the changes of PR #14, because otherwise canonization of
ViT
models does not work properly - so PR #14 should be merged first.ToDo
For documentation, I guess it would be nice to have an extra "Extensions" section in the docs, and have a small tutorial under "Extensions => Vision Transformer".