Skip to content

Commit 0b3a7ca

Browse files
committed
Refactor code, add tests and add docstrings.
1 parent 24c77e9 commit 0b3a7ca

File tree

4 files changed

+553
-278
lines changed

4 files changed

+553
-278
lines changed

sdks/python/apache_beam/ml/anomaly/base.py

+100-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
# limitations under the License.
1616
#
1717

18-
"""Base classes for anomaly detection"""
18+
"""
19+
Base classes for anomaly detection
20+
"""
1921
from __future__ import annotations
2022

2123
import abc
@@ -26,51 +28,111 @@
2628

2729
import apache_beam as beam
2830

31+
__all__ = [
32+
"AnomalyPrediction",
33+
"AnomalyResult",
34+
"ThresholdFn",
35+
"AggregationFn",
36+
"AnomalyDetector",
37+
"EnsembleAnomalyDetector"
38+
]
39+
2940

3041
@dataclass(frozen=True)
3142
class AnomalyPrediction():
43+
"""A dataclass for anomaly detection predictions."""
44+
#: The ID of detector (model) that generates the prediction.
3245
model_id: Optional[str] = None
46+
#: The outlier score resulting from applying the detector to the input data.
3347
score: Optional[float] = None
48+
#: The outlier label (normal or outlier) derived from the outlier score.
3449
label: Optional[int] = None
50+
#: The threshold used to determine the label.
3551
threshold: Optional[float] = None
52+
#: Additional information about the prediction.
3653
info: str = ""
54+
#: If enabled, a list of `AnomalyPrediction` objects used to derive the
55+
#: aggregated prediction.
3756
agg_history: Optional[Iterable[AnomalyPrediction]] = None
3857

3958

4059
@dataclass(frozen=True)
4160
class AnomalyResult():
61+
"""A dataclass for the anomaly detection results"""
62+
#: The original input data.
4263
example: beam.Row
64+
#: The `AnomalyPrediction` object containing the prediction.
4365
prediction: AnomalyPrediction
4466

4567

4668
class ThresholdFn(abc.ABC):
69+
"""An abstract base class for threshold functions.
70+
71+
Args:
72+
normal_label: The integer label used to identify normal data. Defaults to 0.
73+
outlier_label: The integer label used to identify outlier data. Defaults to
74+
1.
75+
"""
4776
def __init__(self, normal_label: int = 0, outlier_label: int = 1):
4877
self._normal_label = normal_label
4978
self._outlier_label = outlier_label
5079

5180
@property
5281
@abc.abstractmethod
5382
def is_stateful(self) -> bool:
83+
"""Indicates whether the threshold function is stateful or not."""
5484
raise NotImplementedError
5585

5686
@property
5787
@abc.abstractmethod
5888
def threshold(self) -> Optional[float]:
89+
"""Retrieves the current threshold value, or None if not set."""
5990
raise NotImplementedError
6091

6192
@abc.abstractmethod
6293
def apply(self, score: Optional[float]) -> int:
94+
"""Applies the threshold function to a given score to classify it as
95+
normal or outlier.
96+
97+
Args:
98+
score: The outlier score generated from the detector (model).
99+
100+
Returns:
101+
The label assigned to the score, either `self._normal_label`
102+
or `self._outlier_label`
103+
"""
63104
raise NotImplementedError
64105

65106

66107
class AggregationFn(abc.ABC):
108+
"""An abstract base class for aggregation functions."""
67109
@abc.abstractmethod
68110
def apply(
69111
self, predictions: Iterable[AnomalyPrediction]) -> AnomalyPrediction:
112+
"""Applies the aggregation function to an iterable of predictions, either on
113+
their outlier scores or labels.
114+
115+
Args:
116+
predictions: An Iterable of `AnomalyPrediction` objects to aggregate.
117+
118+
Returns:
119+
An `AnomalyPrediction` object containing the aggregated result.
120+
"""
70121
raise NotImplementedError
71122

72123

73124
class AnomalyDetector(abc.ABC):
125+
"""An abstract base class for anomaly detectors.
126+
127+
Args:
128+
model_id: The ID of detector (model). Defaults to the value of the
129+
`spec_type` attribute, or 'unknown' if not set.
130+
features: An Iterable of strings representing the names of the input
131+
features in the `beam.Row`
132+
target: The name of the target field in the `beam.Row`.
133+
threshold_criterion: An optional `ThresholdFn` to apply to the outlier score
134+
and yield a label.
135+
"""
74136
def __init__(
75137
self,
76138
model_id: Optional[str] = None,
@@ -79,36 +141,71 @@ def __init__(
79141
threshold_criterion: Optional[ThresholdFn] = None,
80142
**kwargs):
81143
self._model_id = model_id if model_id is not None else getattr(
82-
self, '_key', 'unknown')
144+
self, 'spec_type', 'unknown')
83145
self._features = features
84146
self._target = target
85147
self._threshold_criterion = threshold_criterion
86148

87149
@abc.abstractmethod
88150
def learn_one(self, x: beam.Row) -> None:
151+
"""Trains the detector on a single data instance.
152+
153+
Args:
154+
x: A `beam.Row` representing the data instance.
155+
"""
89156
raise NotImplementedError
90157

91158
@abc.abstractmethod
92159
def score_one(self, x: beam.Row) -> float:
160+
"""Scores a single data instance for anomalies.
161+
162+
Args:
163+
x: A `beam.Row` representing the data instance.
164+
165+
Returns:
166+
The outlier score as a float.
167+
"""
93168
raise NotImplementedError
94169

95170

96171
class EnsembleAnomalyDetector(AnomalyDetector):
172+
"""An abstract base class for an ensemble of anomaly (sub-)detectors.
173+
174+
Args:
175+
sub_detectors: A List of `AnomalyDetector` used in this ensemble model.
176+
aggregation_strategy: An optional `AggregationFn` to apply to the
177+
predictions from all sub-detectors and yield an aggregated result.
178+
model_id: Inherited from `AnomalyDetector`.
179+
features: Inherited from `AnomalyDetector`.
180+
target: Inherited from `AnomalyDetector`.
181+
threshold_criterion: Inherited from `AnomalyDetector`.
182+
"""
97183
def __init__(
98184
self,
99185
sub_detectors: Optional[List[AnomalyDetector]] = None,
100186
aggregation_strategy: Optional[AggregationFn] = None,
101187
**kwargs):
102188
if "model_id" not in kwargs or kwargs["model_id"] is None:
103-
kwargs["model_id"] = getattr(self, '_key', 'custom')
189+
kwargs["model_id"] = getattr(self, 'spec_type', 'custom')
104190

105191
super().__init__(**kwargs)
106192

107193
self._aggregation_strategy = aggregation_strategy
108194
self._sub_detectors = sub_detectors
109195

110196
def learn_one(self, x: beam.Row) -> None:
197+
"""Inherited from `AnomalyDetector.learn_one`.
198+
199+
This method is never called during ensemble detector training. The training
200+
process is done on each sub-detector independently and in parallel.
201+
"""
111202
raise NotImplementedError
112203

113204
def score_one(self, x: beam.Row) -> float:
205+
"""Inherited from `AnomalyDetector.score_one`.
206+
207+
This method is never called during ensemble detector scoring. The scoring
208+
process is done on sub-detector independently and in parallel, and then
209+
the results are aggregated in the pipeline.
210+
"""
114211
raise NotImplementedError

0 commit comments

Comments
 (0)