Skip to content

Commit 979eb20

Browse files
bobyangyffacebook-github-bot
authored andcommitted
Componenets to get fn returning Any (#1018)
Summary: Allow `Any` so that we can have different function signatures for pipeline definition Reviewed By: lgarg26 Differential Revision: D70936712
1 parent 64635a8 commit 979eb20

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

torchx/specs/builders.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _merge_config_values_with_args(
101101

102102

103103
def parse_args(
104-
cmpnt_fn: Callable[..., AppDef],
104+
cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
105105
cmpnt_args: List[str],
106106
cmpnt_defaults: Optional[Dict[str, Any]] = None,
107107
config: Optional[Dict[str, Any]] = None,
@@ -130,7 +130,7 @@ def parse_args(
130130

131131

132132
def materialize_appdef(
133-
cmpnt_fn: Callable[..., AppDef],
133+
cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
134134
cmpnt_args: List[str],
135135
cmpnt_defaults: Optional[Dict[str, Any]] = None,
136136
config: Optional[Dict[str, Any]] = None,
@@ -187,6 +187,10 @@ def materialize_appdef(
187187
var_arg = var_arg[1:]
188188

189189
appdef = cmpnt_fn(*function_args, *var_arg, **kwargs)
190+
if not isinstance(appdef, AppDef):
191+
raise TypeError(
192+
f"Expected a component that returns `AppDef`, but got `{type(appdef)}`"
193+
)
190194

191195
return appdef
192196

torchx/specs/finder.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
from inspect import getmembers, isfunction
1717
from pathlib import Path
1818
from types import ModuleType
19-
from typing import Callable, Dict, Generator, List, Optional, Union
19+
from typing import Any, Callable, Dict, Generator, List, Optional, Union
2020

2121
from torchx.specs import AppDef
2222
from torchx.specs.file_linter import get_fn_docstring, TorchxFunctionValidator, validate
2323
from torchx.util import entrypoints
2424
from torchx.util.io import read_conf_file
2525
from torchx.util.types import none_throws
2626

27+
2728
logger: logging.Logger = logging.getLogger(__name__)
2829

2930

@@ -53,7 +54,10 @@ class _Component:
5354
name: str
5455
description: str
5556
fn_name: str
56-
fn: Callable[..., AppDef]
57+
58+
# pyre-ignore[4] TODO temporary until PipelineDef is decoupled and can be exposed as type to OSS
59+
fn: Callable[..., Any]
60+
5761
validation_errors: List[str]
5862

5963

0 commit comments

Comments
 (0)