You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, MultiHeadDotProductAttention layer's call method signature is MultiHeadDotProductAttention.__call__(inputs_q, inputs_kv, mask=None, deterministic=None). As discussed in #1737, there are some cases where passing in separate values for the key and values is desired, which isn't possible with the current API. The PR #3379 adds two more arguments, inputs_k and inputs_v to the call method signature and sets the method signature to the following: MultiHeadDotProductAttention.__call__(inputs_q, inputs_k=None, inputs_v=None, *, inputs_kv=None, mask=None, deterministic=None). Note that the inputs_kv, mask and deterministic args are now keyword arguments.
if inputs_k and inputs_v are None, then they will both copy the value of inputs_q (i.e. self attention)
if inputs_v is None, it will copy the value of inputs_k (same behavior as the previous API, i.e. module.apply(inputs_q=query, inputs_k=key_value, ...) is equivalent to module.apply(inputs_q=query, inputs_kv=key_value, ...))
if inputs_kv is not None, both inputs_k and inputs_v will copy the value of inputs_kv
Users can still use inputs_kv but a DeprecationWarning will be raised and inputs_kv will be removed in the future.
Since self attention can be done using this new API, the SelfAttention layer will also raise a DeprecationWarning and will be removed in the future.
Some examples of porting over your code to the new method signature:
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Currently,
MultiHeadDotProductAttentionlayer's call method signature isMultiHeadDotProductAttention.__call__(inputs_q, inputs_kv, mask=None, deterministic=None). As discussed in #1737, there are some cases where passing in separate values for the key and values is desired, which isn't possible with the current API. The PR #3379 adds two more arguments,inputs_kandinputs_vto the call method signature and sets the method signature to the following:MultiHeadDotProductAttention.__call__(inputs_q, inputs_k=None, inputs_v=None, *, inputs_kv=None, mask=None, deterministic=None). Note that theinputs_kv,maskanddeterministicargs are now keyword arguments.inputs_kandinputs_vareNone, then they will both copy the value ofinputs_q(i.e. self attention)inputs_visNone, it will copy the value ofinputs_k(same behavior as the previous API, i.e.module.apply(inputs_q=query, inputs_k=key_value, ...)is equivalent tomodule.apply(inputs_q=query, inputs_kv=key_value, ...))inputs_kvis not None, bothinputs_kandinputs_vwill copy the value ofinputs_kvUsers can still use
inputs_kvbut aDeprecationWarningwill be raised andinputs_kvwill be removed in the future.Since self attention can be done using this new API, the
SelfAttentionlayer will also raise aDeprecationWarningand will be removed in the future.Some examples of porting over your code to the new method signature:
module.apply(query, key_value, mask, deterministic)module.apply(query, key_value, mask=mask, deterministic=deterministic)module.apply(inputs_q=query, inputs_kv=key_value, mask=mask, deterministic=deterministic)module.apply(inputs_q=query, inputs_k=key_value, mask=mask, deterministic=deterministic)sa_module.apply(query, mask, deterministic)module.apply(query, mask=mask, deterministic=deterministic)sa_module.apply(inputs_q=query, mask=mask, deterministic=deterministic)module.apply(inputs_q=query, mask=mask, deterministic=deterministic)For additional context, check out the PR #3379 and the discussion thread #1737.
Beta Was this translation helpful? Give feedback.
All reactions