Skip to content

Commit 3016825

Browse files
mcr229facebook-github-bot
authored andcommitted
Collect Named Data Store at construction
Summary: EdgeProgramManager should never take in a NamedDataStore, it should always be collected from the EdgePrograms that it recieves. Differential Revision: D71407899
1 parent 022a946 commit 3016825

File tree

1 file changed

+32
-21
lines changed

1 file changed

+32
-21
lines changed

exir/program/_program.py

+32-21
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,33 @@ def _gen_edge_manager_for_partitioners(
10981098
return edge_manager
10991099

11001100

1101+
def collect_named_data_store_from_exported_program(
1102+
exported_program: ExportedProgram,
1103+
named_data_store: NamedDataStore,
1104+
) -> None:
1105+
"""
1106+
Collects all the named data store outputs found within the exported program
1107+
and adds them to named_data_store.
1108+
"""
1109+
1110+
# collected all the named data into the named data store for deduplication
1111+
def collect_named_data_store_outputs(
1112+
graph_module: torch.fx.GraphModule,
1113+
) -> None:
1114+
for node in graph_module.graph.nodes:
1115+
if node.target == executorch_call_delegate:
1116+
lbm = getattr(graph_module, node.args[0].name)
1117+
assert is_lowered_module(lbm)
1118+
data_store_output = lbm.named_data_store_output
1119+
if data_store_output is not None:
1120+
named_data_store.merge_named_data_store(data_store_output)
1121+
1122+
for _, submod, _ in get_control_flow_submodules(graph_module):
1123+
collect_named_data_store_outputs(submod)
1124+
1125+
collect_named_data_store_outputs(exported_program.graph_module)
1126+
1127+
11011128
@et_logger("to_edge_transform_and_lower")
11021129
def to_edge_transform_and_lower(
11031130
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
@@ -1307,7 +1334,6 @@ def __init__(
13071334
constant_methods: Optional[Dict[str, Any]] = None,
13081335
compile_config: Optional[EdgeCompileConfig] = None,
13091336
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
1310-
named_data_store: Optional[NamedDataStore] = None,
13111337
):
13121338
"""
13131339
Should not be called directly by users. User should use :func:'to_edge' instead.
@@ -1331,7 +1357,11 @@ def __init__(
13311357
self._edge_programs: Dict[str, ExportedProgram] = edge_programs
13321358
self._config_methods = constant_methods
13331359

1334-
self._named_data_store = named_data_store or NamedDataStore()
1360+
self._named_data_store = NamedDataStore()
1361+
for _, program in self._edge_programs.items():
1362+
collect_named_data_store_from_exported_program(
1363+
program, self._named_data_store
1364+
)
13351365

13361366
@property
13371367
def methods(self) -> Set[str]:
@@ -1441,30 +1471,11 @@ def to_backend(
14411471
for name, program in self._edge_programs.items():
14421472
new_edge_programs[name] = to_backend(program, partitioner)
14431473

1444-
# collected all the named data into the named data store for deduplication
1445-
def collect_named_data_store_outputs(
1446-
graph_module: torch.fx.GraphModule,
1447-
) -> None:
1448-
for node in graph_module.graph.nodes:
1449-
if node.target == executorch_call_delegate:
1450-
lbm = getattr(graph_module, node.args[0].name)
1451-
assert is_lowered_module(lbm)
1452-
data_store_output = lbm.named_data_store_output
1453-
if data_store_output is not None:
1454-
self._named_data_store.merge_named_data_store(data_store_output)
1455-
1456-
for _, submod, _ in get_control_flow_submodules(graph_module):
1457-
collect_named_data_store_outputs(submod)
1458-
1459-
for _, program in new_edge_programs.items():
1460-
collect_named_data_store_outputs(program.graph_module)
1461-
14621474
config = EdgeCompileConfig(_check_ir_validity=False)
14631475
return EdgeProgramManager(
14641476
new_edge_programs,
14651477
copy.deepcopy(self._config_methods),
14661478
config,
1467-
named_data_store=self._named_data_store,
14681479
)
14691480

14701481
@et_logger("to_executorch")

0 commit comments

Comments
 (0)