Skip to content

Commit

Permalink
Populate _auto_repr before running module constructor.
Browse files Browse the repository at this point in the history
Fixes #428.

PiperOrigin-RevId: 609318482
  • Loading branch information
tomhennigan authored and copybara-github committed Feb 22, 2024
1 parent 8dc61b3 commit 5c0f1e2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
18 changes: 12 additions & 6 deletions haiku/_src/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,21 @@ def __call__(cls, *args, **kwargs) -> Any: # pylint: disable=no-self-argument
# with the new class and not the metaclass.
module = cls.__new__(cls, *args, **kwargs) # pytype: disable=wrong-arg-types

# Now attempt to initialize the object.
init = wrap_method("__init__", cls.__init__, lambda: cls)
init(module, *args, **kwargs)

# We populate _auto_repr before `__init__` to allow `repr(self)` during the
# constructor of the module.
if (config.get_config().module_auto_repr and
getattr(module, "AUTO_REPR", True)):
module._auto_repr = utils.auto_repr(cls, *args, **kwargs) # pylint: disable=protected-access
module_repr = utils.auto_repr(cls, *args, **kwargs) # pylint: disable=protected-access
else:
module._auto_repr = object.__repr__(module)
module_repr = object.__repr__(module)

# Avoid triggering user defined __setattr__ overrides since we have not yet
# run their constructor.
object.__setattr__(module, "_auto_repr", module_repr)

# Now attempt to initialize the object.
init = wrap_method("__init__", cls.__init__, lambda: cls)
init(module, *args, **kwargs)

ran_super_ctor = hasattr(module, "module_name")
if not ran_super_ctor:
Expand Down
14 changes: 14 additions & 0 deletions haiku/_src/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,20 @@ def test_auto_repr(self):
m = IdentityModule()
self.assertEqual(str(m), "IdentityModule()")

@test_utils.transform_and_run
def test_repr_during_ctor(self):
# See https://github.com/google-deepmind/dm-haiku/issues/428 for other ways
# this can get triggered.

test = self

class MyModule(module.Module):
def __init__(self):
super().__init__()
test.assertEqual(repr(self), "MyModule()")

MyModule() # Does not fail.

def test_signature(self):
captures_expected = inspect.Signature(
parameters=(
Expand Down

0 comments on commit 5c0f1e2

Please sign in to comment.