From 371d46566ee69f7bdb9683ee2a7e69da58a50bc2 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Thu, 6 Feb 2025 00:01:10 -0500 Subject: [PATCH] Refactor and add comments to the code. --- sdks/python/apache_beam/ml/anomaly/base.py | 103 +++++++- .../apache_beam/ml/anomaly/base_test.py | 7 +- .../apache_beam/ml/anomaly/specifiable.py | 245 ++++++++++++------ .../ml/anomaly/specifiable_test.py | 185 +++++++------ 4 files changed, 382 insertions(+), 158 deletions(-) diff --git a/sdks/python/apache_beam/ml/anomaly/base.py b/sdks/python/apache_beam/ml/anomaly/base.py index fdbb88548d55..dfe29ee55ee9 100644 --- a/sdks/python/apache_beam/ml/anomaly/base.py +++ b/sdks/python/apache_beam/ml/anomaly/base.py @@ -15,7 +15,9 @@ # limitations under the License. # -"""Base classes for anomaly detection""" +""" +Base classes for anomaly detection +""" from __future__ import annotations import abc @@ -26,24 +28,51 @@ import apache_beam as beam +__all__ = [ + "AnomalyPrediction", + "AnomalyResult", + "ThresholdFn", + "AggregationFn", + "AnomalyDetector", + "EnsembleAnomalyDetector" +] + @dataclass(frozen=True) class AnomalyPrediction(): + """A dataclass for anomaly detection predictions.""" + #: The ID of detector (model) that generates the prediction. model_id: Optional[str] = None + #: The outlier score resulting from applying the detector to the input data. score: Optional[float] = None + #: The outlier label (normal or outlier) derived from the outlier score. label: Optional[int] = None + #: The threshold used to determine the label. threshold: Optional[float] = None + #: Additional information about the prediction. info: str = "" + #: If enabled, a list of `AnomalyPrediction` objects used to derive the + #: aggregated prediction. agg_history: Optional[Iterable[AnomalyPrediction]] = None @dataclass(frozen=True) class AnomalyResult(): + """A dataclass for the anomaly detection results""" + #: The original input data. example: beam.Row + #: The `AnomalyPrediction` object containing the prediction. prediction: AnomalyPrediction class ThresholdFn(abc.ABC): + """An abstract base class for threshold functions. + + Args: + normal_label: The integer label used to identify normal data. Defaults to 0. + outlier_label: The integer label used to identify outlier data. Defaults to + 1. + """ def __init__(self, normal_label: int = 0, outlier_label: int = 1): self._normal_label = normal_label self._outlier_label = outlier_label @@ -51,26 +80,59 @@ def __init__(self, normal_label: int = 0, outlier_label: int = 1): @property @abc.abstractmethod def is_stateful(self) -> bool: + """Indicates whether the threshold function is stateful or not.""" raise NotImplementedError @property @abc.abstractmethod def threshold(self) -> Optional[float]: + """Retrieves the current threshold value, or None if not set.""" raise NotImplementedError @abc.abstractmethod def apply(self, score: Optional[float]) -> int: + """Applies the threshold function to a given score to classify it as + normal or outlier. + + Args: + score: The outlier score generated from the detector (model). + + Returns: + The label assigned to the score, either `self._normal_label` + or `self._outlier_label` + """ raise NotImplementedError class AggregationFn(abc.ABC): + """An abstract base class for aggregation functions.""" @abc.abstractmethod def apply( self, predictions: Iterable[AnomalyPrediction]) -> AnomalyPrediction: + """Applies the aggregation function to an iterable of predictions, either on + their outlier scores or labels. + + Args: + predictions: An Iterable of `AnomalyPrediction` objects to aggregate. + + Returns: + An `AnomalyPrediction` object containing the aggregated result. + """ raise NotImplementedError class AnomalyDetector(abc.ABC): + """An abstract base class for anomaly detectors. + + Args: + model_id: The ID of detector (model). Defaults to the value of the + `spec_type` attribute, or 'unknown' if not set. + features: An Iterable of strings representing the names of the input + features in the `beam.Row` + target: The name of the target field in the `beam.Row`. + threshold_criterion: An optional `ThresholdFn` to apply to the outlier score + and yield a label. + """ def __init__( self, model_id: Optional[str] = None, @@ -79,28 +141,52 @@ def __init__( threshold_criterion: Optional[ThresholdFn] = None, **kwargs): self._model_id = model_id if model_id is not None else getattr( - self, '_key', 'unknown') + self, 'spec_type', 'unknown') self._features = features self._target = target self._threshold_criterion = threshold_criterion @abc.abstractmethod def learn_one(self, x: beam.Row) -> None: + """Trains the detector on a single data instance. + + Args: + x: A `beam.Row` representing the data instance. + """ raise NotImplementedError @abc.abstractmethod def score_one(self, x: beam.Row) -> float: + """Scores a single data instance for anomalies. + + Args: + x: A `beam.Row` representing the data instance. + + Returns: + The outlier score as a float. + """ raise NotImplementedError class EnsembleAnomalyDetector(AnomalyDetector): + """An abstract base class for an ensemble of anomaly (sub-)detectors. + + Args: + sub_detectors: A List of `AnomalyDetector` used in this ensemble model. + aggregation_strategy: An optional `AggregationFn` to apply to the + predictions from all sub-detectors and yield an aggregated result. + model_id: Inherited from `AnomalyDetector`. + features: Inherited from `AnomalyDetector`. + target: Inherited from `AnomalyDetector`. + threshold_criterion: Inherited from `AnomalyDetector`. + """ def __init__( self, sub_detectors: Optional[List[AnomalyDetector]] = None, aggregation_strategy: Optional[AggregationFn] = None, **kwargs): if "model_id" not in kwargs or kwargs["model_id"] is None: - kwargs["model_id"] = getattr(self, '_key', 'custom') + kwargs["model_id"] = getattr(self, 'spec_type', 'custom') super().__init__(**kwargs) @@ -108,7 +194,18 @@ def __init__( self._sub_detectors = sub_detectors def learn_one(self, x: beam.Row) -> None: + """Inherited from `AnomalyDetector.learn_one`. + + This method is never called during ensemble detector training. The training + process is done on each sub-detector independently and in parallel. + """ raise NotImplementedError def score_one(self, x: beam.Row) -> float: + """Inherited from `AnomalyDetector.score_one`. + + This method is never called during ensemble detector scoring. The scoring + process is done on sub-detector independently and in parallel, and then + the results are aggregated in the pipeline. + """ raise NotImplementedError diff --git a/sdks/python/apache_beam/ml/anomaly/base_test.py b/sdks/python/apache_beam/ml/anomaly/base_test.py index f425525a250a..74bd5b8e5f57 100644 --- a/sdks/python/apache_beam/ml/anomaly/base_test.py +++ b/sdks/python/apache_beam/ml/anomaly/base_test.py @@ -60,9 +60,6 @@ def __eq__(self, value) -> bool: return isinstance(value, TestAnomalyDetector.Dummy) and \ self._my_arg == value._my_arg - def test_unknown_detector(self): - self.assertRaises(ValueError, Specifiable.from_spec, Spec(type="unknown")) - def test_model_id_on_known_detector(self): a = self.Dummy( my_arg="abc", @@ -75,7 +72,7 @@ def test_model_id_on_known_detector(self): assert isinstance(a, Specifiable) self.assertEqual( - a._init_params, { + a.init_kwargs, { "my_arg": "abc", "target": "ABC", "threshold_criterion": t1, @@ -92,7 +89,7 @@ def test_model_id_on_known_detector(self): assert isinstance(b, Specifiable) self.assertEqual( - b._init_params, + b.init_kwargs, { "model_id": "my_dummy", "my_arg": "efg", diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py b/sdks/python/apache_beam/ml/anomaly/specifiable.py index 40b99691bbfc..ef53e8430ac8 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py @@ -15,6 +15,10 @@ # limitations under the License. # +""" +A module that provides utilities to turn a class into a Specifiable subclass. +""" + from __future__ import annotations import dataclasses @@ -30,6 +34,8 @@ from typing_extensions import Self +__all__ = ["KNOWN_SPECIFIABLE", "Spec", "Specifiable", "specifiable"] + ACCEPTED_SPECIFIABLE_SUBSPACES = [ "EnsembleAnomalyDetector", "AnomalyDetector", @@ -37,36 +43,79 @@ "AggregationFn", "*" ] + +#: A nested dictionary for efficient lookup of Specifiable subclasses. +#: Structure: KNOWN_SPECIFIABLE[subspace][spec_type], where "subspace" is one of +#: the accepted subspaces that the class belongs to and "spec_type" is the class +#: name by default. Users can also specify a different value for "spec_type" +#: when applying the `specifiable` decorator to an existing class. KNOWN_SPECIFIABLE = {"*": {}} SpecT = TypeVar('SpecT', bound='Specifiable') -def get_subspace(cls, type=None): - if type is None: - subspace = "*" - for c in cls.mro(): - if c.__name__ in ACCEPTED_SPECIFIABLE_SUBSPACES: - subspace = c.__name__ - break - return subspace - else: - for subspace in ACCEPTED_SPECIFIABLE_SUBSPACES: - if subspace in KNOWN_SPECIFIABLE and type in KNOWN_SPECIFIABLE[subspace]: - return subspace +def _class_to_subspace(cls: Type, default="*") -> str: + """ + Search the class hierarchy to find the subspace: the closest ancestor class in + the class's method resolution order (MRO) whose name is found in the accepted + subspace list. This is usually called when registering a new Specifiable + class. + """ + for c in cls.mro(): + # + if c.__name__ in ACCEPTED_SPECIFIABLE_SUBSPACES: + return c.__name__ + + if default is None: raise ValueError(f"subspace for {cls.__name__} not found.") + return default + + +def _spec_type_to_subspace(type: str, default="*") -> str: + """ + Look for the subspace for a spec type. This is usually called to retrieve + the subspace of a registered Specifiable class. + """ + for subspace in ACCEPTED_SPECIFIABLE_SUBSPACES: + if type in KNOWN_SPECIFIABLE.get(subspace, {}): + return subspace + + if default is None: + raise ValueError(f"subspace for {type} not found.") + + return default + @dataclasses.dataclass(frozen=True) class Spec(): + """ + Dataclass for storing specifications of specifiable objects. + Objects can be initialized using the data in their corresponding spec. + The `type` field indicates the concrete Specifiable class, while + """ + #: A string indicating the concrete Specifiable class type: str + #: A dictionary of keyword arguments for the `__init__` method of the class. config: dict[str, Any] = dataclasses.field(default_factory=dict) @runtime_checkable class Specifiable(Protocol): - _key: ClassVar[str] - _init_params: dict[str, Any] + """Protocol that a Specifiable subclass needs to implement. + + Attributes: + spec_type: The value of the `type` field in the object's Spec for this + class. + init_kwargs: The raw keyword arguments passed to `__init__` during object + initialization. + """ + spec_type: ClassVar[str] + init_kwargs: dict[str, Any] + # a boolean to tell whether the original __init__ is called + _initialized: bool + # a boolean used by new_getattr to tell whether it is in an __init__ call + _in_init: bool @staticmethod def _from_spec_helper(v): @@ -80,17 +129,21 @@ def _from_spec_helper(v): @classmethod def from_spec(cls, spec: Spec) -> Self: + """Generate a Specifiable subclass object based on a spec.""" if spec.type is None: raise ValueError(f"Spec type not found in {spec}") - subspace = get_subspace(cls, spec.type) + subspace = _spec_type_to_subspace(spec.type) subclass: Type[Self] = KNOWN_SPECIFIABLE[subspace].get(spec.type, None) if subclass is None: raise ValueError(f"Unknown spec type '{spec.type}' in {spec}") - args = {k: Specifiable._from_spec_helper(v) for k, v in spec.config.items()} + kwargs = { + k: Specifiable._from_spec_helper(v) + for k, v in spec.config.items() + } - return subclass(**args) + return subclass(**kwargs) @staticmethod def _to_spec_helper(v): @@ -103,121 +156,161 @@ def _to_spec_helper(v): return v def to_spec(self) -> Spec: - if getattr(type(self), '_key', None) is None: + """ + Generate a spec from a Specifiable subclass object. + """ + if getattr(type(self), 'spec_type', None) is None: raise ValueError( f"'{type(self).__name__}' not registered as Specifiable. " f"Decorate ({type(self).__name__}) with @specifiable") - args = {k: self._to_spec_helper(v) for k, v in self._init_params.items()} - - return Spec(type=self.__class__._key, config=args) + args = {k: self._to_spec_helper(v) for k, v in self.init_kwargs.items()} + return Spec(type=self.__class__.spec_type, config=args) -def register(cls, key, error_if_exists) -> None: - if key is None: - key = cls.__name__ - subspace = get_subspace(cls) - if subspace in KNOWN_SPECIFIABLE and key in KNOWN_SPECIFIABLE[ - subspace] and error_if_exists: - raise ValueError(f"{key} is already registered for specifiable") +# Register a Specifiable subclass in KNOWN_SPECIFIABLE +def _register(cls, spec_type=None, error_if_exists=True) -> None: + if spec_type is None: + # By default, spec type is the class name. Users can override this with + # other unique identifier. + spec_type = cls.__name__ - if subspace not in KNOWN_SPECIFIABLE: + subspace = _class_to_subspace(cls) + if subspace in KNOWN_SPECIFIABLE: + if spec_type in KNOWN_SPECIFIABLE[subspace] and error_if_exists: + raise ValueError(f"{spec_type} is already registered for specifiable") + else: KNOWN_SPECIFIABLE[subspace] = {} - KNOWN_SPECIFIABLE[subspace][key] = cls + KNOWN_SPECIFIABLE[subspace][spec_type] = cls - cls._key = key + cls.spec_type = spec_type -def track_init_params(inst, init_method, *args, **kwargs): +# Keep a copy of arguments that are used to call __init__ method, when the +# object is initialized. +def _get_init_kwargs(inst, init_method, *args, **kwargs): params = dict( zip(inspect.signature(init_method).parameters.keys(), (None, ) + args)) del params['self'] params.update(**kwargs) - inst._init_params = params + return params def specifiable( my_cls=None, /, *, - key=None, + spec_type=None, error_if_exists=True, on_demand_init=True, just_in_time_init=True): - - # register a specifiable, track init params for each instance, lazy init + """A decorator that turns a class into a Specifiable subclass by implementing + the Specifiable protocol. + + To use the decorator, simply place `@specifiable` before the class definition. + For finer control, the decorator accepts arguments + (e.g., `@specifiable(arg1=..., arg2=...)`). + + Args: + spec_type: The value of the `type` field in the Spec of a Specifiable + subclass. If not provided, the class name is used. + error_if_exists: If True, raise an exception if `spec_type` is already + registered. + on_demand_init: If True, allow on-demand object initialization. The original + `__init__` method will be called when `_run_init=True` is passed to the + object's initialization function. + just_in_time_init: If True, allow just-in-time object initialization. The + original `__init__` method will be called when an attribute is first + accessed. + """ def _wrapper(cls): - register(cls, key, error_if_exists) - - original_init = cls.__init__ - class_name = cls.__name__ - - def new_init(self, *args, **kwargs): + def new_init(self: Specifiable, *args, **kwargs): self._initialized = False - #self._nested_getattr = False - - if kwargs.get("_run_init", False): - run_init = True - del kwargs['_run_init'] - else: - run_init = False + self._in_init = False - if '_init_params' not in self.__dict__: - track_init_params(self, original_init, *args, **kwargs) + run_init_request = False + if "_run_init" in kwargs: + run_init_request = kwargs["_run_init"] + del kwargs["_run_init"] + if 'init_kwargs' not in self.__dict__: + self.init_kwargs = _get_init_kwargs( + self, original_init, *args, **kwargs) + logging.debug("Record init params in %s.new_init", class_name) # If it is not a nested specifiable, we choose whether to skip original # init call based on options. Otherwise, we always call original init # for inner (parent/grandparent/etc) specifiable. - if (on_demand_init and not run_init) or \ + if (on_demand_init and not run_init_request) or \ (not on_demand_init and just_in_time_init): + logging.debug("Skip original %s.__init__", class_name) return - logging.debug("call original %s.__init__ in new_init", class_name) + logging.debug("Call original %s.__init__ in new_init", class_name) + original_init(self, *args, **kwargs) self._initialized = True - def run_init(self): - original_init(self, **self._init_params) + def run_original_init(self): + self._in_init = True + original_init(self, **self.init_kwargs) + self._in_init = False + self._initialized = True + # __getattr__ is only called when an attribute is not found in the object def new_getattr(self, name): - if name == '_nested_getattr' or \ - ('_nested_getattr' in self.__dict__ and self._nested_getattr): - #self._nested_getattr = False - delattr(self, "_nested_getattr") + logging.debug( + "Trying to access %s.%s, but it is not found.", class_name, name) + + # Fix the infinite loop issue when pickling a Specifiable + if name in ["_in_init", "__getstate__"] and name not in self.__dict__: raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'") - # set it before original init, in case getattr is called in original init - self._nested_getattr = True + # If the attribute is not found during or after initialization, then + # it is a missing attribute. + if self._in_init or self._initialized: + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'") - if not self._initialized and name != "__getstate__": - logging.debug("call original %s.__init__ in new_getattr", class_name) - original_init(self, **self._init_params) - self._initialized = True + # Here, we know the object is not initialized, then we will call original + # init method. + logging.debug("Call original %s.__init__ in new_getattr", class_name) + run_original_init(self) - try: - logging.debug("call original %s.getattr in new_getattr", class_name) - ret = getattr(self, name) - finally: - # self._nested_getattr = False - delattr(self, "_nested_getattr") - return ret + # __getattribute__ is call for every attribute regardless whether it is + # present in the object. In this case, we don't cause an infinite loop + # if the attribute does not exists. + logging.debug( + "Call original %s.__getattribute__(%s) in new_getattr", + class_name, + name) + return self.__getattribute__(name) + # start of the function body of _wrapper + _register(cls, spec_type, error_if_exists) + + class_name = cls.__name__ + original_init = cls.__init__ + cls.__init__ = new_init if just_in_time_init: cls.__getattr__ = new_getattr - cls.__init__ = new_init - cls._run_init = run_init + cls.run_original_init = run_original_init cls.to_spec = Specifiable.to_spec cls._to_spec_helper = staticmethod(Specifiable._to_spec_helper) cls.from_spec = classmethod(Specifiable.from_spec) cls._from_spec_helper = staticmethod(Specifiable._from_spec_helper) return cls + # end of the function body of _wrapper + # When this decorator is called with arguments, i.e.. + # "@specifiable(arg1=...,arg2=...)", it is equivalent to assigning + # specifiable(arg1=..., arg2=...) to a variable, say decor_func, and then + # calling "@decor_func". if my_cls is None: - # support @specifiable(...) return _wrapper - # support @specifiable without arguments + # When this decorator is called without an argument, i.e. "@specifiable", + # we return the augmented class. return _wrapper(my_cls) diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py index 71910b03ed5b..202f6e763044 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py @@ -30,69 +30,83 @@ class TestSpecifiable(unittest.TestCase): - def test_register_specifiable(self): - class MyClass(): + def test_decorator_in_function_form(self): + class A(): pass - # class is not decorated/registered - self.assertRaises(AttributeError, lambda: MyClass().to_spec()) - - self.assertNotIn("MyKey", KNOWN_SPECIFIABLE["*"]) - - MyClass = specifiable(key="MyKey")(MyClass) - - self.assertIn("MyKey", KNOWN_SPECIFIABLE["*"]) - self.assertEqual(KNOWN_SPECIFIABLE["*"]["MyKey"], MyClass) - - # By default, an error is raised if the key is duplicated - self.assertRaises(ValueError, specifiable(key="MyKey"), MyClass) - - # But it is ok if a different key is used for the same class - _ = specifiable(key="MyOtherKey")(MyClass) - self.assertIn("MyOtherKey", KNOWN_SPECIFIABLE["*"]) - - # Or, use a parameter to suppress the error - specifiable(key="MyKey", error_if_exists=False)(MyClass) - - def test_decorator_key(self): - # use decorator without parameter + # class is not decorated and thus not registered + self.assertNotIn("A", KNOWN_SPECIFIABLE["*"]) + + # apply the decorator function to an existing class + A = specifiable(A) + self.assertEqual(A.spec_type, "A") + self.assertTrue(isinstance(A(), Specifiable)) + self.assertIn("A", KNOWN_SPECIFIABLE["*"]) + self.assertEqual(KNOWN_SPECIFIABLE["*"]["A"], A) + + # an error is raised if the specified spec_type already exists. + self.assertRaises(ValueError, specifiable, A) + + # apply the decorator function to an existing class with a different + # spec_type + A = specifiable(spec_type="A_DUP")(A) + self.assertEqual(A.spec_type, "A_DUP") + self.assertTrue(isinstance(A(), Specifiable)) + self.assertIn("A_DUP", KNOWN_SPECIFIABLE["*"]) + self.assertEqual(KNOWN_SPECIFIABLE["*"]["A_DUP"], A) + + # an error is raised if the specified spec_type already exists. + self.assertRaises(ValueError, specifiable(spec_type="A_DUP"), A) + + # but the error can be suppressed by setting error_if_exists=False. + try: + specifiable(spec_type="A_DUP", error_if_exists=False)(A) + except ValueError: + self.fail("The ValueError should be suppressed but instead it is raised.") + + def test_decorator_in_syntactic_sugar_form(self): + # call decorator without parameters @specifiable - class MySecondClass(): + class B(): pass - self.assertIn("MySecondClass", KNOWN_SPECIFIABLE["*"]) - self.assertEqual(KNOWN_SPECIFIABLE["*"]["MySecondClass"], MySecondClass) - self.assertTrue(isinstance(MySecondClass(), Specifiable)) + self.assertTrue(isinstance(B(), Specifiable)) + self.assertIn("B", KNOWN_SPECIFIABLE["*"]) + self.assertEqual(KNOWN_SPECIFIABLE["*"]["B"], B) - # use decorator with key parameter - @specifiable(key="MyThirdKey") - class MyThirdClass(): + # call decorator with parameters + @specifiable(spec_type="C_TYPE") + class C(): pass - self.assertIn("MyThirdKey", KNOWN_SPECIFIABLE["*"]) - self.assertEqual(KNOWN_SPECIFIABLE["*"]["MyThirdKey"], MyThirdClass) + self.assertTrue(isinstance(C(), Specifiable)) + self.assertIn("C_TYPE", KNOWN_SPECIFIABLE["*"]) + self.assertEqual(KNOWN_SPECIFIABLE["*"]["C_TYPE"], C) def test_init_params_in_specifiable(self): @specifiable - class MyClassWithInitParams(): + class ParentWithInitParams(): def __init__(self, arg_1, arg_2=2, arg_3="3", **kwargs): pass - a = MyClassWithInitParams(10, arg_3="30", arg_4=40) - assert isinstance(a, Specifiable) - self.assertEqual(a._init_params, {'arg_1': 10, 'arg_3': '30', 'arg_4': 40}) + parent = ParentWithInitParams(10, arg_3="30", arg_4=40) + assert isinstance(parent, Specifiable) + self.assertEqual( + parent.init_kwargs, { + 'arg_1': 10, 'arg_3': '30', 'arg_4': 40 + }) - # inheritance of specifiable + # inheritance of a Specifiable subclass @specifiable - class MyDerivedClassWithInitParams(MyClassWithInitParams): + class ChildWithInitParams(ParentWithInitParams): def __init__(self, new_arg_1, new_arg_2=200, new_arg_3="300", **kwargs): super().__init__(**kwargs) - b = MyDerivedClassWithInitParams( + child = ChildWithInitParams( 1000, arg_1=11, arg_2=20, new_arg_2=2000, arg_4=4000) - assert isinstance(b, Specifiable) + assert isinstance(child, Specifiable) self.assertEqual( - b._init_params, + child.init_kwargs, { 'new_arg_1': 1000, 'arg_1': 11, @@ -101,25 +115,41 @@ def __init__(self, new_arg_1, new_arg_2=200, new_arg_3="300", **kwargs): 'arg_4': 4000 }) - # composite of specifiable + # composite of Specifiable subclasses @specifiable - class MyCompositeClassWithInitParams(): - def __init__(self, my_class: Optional[MyClassWithInitParams] = None): + class CompositeWithInitParams(): + def __init__( + self, + my_parent: Optional[ParentWithInitParams] = None, + my_child: Optional[ChildWithInitParams] = None): pass - c = MyCompositeClassWithInitParams(a) - assert isinstance(c, Specifiable) - self.assertEqual(c._init_params, {'my_class': a}) + composite = CompositeWithInitParams(parent, child) + assert isinstance(composite, Specifiable) + self.assertEqual( + composite.init_kwargs, { + 'my_parent': parent, 'my_child': child + }) + + def test_from_spec_on_unknown_spec_type(self): + self.assertRaises(ValueError, Specifiable.from_spec, Spec(type="unknown")) - def test_from_and_to_specifiable(self): - @specifiable(on_demand_init=False, just_in_time_init=False) + # To test from_spec and to_spec with/without just_in_time_init. + @parameterized.expand([False, True]) + def test_from_spec_and_to_spec(self, just_in_time_init): + @specifiable( + spec_type=f"product_{just_in_time_init}", + on_demand_init=False, + just_in_time_init=just_in_time_init) @dataclasses.dataclass class Product(): name: str price: float @specifiable( - key="shopping_entry", on_demand_init=False, just_in_time_init=False) + spec_type=f"shopping_entry_{just_in_time_init}", + on_demand_init=False, + just_in_time_init=just_in_time_init) class Entry(): def __init__(self, product: Product, quantity: int = 1): self._product = product @@ -131,7 +161,9 @@ def __eq__(self, value) -> bool: self._quantity == value._quantity @specifiable( - key="shopping_cart", on_demand_init=False, just_in_time_init=False) + spec_type=f"shopping_cart_{just_in_time_init}", + on_demand_init=False, + just_in_time_init=just_in_time_init) @dataclasses.dataclass class ShoppingCart(): user_id: str @@ -140,7 +172,7 @@ class ShoppingCart(): orange = Product("orange", 1.0) expected_orange_spec = Spec( - "Product", config={ + f"product_{just_in_time_init}", config={ 'name': 'orange', 'price': 1.0 }) assert isinstance(orange, Specifiable) @@ -150,7 +182,8 @@ class ShoppingCart(): entry_1 = Entry(product=orange) expected_entry_spec_1 = Spec( - "shopping_entry", config={ + f"shopping_entry_{just_in_time_init}", + config={ 'product': expected_orange_spec, }) @@ -160,19 +193,19 @@ class ShoppingCart(): banana = Product("banana", 0.5) expected_banana_spec = Spec( - "Product", config={ + f"product_{just_in_time_init}", config={ 'name': 'banana', 'price': 0.5 }) entry_2 = Entry(product=banana, quantity=5) expected_entry_spec_2 = Spec( - "shopping_entry", + f"shopping_entry_{just_in_time_init}", config={ 'product': expected_banana_spec, 'quantity': 5 }) shopping_cart = ShoppingCart(user_id="test", entries=[entry_1, entry_2]) expected_shopping_cart_spec = Spec( - "shopping_cart", + f"shopping_cart_{just_in_time_init}", config={ "user_id": "test", "entries": [expected_entry_spec_1, expected_entry_spec_2] @@ -190,12 +223,12 @@ class FooOnDemand(): def __init__(self, arg): self.my_arg = arg * 10 - FooOnDemand.counter += 1 + FooOnDemand.counter += 1 # increment it when __init__ is called foo = FooOnDemand(123) self.assertEqual(FooOnDemand.counter, 0) - self.assertIn("_init_params", foo.__dict__) - self.assertEqual(foo.__dict__["_init_params"], {"arg": 123}) + self.assertIn("init_kwargs", foo.__dict__) + self.assertEqual(foo.__dict__["init_kwargs"], {"arg": 123}) self.assertNotIn("my_arg", foo.__dict__) self.assertRaises(AttributeError, getattr, foo, "my_arg") @@ -204,10 +237,11 @@ def __init__(self, arg): self.assertRaises(AttributeError, lambda: foo.unknown_arg) self.assertEqual(FooOnDemand.counter, 0) + # __init__ is called when _run_init=True is used foo_2 = FooOnDemand(456, _run_init=True) self.assertEqual(FooOnDemand.counter, 1) - self.assertIn("_init_params", foo_2.__dict__) - self.assertEqual(foo_2.__dict__["_init_params"], {"arg": 456}) + self.assertIn("init_kwargs", foo_2.__dict__) + self.assertEqual(foo_2.__dict__["init_kwargs"], {"arg": 456}) self.assertIn("my_arg", foo_2.__dict__) self.assertEqual(foo_2.my_arg, 4560) @@ -220,17 +254,18 @@ class FooJustInTime(): def __init__(self, arg): self.my_arg = arg * 10 - FooJustInTime.counter += 1 + FooJustInTime.counter += 1 # increment it when __init__ is called foo = FooJustInTime(321) self.assertEqual(FooJustInTime.counter, 0) - self.assertIn("_init_params", foo.__dict__) - self.assertEqual(foo.__dict__["_init_params"], {"arg": 321}) + self.assertIn("init_kwargs", foo.__dict__) + self.assertEqual(foo.__dict__["init_kwargs"], {"arg": 321}) - self.assertNotIn("my_arg", foo.__dict__) # __init__ hasn't been called + # __init__ hasn't been called yet + self.assertNotIn("my_arg", foo.__dict__) self.assertEqual(FooJustInTime.counter, 0) - # __init__ is called when trying to accessing an attribute + # __init__ is called when trying to access a class attribute self.assertEqual(foo.my_arg, 3210) self.assertEqual(FooJustInTime.counter, 1) self.assertRaises(AttributeError, lambda: foo.unknown_arg) @@ -247,23 +282,23 @@ def __init__(self, arg): foo = FooOnDemandAndJustInTime(987) self.assertEqual(FooOnDemandAndJustInTime.counter, 0) - self.assertIn("_init_params", foo.__dict__) - self.assertEqual(foo.__dict__["_init_params"], {"arg": 987}) + self.assertIn("init_kwargs", foo.__dict__) + self.assertEqual(foo.__dict__["init_kwargs"], {"arg": 987}) self.assertNotIn("my_arg", foo.__dict__) self.assertEqual(FooOnDemandAndJustInTime.counter, 0) - # __init__ is called + # __init__ is called when trying to access a class attribute self.assertEqual(foo.my_arg, 9870) self.assertEqual(FooOnDemandAndJustInTime.counter, 1) - # __init__ is called + # __init__ is called when _run_init=True is used foo_2 = FooOnDemandAndJustInTime(789, _run_init=True) self.assertEqual(FooOnDemandAndJustInTime.counter, 2) - self.assertIn("_init_params", foo_2.__dict__) - self.assertEqual(foo_2.__dict__["_init_params"], {"arg": 789}) + self.assertIn("init_kwargs", foo_2.__dict__) + self.assertEqual(foo_2.__dict__["init_kwargs"], {"arg": 789}) self.assertEqual(FooOnDemandAndJustInTime.counter, 2) - # __init__ is NOT called + # __init__ is NOT called after it is initialized self.assertEqual(foo_2.my_arg, 7890) self.assertEqual(FooOnDemandAndJustInTime.counter, 2) @@ -357,6 +392,7 @@ class Child_Error_1(Parent): child_class_var = 2001 def __init__(self, c): + # read an instance var in child that doesn't exist self.child_inst_var += 1 super().__init__(c) Child_2.counter += 1 @@ -368,6 +404,7 @@ class Child_Error_2(Parent): child_class_var = 2001 def __init__(self, c): + # read an instance var in parent without calling parent's __init__. self.parent_inst_var += 1 Child_2.counter += 1