Skip to content

Commit 9b7b5aa

Browse files
Rmspe test stat (#22)
* Minor change in README to fix guidance for developers (#18) * Noise transform (#19) * Add noise transformation that apply perturbations on treatment * Formatting * Add docstring * Fix linting * Add tests to check perturbation impact and randomness over timepoints * bump version (#20) * Initial implementation of RMSPE * Add TestResultFrame parent class for test results * Add test for RMSPE * Add doc string * Fix linting * Update src/causal_validation/validation/rmspe.py Co-authored-by: Thomas Pinder <[email protected]> * Fix typo --------- Co-authored-by: Thomas Pinder <[email protected]>
1 parent b7edd4e commit 9b7b5aa

File tree

14 files changed

+557
-42
lines changed

14 files changed

+557
-42
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,4 @@ in your terminal.
7676
1. Follow steps 1-3 from `For Users`
7777
2. Create a hatch environment `hatch env create`
7878
3. Open a hatch shell `hatch shell`
79-
4. Validate your installation by running `hatch run tests:test`
79+
4. Validate your installation by running `hatch run dev:test`

src/causal_validation/__about__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "0.0.7"
1+
__version__ = "0.0.8"
22

33
__all__ = ["__version__"]

src/causal_validation/data.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class Dataset:
2727
yte: Float[np.ndarray, "M 1"]
2828
_start_date: dt.date
2929
counterfactual: tp.Optional[Float[np.ndarray, "M 1"]] = None
30+
synthetic: tp.Optional[Float[np.ndarray, "M 1"]] = None
3031
_name: str = None
3132

3233
def to_df(
@@ -151,7 +152,13 @@ def drop_unit(self, idx: int) -> Dataset:
151152
Xtr = np.delete(self.Xtr, [idx], axis=1)
152153
Xte = np.delete(self.Xte, [idx], axis=1)
153154
return Dataset(
154-
Xtr, Xte, self.ytr, self.yte, self._start_date, self.counterfactual
155+
Xtr,
156+
Xte,
157+
self.ytr,
158+
self.yte,
159+
self._start_date,
160+
self.counterfactual,
161+
self.synthetic,
155162
)
156163

157164
def to_placebo_data(self, to_treat_idx: int) -> Dataset:
@@ -204,4 +211,6 @@ def reassign_treatment(
204211
) -> Dataset:
205212
Xtr = data.Xtr
206213
Xte = data.Xte
207-
return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual)
214+
return Dataset(
215+
Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic
216+
)

src/causal_validation/models.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
class Result:
1616
effect: Effect
1717
counterfactual: Float[NPArray, "N 1"]
18+
synthetic: Float[NPArray, "N 1"]
1819
observed: Float[NPArray, "N 1"]
1920

2021

2122
@dataclass
2223
class AZCausalWrapper:
2324
model: Estimator
2425
error_estimator: tp.Optional[Error] = None
25-
_az_result: Result = None
26+
_az_result: _Result = None
2627

2728
def __post_init__(self):
2829
self._model_name = self.model.__class__.__name__
@@ -37,6 +38,7 @@ def __call__(self, data: Dataset, **kwargs) -> Result:
3738
res = Result(
3839
effect=result.effect,
3940
counterfactual=self.counterfactual,
41+
synthetic=self.synthetic,
4042
observed=self.observed,
4143
)
4244
return res
@@ -47,6 +49,12 @@ def counterfactual(self) -> Float[NPArray, "N 1"]:
4749
c_factual = df.loc[:, "CF"].values.reshape(-1, 1)
4850
return c_factual
4951

52+
@property
53+
def synthetic(self) -> Float[NPArray, "N 1"]:
54+
df = self._az_result.effect.by_time
55+
synth_control = df.loc[:, "C"].values.reshape(-1, 1)
56+
return synth_control
57+
5058
@property
5159
def observed(self) -> Float[NPArray, "N 1"]:
5260
df = self._az_result.effect.by_time
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1+
from causal_validation.transforms.noise import Noise
12
from causal_validation.transforms.periodic import Periodic
23
from causal_validation.transforms.trends import Trend
34

4-
__all__ = ["Trend", "Periodic"]
5+
__all__ = ["Trend", "Periodic", "Noise"]

src/causal_validation/transforms/base.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def apply_values(
7474
ytr = ytr + pre_intervention_vals[:, :1]
7575
Xte = Xte + post_intervention_vals[:, 1:]
7676
yte = yte + post_intervention_vals[:, :1]
77-
return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual)
77+
return Dataset(
78+
Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic
79+
)
7880

7981

8082
@dataclass(kw_only=True)
@@ -91,4 +93,6 @@ def apply_values(
9193
ytr = ytr * pre_intervention_vals
9294
Xte = Xte * post_intervention_vals
9395
yte = yte * post_intervention_vals
94-
return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual)
96+
return Dataset(
97+
Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic
98+
)
+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from dataclasses import dataclass
2+
from typing import Tuple
3+
4+
from jaxtyping import Float
5+
import numpy as np
6+
from scipy.stats import norm
7+
8+
from causal_validation.data import Dataset
9+
from causal_validation.transforms.base import AdditiveTransform
10+
from causal_validation.transforms.parameter import TimeVaryingParameter
11+
12+
13+
@dataclass(kw_only=True)
14+
class Noise(AdditiveTransform):
15+
"""
16+
Transform the treatment by adding TimeVaryingParameter noise terms sampled from
17+
a specified sampling distribution. By default, the sampling distribution is
18+
Normal with 0 loc and 0.1 scale.
19+
"""
20+
21+
noise_dist: TimeVaryingParameter = TimeVaryingParameter(sampling_dist=norm(0, 0.1))
22+
_slots: Tuple[str] = ("noise_dist",)
23+
24+
def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]:
25+
noise = np.zeros((data.n_timepoints, data.n_units + 1))
26+
noise_treatment = self.noise_dist.get_value(
27+
n_units=1, n_timepoints=data.n_timepoints
28+
).reshape(-1)
29+
noise[:, 0] = noise_treatment
30+
return noise

src/causal_validation/validation/placebo.py

+2-19
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,11 @@
99
Column,
1010
DataFrameSchema,
1111
)
12-
from rich import box
1312
from rich.progress import (
1413
Progress,
1514
ProgressBar,
1615
track,
1716
)
18-
from rich.table import Table
1917
from scipy.stats import ttest_1samp
2018
from tqdm import (
2119
tqdm,
@@ -30,6 +28,7 @@
3028
AZCausalWrapper,
3129
Result,
3230
)
31+
from causal_validation.validation.testing import TestResultFrame
3332

