15
15
# limitations under the License.
16
16
#
17
17
18
- """Base classes for anomaly detection"""
18
+ """
19
+ Base classes for anomaly detection
20
+ """
19
21
from __future__ import annotations
20
22
21
23
import abc
26
28
27
29
import apache_beam as beam
28
30
31
+ __all__ = [
32
+ "AnomalyPrediction" ,
33
+ "AnomalyResult" ,
34
+ "ThresholdFn" ,
35
+ "AggregationFn" ,
36
+ "AnomalyDetector" ,
37
+ "EnsembleAnomalyDetector"
38
+ ]
39
+
29
40
30
41
@dataclass (frozen = True )
31
42
class AnomalyPrediction ():
43
+ """A dataclass for anomaly detection predictions."""
44
+ #: The ID of detector (model) that generates the prediction.
32
45
model_id : Optional [str ] = None
46
+ #: The outlier score resulting from applying the detector to the input data.
33
47
score : Optional [float ] = None
48
+ #: The outlier label (normal or outlier) derived from the outlier score.
34
49
label : Optional [int ] = None
50
+ #: The threshold used to determine the label.
35
51
threshold : Optional [float ] = None
52
+ #: Additional information about the prediction.
36
53
info : str = ""
54
+ #: If enabled, a list of `AnomalyPrediction` objects used to derive the
55
+ #: aggregated prediction.
37
56
agg_history : Optional [Iterable [AnomalyPrediction ]] = None
38
57
39
58
40
59
@dataclass (frozen = True )
41
60
class AnomalyResult ():
61
+ """A dataclass for the anomaly detection results"""
62
+ #: The original input data.
42
63
example : beam .Row
64
+ #: The `AnomalyPrediction` object containing the prediction.
43
65
prediction : AnomalyPrediction
44
66
45
67
46
68
class ThresholdFn (abc .ABC ):
69
+ """An abstract base class for threshold functions.
70
+
71
+ Args:
72
+ normal_label: The integer label used to identify normal data. Defaults to 0.
73
+ outlier_label: The integer label used to identify outlier data. Defaults to
74
+ 1.
75
+ """
47
76
def __init__ (self , normal_label : int = 0 , outlier_label : int = 1 ):
48
77
self ._normal_label = normal_label
49
78
self ._outlier_label = outlier_label
50
79
51
80
@property
52
81
@abc .abstractmethod
53
82
def is_stateful (self ) -> bool :
83
+ """Indicates whether the threshold function is stateful or not."""
54
84
raise NotImplementedError
55
85
56
86
@property
57
87
@abc .abstractmethod
58
88
def threshold (self ) -> Optional [float ]:
89
+ """Retrieves the current threshold value, or None if not set."""
59
90
raise NotImplementedError
60
91
61
92
@abc .abstractmethod
62
93
def apply (self , score : Optional [float ]) -> int :
94
+ """Applies the threshold function to a given score to classify it as
95
+ normal or outlier.
96
+
97
+ Args:
98
+ score: The outlier score generated from the detector (model).
99
+
100
+ Returns:
101
+ The label assigned to the score, either `self._normal_label`
102
+ or `self._outlier_label`
103
+ """
63
104
raise NotImplementedError
64
105
65
106
66
107
class AggregationFn (abc .ABC ):
108
+ """An abstract base class for aggregation functions."""
67
109
@abc .abstractmethod
68
110
def apply (
69
111
self , predictions : Iterable [AnomalyPrediction ]) -> AnomalyPrediction :
112
+ """Applies the aggregation function to an iterable of predictions, either on
113
+ their outlier scores or labels.
114
+
115
+ Args:
116
+ predictions: An Iterable of `AnomalyPrediction` objects to aggregate.
117
+
118
+ Returns:
119
+ An `AnomalyPrediction` object containing the aggregated result.
120
+ """
70
121
raise NotImplementedError
71
122
72
123
73
124
class AnomalyDetector (abc .ABC ):
125
+ """An abstract base class for anomaly detectors.
126
+
127
+ Args:
128
+ model_id: The ID of detector (model). Defaults to the value of the
129
+ `spec_type` attribute, or 'unknown' if not set.
130
+ features: An Iterable of strings representing the names of the input
131
+ features in the `beam.Row`
132
+ target: The name of the target field in the `beam.Row`.
133
+ threshold_criterion: An optional `ThresholdFn` to apply to the outlier score
134
+ and yield a label.
135
+ """
74
136
def __init__ (
75
137
self ,
76
138
model_id : Optional [str ] = None ,
@@ -79,36 +141,71 @@ def __init__(
79
141
threshold_criterion : Optional [ThresholdFn ] = None ,
80
142
** kwargs ):
81
143
self ._model_id = model_id if model_id is not None else getattr (
82
- self , '_key ' , 'unknown' )
144
+ self , 'spec_type ' , 'unknown' )
83
145
self ._features = features
84
146
self ._target = target
85
147
self ._threshold_criterion = threshold_criterion
86
148
87
149
@abc .abstractmethod
88
150
def learn_one (self , x : beam .Row ) -> None :
151
+ """Trains the detector on a single data instance.
152
+
153
+ Args:
154
+ x: A `beam.Row` representing the data instance.
155
+ """
89
156
raise NotImplementedError
90
157
91
158
@abc .abstractmethod
92
159
def score_one (self , x : beam .Row ) -> float :
160
+ """Scores a single data instance for anomalies.
161
+
162
+ Args:
163
+ x: A `beam.Row` representing the data instance.
164
+
165
+ Returns:
166
+ The outlier score as a float.
167
+ """
93
168
raise NotImplementedError
94
169
95
170
96
171
class EnsembleAnomalyDetector (AnomalyDetector ):
172
+ """An abstract base class for an ensemble of anomaly (sub-)detectors.
173
+
174
+ Args:
175
+ sub_detectors: A List of `AnomalyDetector` used in this ensemble model.
176
+ aggregation_strategy: An optional `AggregationFn` to apply to the
177
+ predictions from all sub-detectors and yield an aggregated result.
178
+ model_id: Inherited from `AnomalyDetector`.
179
+ features: Inherited from `AnomalyDetector`.
180
+ target: Inherited from `AnomalyDetector`.
181
+ threshold_criterion: Inherited from `AnomalyDetector`.
182
+ """
97
183
def __init__ (
98
184
self ,
99
185
sub_detectors : Optional [List [AnomalyDetector ]] = None ,
100
186
aggregation_strategy : Optional [AggregationFn ] = None ,
101
187
** kwargs ):
102
188
if "model_id" not in kwargs or kwargs ["model_id" ] is None :
103
- kwargs ["model_id" ] = getattr (self , '_key ' , 'custom' )
189
+ kwargs ["model_id" ] = getattr (self , 'spec_type ' , 'custom' )
104
190
105
191
super ().__init__ (** kwargs )
106
192
107
193
self ._aggregation_strategy = aggregation_strategy
108
194
self ._sub_detectors = sub_detectors
109
195
110
196
def learn_one (self , x : beam .Row ) -> None :
197
+ """Inherited from `AnomalyDetector.learn_one`.
198
+
199
+ This method is never called during ensemble detector training. The training
200
+ process is done on each sub-detector independently and in parallel.
201
+ """
111
202
raise NotImplementedError
112
203
113
204
def score_one (self , x : beam .Row ) -> float :
205
+ """Inherited from `AnomalyDetector.score_one`.
206
+
207
+ This method is never called during ensemble detector scoring. The scoring
208
+ process is done on sub-detector independently and in parallel, and then
209
+ the results are aggregated in the pipeline.
210
+ """
114
211
raise NotImplementedError
0 commit comments