Skip to content

Commit f3e57ac

Browse files
ConchylicultorThe gemma Authors
authored andcommitted
Fix dtype overwrite when used with LoRA
PiperOrigin-RevId: 726423279
1 parent 10ce77b commit f3e57ac

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

gemma/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616

1717
# A new PyPI release will be pushed every time `__version__` is increased.
1818
# When changing this, also update the CHANGELOG.md.
19-
__version__ = '2.0.7'
19+
__version__ = '2.0.8'

gemma/gm/utils/_dtype_params.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,16 @@
3232
class _Context:
3333
"""Context for the dtype stack."""
3434

35-
dtypes: edc.ContextVar[list[_DType]] = dataclasses.field(default_factory=list)
35+
dtypes: edc.ContextVar[list[_DType | None]] = dataclasses.field(
36+
default_factory=list
37+
)
3638

3739

3840
_context = _Context()
3941

4042

4143
@contextlib.contextmanager
42-
def initialize_param_with_dtype(dtype: _DType) -> Iterator[None]:
44+
def initialize_param_with_dtype(dtype: _DType | None) -> Iterator[None]:
4345
"""Set the params dtype to the given value.
4446
4547
Inside the contextmanager, `self.param()` will use the given dtype.
@@ -64,7 +66,14 @@ def _mock_flax_module_param() -> None:
6466

6567
@_internal.wraps_with_reload(param)
6668
def decorated(self: nn.Module, *args, **kwargs):
67-
if self.is_initializing() and _context.dtypes:
69+
if (
70+
self.is_initializing()
71+
and _context.dtypes
72+
# LoRA modules provide the dtype as kwargs
73+
and 'dtype' not in kwargs
74+
# If `None` is provided, use the default dtype
75+
and _context.dtypes[-1] is not None
76+
):
6877
return param(self, *args, **kwargs, dtype=_context.dtypes[-1])
6978
else:
7079
return param(self, *args, **kwargs)

0 commit comments

Comments
 (0)