3433
PlaceboSchema = DataFrameSchema(
3534
{
@@ -46,13 +45,12 @@
4645

4746

4847
@dataclass
49-
class PlaceboTestResult:
48+
class PlaceboTestResult(TestResultFrame):
5049
effects: tp.Dict[tp.Tuple[str, str], tp.List[Result]]
5150

5251
def _model_to_df(
5352
self, model_name: str, dataset_name: str, effects: tp.List[Result]
5453
) -> pd.DataFrame:
55-
breakpoint()
5654
_effects = [e.effect.percentage().value for e in effects]
5755
_n_effects = len(_effects)
5856
expected_effect = np.mean(_effects)
@@ -79,21 +77,6 @@ def to_df(self) -> pd.DataFrame:
7977
PlaceboSchema.validate(df)
8078
return df
8179

82-
def summary(self, precision: int = 4) -> Table:
83-
table = Table(show_header=True, box=box.MARKDOWN)
84-
df = self.to_df()
85-
numeric_cols = df.select_dtypes(include=[np.number])
86-
df.loc[:, numeric_cols.columns] = np.round(numeric_cols, decimals=precision)
87-
88-
for column in df.columns:
89-
table.add_column(str(column), style="magenta")
90-
91-
for _, value_list in enumerate(df.values.tolist()):
92-
row = [str(x) for x in value_list]
93-
table.add_row(*row)
94-
95-
return table
96-
9780

9881
@dataclass
9982
class PlaceboTest:
+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from dataclasses import dataclass
2+
import typing as tp
3+
4+
from jaxtyping import Float
5+
import numpy as np
6+
import pandas as pd
7+
from pandera import (
8+
Check,
9+
Column,
10+
DataFrameSchema,
11+
)
12+
from rich import box
13+
from rich.progress import (
14+
Progress,
15+
ProgressBar,
16+
track,
17+
)
18+
19+
from causal_validation.validation.placebo import PlaceboTest
20+
from causal_validation.validation.testing import (
21+
RMSPETestStatistic,
22+
TestResult,
23+
TestResultFrame,
24+
)
25+
26+
RMSPESchema = DataFrameSchema(
27+
{
28+
"Model": Column(str),
29+
"Dataset": Column(str),
30+
"Test statistic": Column(float, coerce=True),
31+
"p-value": Column(
32+
float,
33+
checks=[
34+
Check.greater_than_or_equal_to(0.0),
35+
Check.less_than_or_equal_to(1.0),
36+
],
37+
coerce=True,
38+
),
39+
}
40+
)
41+
42+
43+
@dataclass
44+
class RMSPETestResult(TestResultFrame):
45+
"""
46+
A subclass of TestResultFrame, RMSPETestResult stores test statistics and p-value
47+
for the treated unit. Test statistics for pseudo treatment units are also stored.
48+
"""
49+
50+
treatment_test_results: tp.Dict[tp.Tuple[str, str], TestResult]
51+
pseudo_treatment_test_statistics: tp.Dict[tp.Tuple[str, str], tp.List[Float]]
52+
53+
def to_df(self) -> pd.DataFrame:
54+
dfs = []
55+
for (model, dataset), test_results in self.treatment_test_results.items():
56+
result = {
57+
"Model": model,
58+
"Dataset": dataset,
59+
"Test statistic": test_results.test_statistic,
60+
"p-value": test_results.p_value,
61+
}
62+
df = pd.DataFrame([result])
63+
dfs.append(df)
64+
df = pd.concat(dfs)
65+
RMSPESchema.validate(df)
66+
return df
67+
68+
69+
@dataclass
70+
class RMSPETest(PlaceboTest):
71+
"""
72+
A subclass of PlaceboTest calculates RMSPE as test statistic for all units.
73+
Given the RMSPE test stats, p-value for actual treatment is calculated.
74+
"""
75+
76+
def execute(self, verbose: bool = True) -> RMSPETestResult:
77+
treatment_results, pseudo_treatment_results = {}, {}
78+
datasets = self.dataset_dict
79+
n_datasets = len(datasets)
80+
n_control = sum([d.n_units for d in datasets.values()])
81+
rmspe = RMSPETestStatistic()
82+
with Progress(disable=not verbose) as progress:
83+
model_task = progress.add_task(
84+
"[red]Models", total=len(self.models), visible=verbose
85+
)
86+
data_task = progress.add_task(
87+
"[blue]Datasets", total=n_datasets, visible=verbose
88+
)
89+
unit_task = progress.add_task(
90+
f"[green]Treatment and Control Units",
91+
total=n_control + 1,
92+
visible=verbose,
93+
)
94+
for data_name, dataset in datasets.items():
95+
progress.update(data_task, advance=1)
96+
for model in self.models:
97+
progress.update(unit_task, advance=1)
98+
treatment_result = model(dataset)
99+
treatment_idx = dataset.ytr.shape[0]
100+
treatment_test_stat = rmspe(
101+
dataset,
102+
treatment_result.counterfactual,
103+
treatment_result.synthetic,
104+
treatment_idx,
105+
)
106+
progress.update(model_task, advance=1)
107+
placebo_test_stats = []
108+
for i in range(dataset.n_units):
109+
progress.update(unit_task, advance=1)
110+
placebo_data = dataset.to_placebo_data(i)
111+
result = model(placebo_data)
112+
placebo_test_stats.append(
113+
rmspe(
114+
placebo_data,
115+
result.counterfactual,
116+
result.synthetic,
117+
treatment_idx,
118+
)
119+
)
120+
pval_idx = 1
121+
for p_stat in placebo_test_stats:
122+
pval_idx += 1 if treatment_test_stat < p_stat else 0
123+
pval = pval_idx / (n_control + 1)
124+
treatment_results[(model._model_name, data_name)] = TestResult(
125+
p_value=pval, test_statistic=treatment_test_stat
126+
)
127+
pseudo_treatment_results[(model._model_name, data_name)] = (
128+
placebo_test_stats
129+
)
130+
return RMSPETestResult(
131+
treatment_test_results=treatment_results,
132+
pseudo_treatment_test_statistics=pseudo_treatment_results,
133+
)

0 commit comments

Comments
 (0)