Skip to content

Commit

Permalink
Raise error if user uses 'name' in bridge module setup
Browse files Browse the repository at this point in the history
  • Loading branch information
IvyZX committed Feb 19, 2025
1 parent f9f6885 commit c8a465d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
8 changes: 5 additions & 3 deletions flax/nnx/bridge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _module_meta_call(cls: type[M], *args, **kwargs) -> M:
parent = None
module: M

if parent_ctx is not None and parent_ctx.in_compact:
if parent_ctx is not None:
if 'parent' in kwargs:
parent = kwargs.pop('parent')
if parent is not None:
Expand All @@ -136,6 +136,10 @@ def _module_meta_call(cls: type[M], *args, **kwargs) -> M:
parent_ctx.type_counter[cls] += 1

if 'name' in kwargs:
if not parent_ctx.in_compact:
raise ValueError(
f"'name' can only be set in @compact functions. If in setup(), "
"use parent's `self.<attr_name> to set the submodule name.")
name = kwargs.pop('name')
if not isinstance(name, str):
raise ValueError(f"'name' must be a 'str', got {type(name).__name__}")
Expand Down Expand Up @@ -202,7 +206,6 @@ def param( # type: ignore[invalid-annotation]
name: str,
init_fn: tp.Callable[..., A],
*init_args,
unbox: bool = True,
**init_kwargs,
) -> variablelib.Param[A]:
# TODO(cgarciae): implement same condition as linen
Expand Down Expand Up @@ -253,7 +256,6 @@ def variable( # type: ignore[invalid-annotation]
name: str,
init_fn: tp.Callable[..., A] | None = None,
*init_args,
unbox: bool = True,
**init_kwargs,
) -> variablelib.Variable[A]:
variable_type = variablelib.variable_type_from_name(
Expand Down
8 changes: 8 additions & 0 deletions tests/nnx/bridge/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,14 @@ def __call__(self, x):
y = bar.apply(variables, x)
self.assertEqual(y.shape, (1, 5))

with self.assertRaises(ValueError):
class SetupBar(bridge.Module):
def setup(self):
self.xyz = Foo(5, name='xyz')
def __call__(self, x):
return self.xyz(x)
SetupBar().init(0, x)

def test_dense_port(self):
class Dense(bridge.Module):
features: int
Expand Down

0 comments on commit c8a465d

Please sign in to comment.