|
| 1 | +from hypothesis import ( |
| 2 | + given, |
| 3 | + settings, |
| 4 | + strategies as st, |
| 5 | +) |
| 6 | +import numpy as np |
| 7 | +from scipy.stats import norm |
| 8 | + |
| 9 | +from causal_validation.testing import ( |
| 10 | + TestConstants, |
| 11 | + simulate_data, |
| 12 | +) |
| 13 | +from causal_validation.transforms import ( |
| 14 | + Noise, |
| 15 | + Trend, |
| 16 | +) |
| 17 | +from causal_validation.transforms.parameter import TimeVaryingParameter |
| 18 | + |
| 19 | +CONSTANTS = TestConstants() |
| 20 | +DEFAULT_SEED = 123 |
| 21 | +GLOBAL_MEAN = 20 |
| 22 | +STATES = [42, 123] |
| 23 | + |
| 24 | + |
| 25 | +def test_slot_type(): |
| 26 | + noise_transform = Noise() |
| 27 | + assert isinstance(noise_transform.noise_dist, TimeVaryingParameter) |
| 28 | + |
| 29 | + |
| 30 | +def test_timepoints_randomness(): |
| 31 | + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED) |
| 32 | + |
| 33 | + noise_transform = Noise() |
| 34 | + noisy_data = noise_transform(base_data) |
| 35 | + |
| 36 | + diff_tr = (noisy_data.ytr - base_data.ytr).reshape(-1) |
| 37 | + diff_te = (noisy_data.yte - base_data.yte).reshape(-1) |
| 38 | + |
| 39 | + assert np.all(diff_tr != diff_te) |
| 40 | + |
| 41 | + diff_tr_permute = np.random.permutation(diff_tr) |
| 42 | + diff_te_permute = np.random.permutation(diff_te) |
| 43 | + |
| 44 | + assert not np.all(diff_tr == diff_tr_permute) |
| 45 | + assert not np.all(diff_te == diff_te_permute) |
| 46 | + |
| 47 | + |
| 48 | +@given( |
| 49 | + loc=st.floats(min_value=-5.0, max_value=5.0), |
| 50 | + scale=st.floats(min_value=0.1, max_value=1.0), |
| 51 | +) |
| 52 | +@settings(max_examples=5) |
| 53 | +def test_base_transform(loc: float, scale: float): |
| 54 | + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED) |
| 55 | + noise_transform = Noise( |
| 56 | + noise_dist=TimeVaryingParameter(sampling_dist=norm(loc, scale)) |
| 57 | + ) |
| 58 | + noisy_data = noise_transform(base_data) |
| 59 | + |
| 60 | + assert np.all(noisy_data.Xtr == base_data.Xtr) |
| 61 | + assert np.all(noisy_data.Xte == base_data.Xte) |
| 62 | + assert np.all(noisy_data.ytr != base_data.ytr) |
| 63 | + assert np.all(noisy_data.yte != base_data.yte) |
| 64 | + |
| 65 | + |
| 66 | +@given( |
| 67 | + degree=st.integers(min_value=1, max_value=3), |
| 68 | + coefficient=st.floats(min_value=-1.0, max_value=1.0), |
| 69 | + intercept=st.floats(min_value=-1.0, max_value=1.0), |
| 70 | +) |
| 71 | +@settings(max_examples=5) |
| 72 | +def test_composite_transform(degree: int, coefficient: float, intercept: float): |
| 73 | + trend_transform = Trend(degree=degree, coefficient=coefficient, intercept=intercept) |
| 74 | + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED) |
| 75 | + trendy_data = trend_transform(base_data) |
| 76 | + |
| 77 | + noise_transform = Noise() |
| 78 | + noisy_trendy_data = noise_transform(trendy_data) |
| 79 | + |
| 80 | + assert np.all(noisy_trendy_data.Xtr == trendy_data.Xtr) |
| 81 | + assert np.all(noisy_trendy_data.Xte == trendy_data.Xte) |
| 82 | + assert np.all(noisy_trendy_data.ytr != trendy_data.ytr) |
| 83 | + assert np.all(noisy_trendy_data.yte != trendy_data.yte) |
| 84 | + |
| 85 | + |
| 86 | +@given( |
| 87 | + loc_large=st.floats(min_value=10.0, max_value=15.0), |
| 88 | + loc_small=st.floats(min_value=-2.5, max_value=2.5), |
| 89 | + scale_large=st.floats(min_value=10.0, max_value=15.0), |
| 90 | + scale_small=st.floats(min_value=0.1, max_value=1.0), |
| 91 | +) |
| 92 | +@settings(max_examples=5) |
| 93 | +def test_perturbation_impact( |
| 94 | + loc_large: float, loc_small: float, scale_large: float, scale_small: float |
| 95 | +): |
| 96 | + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED) |
| 97 | + |
| 98 | + noise_transform1 = Noise( |
| 99 | + noise_dist=TimeVaryingParameter(sampling_dist=norm(loc_small, scale_small)) |
| 100 | + ) |
| 101 | + noise_transform2 = Noise( |
| 102 | + noise_dist=TimeVaryingParameter(sampling_dist=norm(loc_small, scale_large)) |
| 103 | + ) |
| 104 | + noise_transform3 = Noise( |
| 105 | + noise_dist=TimeVaryingParameter(sampling_dist=norm(loc_large, scale_small)) |
| 106 | + ) |
| 107 | + |
| 108 | + noise_transforms = [noise_transform1, noise_transform2, noise_transform3] |
| 109 | + |
| 110 | + diff_tr_list, diff_te_list = [], [] |
| 111 | + |
| 112 | + for noise_transform in noise_transforms: |
| 113 | + noisy_data = noise_transform(base_data) |
| 114 | + diff_tr = noisy_data.ytr - base_data.ytr |
| 115 | + diff_te = noisy_data.yte - base_data.yte |
| 116 | + diff_tr_list.append(diff_tr) |
| 117 | + diff_te_list.append(diff_te) |
| 118 | + |
| 119 | + assert np.max(diff_tr_list[0]) < np.max(diff_tr_list[1]) |
| 120 | + assert np.min(diff_tr_list[0]) > np.min(diff_tr_list[1]) |
| 121 | + assert np.max(diff_tr_list[0]) < np.max(diff_tr_list[2]) |
| 122 | + assert np.min(diff_tr_list[0]) < np.min(diff_tr_list[2]) |
| 123 | + |
| 124 | + assert np.max(diff_te_list[0]) < np.max(diff_te_list[1]) |
| 125 | + assert np.min(diff_te_list[0]) > np.min(diff_te_list[1]) |
| 126 | + assert np.max(diff_te_list[0]) < np.max(diff_te_list[2]) |
| 127 | + assert np.min(diff_te_list[0]) < np.min(diff_te_list[2]) |
0 commit comments