Replies: 2 comments 1 reply
-
|
Hey @stefan-it, can you point to the exact line where this fusion is happening?
|
Beta Was this translation helpful? Give feedback.
-
|
Hi @cgarciae , thanks for your fast reply. I see the only difference here (there's this projection = functools.partial(
DenseGeneral,
axis=-1,
features=(self.num_heads, self.head_dim),
kernel_axes=('embed', 'joined_kv'),
dtype=self.dtype)whereas the scalable architecture has: projection = functools.partial(
DenseGeneral,
axis=-1,
features=(self.num_heads, self.head_dim),
kernel_axes=('embed', 'heads', 'kv'),
dtype=self.dtype)For converting weights we need to manually fuse |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
during integrating the umT5 architecture (which is called "Scalable T5" in T5X repository) into 🤗 Transformers library, we saw a "joined" or fused matrix representation with shape of
headxkv. The T5X repo states this operation as:(-> Source)
But how could this be done in Flax?
For example we have a tensor of (512, 6, 64) and we need a "fused" representation with a shape of (512, 384).
I used:
where
config.d_model = 512,config.num_heads = 6andconfig.d_kv = 64to get the desired fused representation.But is this the correct way to do that?
Beta Was this translation helpful? Give feedback.
All reactions