@@ -1098,6 +1098,33 @@ def _gen_edge_manager_for_partitioners(
1098
1098
return edge_manager
1099
1099
1100
1100
1101
+ def collect_named_data_store_from_exported_program (
1102
+ exported_program : ExportedProgram ,
1103
+ named_data_store : NamedDataStore ,
1104
+ ) -> None :
1105
+ """
1106
+ Collects all the named data store outputs found within the exported program
1107
+ and adds them to named_data_store.
1108
+ """
1109
+
1110
+ # collected all the named data into the named data store for deduplication
1111
+ def collect_named_data_store_outputs (
1112
+ graph_module : torch .fx .GraphModule ,
1113
+ ) -> None :
1114
+ for node in graph_module .graph .nodes :
1115
+ if node .target == executorch_call_delegate :
1116
+ lbm = getattr (graph_module , node .args [0 ].name )
1117
+ assert is_lowered_module (lbm )
1118
+ data_store_output = lbm .named_data_store_output
1119
+ if data_store_output is not None :
1120
+ named_data_store .merge_named_data_store (data_store_output )
1121
+
1122
+ for _ , submod , _ in get_control_flow_submodules (graph_module ):
1123
+ collect_named_data_store_outputs (submod )
1124
+
1125
+ collect_named_data_store_outputs (exported_program .graph_module )
1126
+
1127
+
1101
1128
@et_logger ("to_edge_transform_and_lower" )
1102
1129
def to_edge_transform_and_lower (
1103
1130
programs : Union [ExportedProgram , Dict [str , ExportedProgram ]],
@@ -1307,7 +1334,6 @@ def __init__(
1307
1334
constant_methods : Optional [Dict [str , Any ]] = None ,
1308
1335
compile_config : Optional [EdgeCompileConfig ] = None ,
1309
1336
ops_set_to_not_decompose : Optional [List [torch ._ops .OpOverload ]] = None ,
1310
- named_data_store : Optional [NamedDataStore ] = None ,
1311
1337
):
1312
1338
"""
1313
1339
Should not be called directly by users. User should use :func:'to_edge' instead.
@@ -1331,7 +1357,11 @@ def __init__(
1331
1357
self ._edge_programs : Dict [str , ExportedProgram ] = edge_programs
1332
1358
self ._config_methods = constant_methods
1333
1359
1334
- self ._named_data_store = named_data_store or NamedDataStore ()
1360
+ self ._named_data_store = NamedDataStore ()
1361
+ for _ , program in self ._edge_programs .items ():
1362
+ collect_named_data_store_from_exported_program (
1363
+ program , self ._named_data_store
1364
+ )
1335
1365
1336
1366
@property
1337
1367
def methods (self ) -> Set [str ]:
@@ -1441,30 +1471,11 @@ def to_backend(
1441
1471
for name , program in self ._edge_programs .items ():
1442
1472
new_edge_programs [name ] = to_backend (program , partitioner )
1443
1473
1444
- # collected all the named data into the named data store for deduplication
1445
- def collect_named_data_store_outputs (
1446
- graph_module : torch .fx .GraphModule ,
1447
- ) -> None :
1448
- for node in graph_module .graph .nodes :
1449
- if node .target == executorch_call_delegate :
1450
- lbm = getattr (graph_module , node .args [0 ].name )
1451
- assert is_lowered_module (lbm )
1452
- data_store_output = lbm .named_data_store_output
1453
- if data_store_output is not None :
1454
- self ._named_data_store .merge_named_data_store (data_store_output )
1455
-
1456
- for _ , submod , _ in get_control_flow_submodules (graph_module ):
1457
- collect_named_data_store_outputs (submod )
1458
-
1459
- for _ , program in new_edge_programs .items ():
1460
- collect_named_data_store_outputs (program .graph_module )
1461
-
1462
1474
config = EdgeCompileConfig (_check_ir_validity = False )
1463
1475
return EdgeProgramManager (
1464
1476
new_edge_programs ,
1465
1477
copy .deepcopy (self ._config_methods ),
1466
1478
config ,
1467
- named_data_store = self ._named_data_store ,
1468
1479
)
1469
1480
1470
1481
@et_logger ("to_executorch" )
0 commit comments