From 1440635bcf315db2a94519bfe3ffca83ff20d3a1 Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Fri, 28 Jul 2023 14:04:54 -0700 Subject: [PATCH] Fix failing tests (#857) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/857 Reviewed By: frank-wei, mortzur Differential Revision: D47874377 fbshipit-source-id: 8afe4e107e9c7ba16aed3e8931de3e64d7cb11c1 --- fx2ait/fx2ait/tools/common_aten2ait.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/fx2ait/fx2ait/tools/common_aten2ait.py b/fx2ait/fx2ait/tools/common_aten2ait.py index e6feab9d2..1a603848a 100644 --- a/fx2ait/fx2ait/tools/common_aten2ait.py +++ b/fx2ait/fx2ait/tools/common_aten2ait.py @@ -102,16 +102,20 @@ def generate_graph( if customized_passes: passes_list.extend(customized_passes) - fx_module = exir.capture( - mod, - tuple(original_inputs), - CaptureConfig( - pt2_mode=True, - enable_functionalization=False, - enable_dynamic_shape=True, - _use_old_decomp_table=True, - ), - ).transform(*tuple(passes_list)) + fx_module = ( + exir.capture( + mod, + tuple(original_inputs), + CaptureConfig( + pt2_mode=True, + enable_functionalization=False, + enable_dynamic_shape=True, + _use_old_decomp_table=True, + ), + ) + .transform(*tuple(passes_list)) + .exported_program.graph_module + ) fx_module = run_const_fold(fx_module) _LOGGER.info(f"aten fx graph: {fx_module.graph}")