Skip to content

Commit

Permalink
Raise error if user uses 'name' in bridge module setup
Browse files Browse the repository at this point in the history
  • Loading branch information
IvyZX committed Feb 20, 2025
1 parent f9f6885 commit ee377f4
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 20 deletions.
43 changes: 23 additions & 20 deletions flax/nnx/bridge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,28 +123,31 @@ def _module_meta_call(cls: type[M], *args, **kwargs) -> M:
parent = None
module: M

if parent_ctx is not None and parent_ctx.in_compact:
if 'parent' in kwargs:
parent = kwargs.pop('parent')
if parent is not None:
raise ValueError(
f"'parent' can only be set to None, got {type(parent).__name__}"
)
name = None
else:
type_index = parent_ctx.type_counter[cls]
parent_ctx.type_counter[cls] += 1
name = None
if parent_ctx is not None:
if not parent_ctx.in_compact and 'name' in kwargs:
raise ValueError(
f"'name' can only be set in @compact functions. If in setup(), "
"use parent's `self.<attr_name> to set the submodule name.")

if 'name' in kwargs:
name = kwargs.pop('name')
if not isinstance(name, str):
raise ValueError(f"'name' must be a 'str', got {type(name).__name__}")
if parent_ctx.in_compact:
if 'parent' in kwargs:
parent = kwargs.pop('parent')
if parent is not None:
raise ValueError(
f"'parent' can only be set to None, got {type(parent).__name__}"
)
else:
name = f'{cls.__name__}_{type_index}'

parent = parent_ctx.module
else:
name = None
type_index = parent_ctx.type_counter[cls]
parent_ctx.type_counter[cls] += 1

if 'name' in kwargs:
name = kwargs.pop('name')
if not isinstance(name, str):
raise ValueError(f"'name' must be a 'str', got {type(name).__name__}")
else:
name = f'{cls.__name__}_{type_index}'
parent = parent_ctx.module

module = nnx_module.ModuleMeta.__call__(cls, *args, **kwargs)
module.scope = None
Expand Down
8 changes: 8 additions & 0 deletions tests/nnx/bridge/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,14 @@ def __call__(self, x):
y = bar.apply(variables, x)
self.assertEqual(y.shape, (1, 5))

with self.assertRaises(ValueError):
class SetupBar(bridge.Module):
def setup(self):
self.xyz = Foo(5, name='xyz')
def __call__(self, x):
return self.xyz(x)
SetupBar().init(0, x)

def test_dense_port(self):
class Dense(bridge.Module):
features: int
Expand Down

0 comments on commit ee377f4

Please sign in to comment.