Skip to content

Commit d473a72

Browse files
committed
Merge branch 'main' into rmspe-test-stat
2 parents b7edd4e + 8b44c71 commit d473a72

File tree

5 files changed

+161
-3
lines changed

5 files changed

+161
-3
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,4 @@ in your terminal.
7676
1. Follow steps 1-3 from `For Users`
7777
2. Create a hatch environment `hatch env create`
7878
3. Open a hatch shell `hatch shell`
79-
4. Validate your installation by running `hatch run tests:test`
79+
4. Validate your installation by running `hatch run dev:test`

src/causal_validation/__about__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "0.0.7"
1+
__version__ = "0.0.8"
22

33
__all__ = ["__version__"]
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1+
from causal_validation.transforms.noise import Noise
12
from causal_validation.transforms.periodic import Periodic
23
from causal_validation.transforms.trends import Trend
34

4-
__all__ = ["Trend", "Periodic"]
5+
__all__ = ["Trend", "Periodic", "Noise"]
+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from dataclasses import dataclass
2+
from typing import Tuple
3+
4+
from jaxtyping import Float
5+
import numpy as np
6+
from scipy.stats import norm
7+
8+
from causal_validation.data import Dataset
9+
from causal_validation.transforms.base import AdditiveTransform
10+
from causal_validation.transforms.parameter import TimeVaryingParameter
11+
12+
13+
@dataclass(kw_only=True)
14+
class Noise(AdditiveTransform):
15+
"""
16+
Transform the treatment by adding TimeVaryingParameter noise terms sampled from
17+
a specified sampling distribution. By default, the sampling distribution is
18+
Normal with 0 loc and 0.1 scale.
19+
"""
20+
21+
noise_dist: TimeVaryingParameter = TimeVaryingParameter(sampling_dist=norm(0, 0.1))
22+
_slots: Tuple[str] = ("noise_dist",)
23+
24+
def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]:
25+
noise = np.zeros((data.n_timepoints, data.n_units + 1))
26+
noise_treatment = self.noise_dist.get_value(
27+
n_units=1, n_timepoints=data.n_timepoints
28+
).reshape(-1)
29+
noise[:, 0] = noise_treatment
30+
return noise
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from hypothesis import (
2+
given,
3+
settings,
4+
strategies as st,
5+
)
6+
import numpy as np
7+
from scipy.stats import norm
8+
9+
from causal_validation.testing import (
10+
TestConstants,
11+
simulate_data,
12+
)
13+
from causal_validation.transforms import (
14+
Noise,
15+
Trend,
16+
)
17+
from causal_validation.transforms.parameter import TimeVaryingParameter
18+
19+
CONSTANTS = TestConstants()
20+
DEFAULT_SEED = 123
21+
GLOBAL_MEAN = 20
22+
STATES = [42, 123]
23+
24+
25+
def test_slot_type():
26+
noise_transform = Noise()
27+
assert isinstance(noise_transform.noise_dist, TimeVaryingParameter)
28+
29+
30+
def test_timepoints_randomness():
31+
base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED)
32+
33+
noise_transform = Noise()
34+
noisy_data = noise_transform(base_data)
35+
36+
diff_tr = (noisy_data.ytr - base_data.ytr).reshape(-1)
37+
diff_te = (noisy_data.yte - base_data.yte).reshape(-1)
38+
39+
assert np.all(diff_tr != diff_te)
40+
41+
diff_tr_permute = np.random.permutation(diff_tr)
42+
diff_te_permute = np.random.permutation(diff_te)
43+
44+
assert not np.all(diff_tr == diff_tr_permute)
45+
assert not np.all(diff_te == diff_te_permute)
46+
47+
48+
@given(
49+
loc=st.floats(min_value=-5.0, max_value=5.0),
50+
scale=st.floats(min_value=0.1, max_value=1.0),
51+
)
52+
@settings(max_examples=5)
53+
def test_base_transform(loc: float, scale: float):
54+
base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED)
55+
noise_transform = Noise(
56+
noise_dist=TimeVaryingParameter(sampling_dist=norm(loc, scale))
57+
)
58+
noisy_data = noise_transform(base_data)
59+
60+
assert np.all(noisy_data.Xtr == base_data.Xtr)
61+
assert np.all(noisy_data.Xte == base_data.Xte)
62+
assert np.all(noisy_data.ytr != base_data.ytr)
63+
assert np.all(noisy_data.yte != base_data.yte)
64+
65+
66+
@given(
67+
degree=st.integers(min_value=1, max_value=3),
68+
coefficient=st.floats(min_value=-1.0, max_value=1.0),
69+
intercept=st.floats(min_value=-1.0, max_value=1.0),
70+
)
71+
@settings(max_examples=5)
72+
def test_composite_transform(degree: int, coefficient: float, intercept: float):
73+
trend_transform = Trend(degree=degree, coefficient=coefficient, intercept=intercept)
74+
base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED)
75+
trendy_data = trend_transform(base_data)
76+
77+
noise_transform = Noise()
78+
noisy_trendy_data = noise_transform(trendy_data)
79+
80+
assert np.all(noisy_trendy_data.Xtr == trendy_data.Xtr)
81+
assert np.all(noisy_trendy_data.Xte == trendy_data.Xte)
82+
assert np.all(noisy_trendy_data.ytr != trendy_data.ytr)
83+
assert np.all(noisy_trendy_data.yte != trendy_data.yte)
84+
85+
86+
@given(
87+
loc_large=st.floats(min_value=10.0, max_value=15.0),
88+
loc_small=st.floats(min_value=-2.5, max_value=2.5),
89+
scale_large=st.floats(min_value=10.0, max_value=15.0),
90+
scale_small=st.floats(min_value=0.1, max_value=1.0),
91+
)
92+
@settings(max_examples=5)
93+
def test_perturbation_impact(
94+
loc_large: float, loc_small: float, scale_large: float, scale_small: float
95+
):
96+
base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED)
97+
98+
noise_transform1 = Noise(
99+
noise_dist=TimeVaryingParameter(sampling_dist=norm(loc_small, scale_small))
100+
)
101+
noise_transform2 = Noise(
102+
noise_dist=TimeVaryingParameter(sampling_dist=norm(loc_small, scale_large))
103+
)
104+
noise_transform3 = Noise(
105+
noise_dist=TimeVaryingParameter(sampling_dist=norm(loc_large, scale_small))
106+
)
107+
108+
noise_transforms = [noise_transform1, noise_transform2, noise_transform3]
109+
110+
diff_tr_list, diff_te_list = [], []
111+
112+
for noise_transform in noise_transforms:
113+
noisy_data = noise_transform(base_data)
114+
diff_tr = noisy_data.ytr - base_data.ytr
115+
diff_te = noisy_data.yte - base_data.yte
116+
diff_tr_list.append(diff_tr)
117+
diff_te_list.append(diff_te)
118+
119+
assert np.max(diff_tr_list[0]) < np.max(diff_tr_list[1])
120+
assert np.min(diff_tr_list[0]) > np.min(diff_tr_list[1])
121+
assert np.max(diff_tr_list[0]) < np.max(diff_tr_list[2])
122+
assert np.min(diff_tr_list[0]) < np.min(diff_tr_list[2])
123+
124+
assert np.max(diff_te_list[0]) < np.max(diff_te_list[1])
125+
assert np.min(diff_te_list[0]) > np.min(diff_te_list[1])
126+
assert np.max(diff_te_list[0]) < np.max(diff_te_list[2])
127+
assert np.min(diff_te_list[0]) < np.min(diff_te_list[2])

0 commit comments

Comments
 (0)