diff --git a/README.md b/README.md index 6ba43fc..0e26876 100644 --- a/README.md +++ b/README.md @@ -76,4 +76,4 @@ in your terminal. 1. Follow steps 1-3 from `For Users` 2. Create a hatch environment `hatch env create` 3. Open a hatch shell `hatch shell` -4. Validate your installation by running `hatch run tests:test` +4. Validate your installation by running `hatch run dev:test` diff --git a/src/causal_validation/__about__.py b/src/causal_validation/__about__.py index 30d4839..56844dc 100644 --- a/src/causal_validation/__about__.py +++ b/src/causal_validation/__about__.py @@ -1,3 +1,3 @@ -__version__ = "0.0.7" +__version__ = "0.0.8" __all__ = ["__version__"] diff --git a/src/causal_validation/data.py b/src/causal_validation/data.py index a25a94c..057acfa 100644 --- a/src/causal_validation/data.py +++ b/src/causal_validation/data.py @@ -27,6 +27,7 @@ class Dataset: yte: Float[np.ndarray, "M 1"] _start_date: dt.date counterfactual: tp.Optional[Float[np.ndarray, "M 1"]] = None + synthetic: tp.Optional[Float[np.ndarray, "M 1"]] = None _name: str = None def to_df( @@ -151,7 +152,13 @@ def drop_unit(self, idx: int) -> Dataset: Xtr = np.delete(self.Xtr, [idx], axis=1) Xte = np.delete(self.Xte, [idx], axis=1) return Dataset( - Xtr, Xte, self.ytr, self.yte, self._start_date, self.counterfactual + Xtr, + Xte, + self.ytr, + self.yte, + self._start_date, + self.counterfactual, + self.synthetic, ) def to_placebo_data(self, to_treat_idx: int) -> Dataset: @@ -204,4 +211,6 @@ def reassign_treatment( ) -> Dataset: Xtr = data.Xtr Xte = data.Xte - return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual) + return Dataset( + Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic + ) diff --git a/src/causal_validation/models.py b/src/causal_validation/models.py index 7015ae1..499ffe1 100644 --- a/src/causal_validation/models.py +++ b/src/causal_validation/models.py @@ -15,6 +15,7 @@ class Result: effect: Effect counterfactual: Float[NPArray, "N 1"] + synthetic: Float[NPArray, "N 1"] observed: Float[NPArray, "N 1"] @@ -22,7 +23,7 @@ class Result: class AZCausalWrapper: model: Estimator error_estimator: tp.Optional[Error] = None - _az_result: Result = None + _az_result: _Result = None def __post_init__(self): self._model_name = self.model.__class__.__name__ @@ -37,6 +38,7 @@ def __call__(self, data: Dataset, **kwargs) -> Result: res = Result( effect=result.effect, counterfactual=self.counterfactual, + synthetic=self.synthetic, observed=self.observed, ) return res @@ -47,6 +49,12 @@ def counterfactual(self) -> Float[NPArray, "N 1"]: c_factual = df.loc[:, "CF"].values.reshape(-1, 1) return c_factual + @property + def synthetic(self) -> Float[NPArray, "N 1"]: + df = self._az_result.effect.by_time + synth_control = df.loc[:, "C"].values.reshape(-1, 1) + return synth_control + @property def observed(self) -> Float[NPArray, "N 1"]: df = self._az_result.effect.by_time diff --git a/src/causal_validation/transforms/__init__.py b/src/causal_validation/transforms/__init__.py index 2707bfc..c5cf07f 100644 --- a/src/causal_validation/transforms/__init__.py +++ b/src/causal_validation/transforms/__init__.py @@ -1,4 +1,5 @@ +from causal_validation.transforms.noise import Noise from causal_validation.transforms.periodic import Periodic from causal_validation.transforms.trends import Trend -__all__ = ["Trend", "Periodic"] +__all__ = ["Trend", "Periodic", "Noise"] diff --git a/src/causal_validation/transforms/base.py b/src/causal_validation/transforms/base.py index ea15109..6ef7a97 100644 --- a/src/causal_validation/transforms/base.py +++ b/src/causal_validation/transforms/base.py @@ -74,7 +74,9 @@ def apply_values( ytr = ytr + pre_intervention_vals[:, :1] Xte = Xte + post_intervention_vals[:, 1:] yte = yte + post_intervention_vals[:, :1] - return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual) + return Dataset( + Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic + ) @dataclass(kw_only=True) @@ -91,4 +93,6 @@ def apply_values( ytr = ytr * pre_intervention_vals Xte = Xte * post_intervention_vals yte = yte * post_intervention_vals - return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual) + return Dataset( + Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic + ) diff --git a/src/causal_validation/transforms/noise.py b/src/causal_validation/transforms/noise.py new file mode 100644 index 0000000..cc3fe7c --- /dev/null +++ b/src/causal_validation/transforms/noise.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass +from typing import Tuple + +from jaxtyping import Float +import numpy as np +from scipy.stats import norm + +from causal_validation.data import Dataset +from causal_validation.transforms.base import AdditiveTransform +from causal_validation.transforms.parameter import TimeVaryingParameter + + +@dataclass(kw_only=True) +class Noise(AdditiveTransform): + """ + Transform the treatment by adding TimeVaryingParameter noise terms sampled from + a specified sampling distribution. By default, the sampling distribution is + Normal with 0 loc and 0.1 scale. + """ + + noise_dist: TimeVaryingParameter = TimeVaryingParameter(sampling_dist=norm(0, 0.1)) + _slots: Tuple[str] = ("noise_dist",) + + def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]: + noise = np.zeros((data.n_timepoints, data.n_units + 1)) + noise_treatment = self.noise_dist.get_value( + n_units=1, n_timepoints=data.n_timepoints + ).reshape(-1) + noise[:, 0] = noise_treatment + return noise diff --git a/src/causal_validation/validation/placebo.py b/src/causal_validation/validation/placebo.py index 5e334d0..b8f7c36 100644 --- a/src/causal_validation/validation/placebo.py +++ b/src/causal_validation/validation/placebo.py @@ -9,13 +9,11 @@ Column, DataFrameSchema, ) -from rich import box from rich.progress import ( Progress, ProgressBar, track, ) -from rich.table import Table from scipy.stats import ttest_1samp from tqdm import ( tqdm, @@ -30,6 +28,7 @@ AZCausalWrapper, Result, ) +from causal_validation.validation.testing import TestResultFrame PlaceboSchema = DataFrameSchema( { @@ -46,13 +45,12 @@ @dataclass -class PlaceboTestResult: +class PlaceboTestResult(TestResultFrame): effects: tp.Dict[tp.Tuple[str, str], tp.List[Result]] def _model_to_df( self, model_name: str, dataset_name: str, effects: tp.List[Result] ) -> pd.DataFrame: - breakpoint() _effects = [e.effect.percentage().value for e in effects] _n_effects = len(_effects) expected_effect = np.mean(_effects) @@ -79,21 +77,6 @@ def to_df(self) -> pd.DataFrame: PlaceboSchema.validate(df) return df - def summary(self, precision: int = 4) -> Table: - table = Table(show_header=True, box=box.MARKDOWN) - df = self.to_df() - numeric_cols = df.select_dtypes(include=[np.number]) - df.loc[:, numeric_cols.columns] = np.round(numeric_cols, decimals=precision) - - for column in df.columns: - table.add_column(str(column), style="magenta") - - for _, value_list in enumerate(df.values.tolist()): - row = [str(x) for x in value_list] - table.add_row(*row) - - return table - @dataclass class PlaceboTest: diff --git a/src/causal_validation/validation/rmspe.py b/src/causal_validation/validation/rmspe.py new file mode 100644 index 0000000..6b541ff --- /dev/null +++ b/src/causal_validation/validation/rmspe.py @@ -0,0 +1,133 @@ +from dataclasses import dataclass +import typing as tp + +from jaxtyping import Float +import numpy as np +import pandas as pd +from pandera import ( + Check, + Column, + DataFrameSchema, +) +from rich import box +from rich.progress import ( + Progress, + ProgressBar, + track, +) + +from causal_validation.validation.placebo import PlaceboTest +from causal_validation.validation.testing import ( + RMSPETestStatistic, + TestResult, + TestResultFrame, +) + +RMSPESchema = DataFrameSchema( + { + "Model": Column(str), + "Dataset": Column(str), + "Test statistic": Column(float, coerce=True), + "p-value": Column( + float, + checks=[ + Check.greater_than_or_equal_to(0.0), + Check.less_than_or_equal_to(1.0), + ], + coerce=True, + ), + } +) + + +@dataclass +class RMSPETestResult(TestResultFrame): + """ + A subclass of TestResultFrame, RMSPETestResult stores test statistics and p-value + for the treated unit. Test statistics for pseudo treatment units are also stored. + """ + + treatment_test_results: tp.Dict[tp.Tuple[str, str], TestResult] + pseudo_treatment_test_statistics: tp.Dict[tp.Tuple[str, str], tp.List[Float]] + + def to_df(self) -> pd.DataFrame: + dfs = [] + for (model, dataset), test_results in self.treatment_test_results.items(): + result = { + "Model": model, + "Dataset": dataset, + "Test statistic": test_results.test_statistic, + "p-value": test_results.p_value, + } + df = pd.DataFrame([result]) + dfs.append(df) + df = pd.concat(dfs) + RMSPESchema.validate(df) + return df + + +@dataclass +class RMSPETest(PlaceboTest): + """ + A subclass of PlaceboTest calculates RMSPE as test statistic for all units. + Given the RMSPE test stats, p-value for actual treatment is calculated. + """ + + def execute(self, verbose: bool = True) -> RMSPETestResult: + treatment_results, pseudo_treatment_results = {}, {} + datasets = self.dataset_dict + n_datasets = len(datasets) + n_control = sum([d.n_units for d in datasets.values()]) + rmspe = RMSPETestStatistic() + with Progress(disable=not verbose) as progress: + model_task = progress.add_task( + "[red]Models", total=len(self.models), visible=verbose + ) + data_task = progress.add_task( + "[blue]Datasets", total=n_datasets, visible=verbose + ) + unit_task = progress.add_task( + f"[green]Treatment and Control Units", + total=n_control + 1, + visible=verbose, + ) + for data_name, dataset in datasets.items(): + progress.update(data_task, advance=1) + for model in self.models: + progress.update(unit_task, advance=1) + treatment_result = model(dataset) + treatment_idx = dataset.ytr.shape[0] + treatment_test_stat = rmspe( + dataset, + treatment_result.counterfactual, + treatment_result.synthetic, + treatment_idx, + ) + progress.update(model_task, advance=1) + placebo_test_stats = [] + for i in range(dataset.n_units): + progress.update(unit_task, advance=1) + placebo_data = dataset.to_placebo_data(i) + result = model(placebo_data) + placebo_test_stats.append( + rmspe( + placebo_data, + result.counterfactual, + result.synthetic, + treatment_idx, + ) + ) + pval_idx = 1 + for p_stat in placebo_test_stats: + pval_idx += 1 if treatment_test_stat < p_stat else 0 + pval = pval_idx / (n_control + 1) + treatment_results[(model._model_name, data_name)] = TestResult( + p_value=pval, test_statistic=treatment_test_stat + ) + pseudo_treatment_results[(model._model_name, data_name)] = ( + placebo_test_stats + ) + return RMSPETestResult( + treatment_test_results=treatment_results, + pseudo_treatment_test_statistics=pseudo_treatment_results, + ) diff --git a/src/causal_validation/validation/testing.py b/src/causal_validation/validation/testing.py index 0f50246..e0144a2 100644 --- a/src/causal_validation/validation/testing.py +++ b/src/causal_validation/validation/testing.py @@ -4,10 +4,37 @@ from jaxtyping import Float import numpy as np +import pandas as pd +from rich import box +from rich.table import Table from causal_validation.data import Dataset +@dataclass +class TestResultFrame: + """A parent class for test results""" + + @abc.abstractmethod + def to_df(self) -> pd.DataFrame: + raise NotImplementedError + + def summary(self, precision: int = 4) -> Table: + table = Table(show_header=True, box=box.MARKDOWN) + df = self.to_df() + numeric_cols = df.select_dtypes(include=[np.number]) + df.loc[:, numeric_cols.columns] = np.round(numeric_cols, decimals=precision) + + for column in df.columns: + table.add_column(str(column), style="magenta") + + for _, value_list in enumerate(df.values.tolist()): + row = [str(x) for x in value_list] + table.add_row(*row) + + return table + + @dataclass class TestResult: p_value: float @@ -21,41 +48,55 @@ def _compute( self, dataset: Dataset, counterfactual: Float[np.ndarray, "N 1"], + synthetic: tp.Optional[Float[np.ndarray, "M 1"]], treatment_index: int, - ) -> TestResult: + ) -> Float: raise NotImplementedError def __call__( self, observed: Float[np.ndarray, "N 1"], counterfactual: Float[np.ndarray, "N 1"], + synthetic: tp.Optional[Float[np.ndarray, "M 1"]], treatment_index: int, - ) -> TestResult: - return self._compute(observed, counterfactual, treatment_index) + ) -> Float: + return self._compute(observed, counterfactual, synthetic, treatment_index) @dataclass class RMSPETestStatistic(AbstractTestStatistic): + """ + Provided a dataset and treatment index together with counterfactuals and + synthetic control for the unit assigned to treatment, rmspe test statistic + is calculated. + """ + + @staticmethod def _compute( - self, dataset: Dataset, counterfactual: Float[np.ndarray, "N 1"], + synthetic: Float[np.ndarray, "N 1"], treatment_index: int, - ) -> TestResult: + ) -> Float: _, pre_observed = dataset.pre_intervention_obs _, post_observed = dataset.post_intervention_obs - pre_counterfactual, post_counterfactual = self._split_array( + _, post_counterfactual = RMSPETestStatistic._split_array( counterfactual, treatment_index ) - pre_rmspe = self._rmspe(pre_observed, pre_counterfactual) - post_rmspe = self._rmspe(post_observed, post_counterfactual) - test_statistic = post_rmspe / pre_rmspe + pre_synthetic, _ = RMSPETestStatistic._split_array(synthetic, treatment_index) + pre_rmspe = RMSPETestStatistic._rmspe(pre_observed, pre_synthetic) + post_rmspe = RMSPETestStatistic._rmspe(post_observed, post_counterfactual) + if pre_rmspe == 0: + raise ZeroDivisionError("Error: pre intervention period MSPE is 0!") + else: + test_statistic = post_rmspe / pre_rmspe + return test_statistic @staticmethod def _rmspe( - observed: Float[np.ndarray, "N 1"], counterfactual: Float[np.ndarray, "N 1"] + observed: Float[np.ndarray, "N 1"], generated: Float[np.ndarray, "N 1"] ) -> float: - return np.sqrt(np.mean(np.square(observed - counterfactual))) + return np.sqrt(np.mean(np.square(observed - generated))) @staticmethod def _split_array( diff --git a/tests/test_causal_validation/test_models.py b/tests/test_causal_validation/test_models.py index feeb1ce..ca479a4 100644 --- a/tests/test_causal_validation/test_models.py +++ b/tests/test_causal_validation/test_models.py @@ -7,7 +7,7 @@ JackKnife, ) from azcausal.core.estimator import Estimator -from azcausal.core.result import Result +from azcausal.core.result import Result as _Result from azcausal.estimators.panel import ( did, sdid, @@ -19,7 +19,10 @@ ) import numpy as np -from causal_validation.models import AZCausalWrapper +from causal_validation.models import ( + AZCausalWrapper, + Result, +) from causal_validation.testing import ( TestConstants, simulate_data, @@ -49,15 +52,20 @@ def test_call( n_post_treatment: int, seed: int, ): - constancts = TestConstants( + constants = TestConstants( N_CONTROL=n_control, N_PRE_TREATMENT=n_pre_treatment, N_POST_TREATMENT=n_post_treatment, ) - data = simulate_data(global_mean=10.0, seed=seed, constants=constancts) + data = simulate_data(global_mean=10.0, seed=seed, constants=constants) model = AZCausalWrapper(*model_error) result = model(data) assert isinstance(result, Result) assert isinstance(result.effect, Effect) assert not np.isnan(result.effect.value) + assert isinstance(model._az_result, _Result) + assert np.all(np.concatenate((data.ytr, data.yte), axis=0) == result.observed) + assert ( + result.observed.shape == result.counterfactual.shape == result.synthetic.shape + ) diff --git a/tests/test_causal_validation/test_transforms/test_noise.py b/tests/test_causal_validation/test_transforms/test_noise.py new file mode 100644 index 0000000..94608c3 --- /dev/null +++ b/tests/test_causal_validation/test_transforms/test_noise.py @@ -0,0 +1,127 @@ +from hypothesis import ( + given, + settings, + strategies as st, +) +import numpy as np +from scipy.stats import norm + +from causal_validation.testing import ( + TestConstants, + simulate_data, +) +from causal_validation.transforms import ( + Noise, + Trend, +) +from causal_validation.transforms.parameter import TimeVaryingParameter + +CONSTANTS = TestConstants() +DEFAULT_SEED = 123 +GLOBAL_MEAN = 20 +STATES = [42, 123] + + +def test_slot_type(): + noise_transform = Noise() + assert isinstance(noise_transform.noise_dist, TimeVaryingParameter) + + +def test_timepoints_randomness(): + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED) + + noise_transform = Noise() + noisy_data = noise_transform(base_data) + + diff_tr = (noisy_data.ytr - base_data.ytr).reshape(-1) + diff_te = (noisy_data.yte - base_data.yte).reshape(-1) + + assert np.all(diff_tr != diff_te) + + diff_tr_permute = np.random.permutation(diff_tr) + diff_te_permute = np.random.permutation(diff_te) + + assert not np.all(diff_tr == diff_tr_permute) + assert not np.all(diff_te == diff_te_permute) + + +@given( + loc=st.floats(min_value=-5.0, max_value=5.0), + scale=st.floats(min_value=0.1, max_value=1.0), +) +@settings(max_examples=5) +def test_base_transform(loc: float, scale: float): + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED) + noise_transform = Noise( + noise_dist=TimeVaryingParameter(sampling_dist=norm(loc, scale)) + ) + noisy_data = noise_transform(base_data) + + assert np.all(noisy_data.Xtr == base_data.Xtr) + assert np.all(noisy_data.Xte == base_data.Xte) + assert np.all(noisy_data.ytr != base_data.ytr) + assert np.all(noisy_data.yte != base_data.yte) + + +@given( + degree=st.integers(min_value=1, max_value=3), + coefficient=st.floats(min_value=-1.0, max_value=1.0), + intercept=st.floats(min_value=-1.0, max_value=1.0), +) +@settings(max_examples=5) +def test_composite_transform(degree: int, coefficient: float, intercept: float): + trend_transform = Trend(degree=degree, coefficient=coefficient, intercept=intercept) + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED) + trendy_data = trend_transform(base_data) + + noise_transform = Noise() + noisy_trendy_data = noise_transform(trendy_data) + + assert np.all(noisy_trendy_data.Xtr == trendy_data.Xtr) + assert np.all(noisy_trendy_data.Xte == trendy_data.Xte) + assert np.all(noisy_trendy_data.ytr != trendy_data.ytr) + assert np.all(noisy_trendy_data.yte != trendy_data.yte) + + +@given( + loc_large=st.floats(min_value=10.0, max_value=15.0), + loc_small=st.floats(min_value=-2.5, max_value=2.5), + scale_large=st.floats(min_value=10.0, max_value=15.0), + scale_small=st.floats(min_value=0.1, max_value=1.0), +) +@settings(max_examples=5) +def test_perturbation_impact( + loc_large: float, loc_small: float, scale_large: float, scale_small: float +): + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED) + + noise_transform1 = Noise( + noise_dist=TimeVaryingParameter(sampling_dist=norm(loc_small, scale_small)) + ) + noise_transform2 = Noise( + noise_dist=TimeVaryingParameter(sampling_dist=norm(loc_small, scale_large)) + ) + noise_transform3 = Noise( + noise_dist=TimeVaryingParameter(sampling_dist=norm(loc_large, scale_small)) + ) + + noise_transforms = [noise_transform1, noise_transform2, noise_transform3] + + diff_tr_list, diff_te_list = [], [] + + for noise_transform in noise_transforms: + noisy_data = noise_transform(base_data) + diff_tr = noisy_data.ytr - base_data.ytr + diff_te = noisy_data.yte - base_data.yte + diff_tr_list.append(diff_tr) + diff_te_list.append(diff_te) + + assert np.max(diff_tr_list[0]) < np.max(diff_tr_list[1]) + assert np.min(diff_tr_list[0]) > np.min(diff_tr_list[1]) + assert np.max(diff_tr_list[0]) < np.max(diff_tr_list[2]) + assert np.min(diff_tr_list[0]) < np.min(diff_tr_list[2]) + + assert np.max(diff_te_list[0]) < np.max(diff_te_list[1]) + assert np.min(diff_te_list[0]) > np.min(diff_te_list[1]) + assert np.max(diff_te_list[0]) < np.max(diff_te_list[2]) + assert np.min(diff_te_list[0]) < np.min(diff_te_list[2]) diff --git a/tests/test_causal_validation/test_validation/test_placebo.py b/tests/test_causal_validation/test_validation/test_placebo.py index b48d058..858c5f3 100644 --- a/tests/test_causal_validation/test_validation/test_placebo.py +++ b/tests/test_causal_validation/test_validation/test_placebo.py @@ -23,6 +23,7 @@ PlaceboTest, PlaceboTestResult, ) +from causal_validation.validation.testing import TestResultFrame def test_schema_coerce(): @@ -56,6 +57,7 @@ def test_placebo_test( # Check that the structure of result assert isinstance(result, PlaceboTestResult) + assert isinstance(result, TestResultFrame) for _, v in result.effects.items(): assert len(v) == n_control diff --git a/tests/test_causal_validation/test_validation/test_rmspe.py b/tests/test_causal_validation/test_validation/test_rmspe.py new file mode 100644 index 0000000..1bc6b37 --- /dev/null +++ b/tests/test_causal_validation/test_validation/test_rmspe.py @@ -0,0 +1,169 @@ +import typing as tp + +from azcausal.estimators.panel.did import DID +from azcausal.estimators.panel.sdid import SDID +from hypothesis import ( + given, + settings, + strategies as st, +) +import numpy as np +import pandas as pd +import pytest +from rich.table import Table + +from causal_validation.effects import StaticEffect +from causal_validation.models import AZCausalWrapper +from causal_validation.testing import ( + TestConstants, + simulate_data, +) +from causal_validation.transforms import Trend +from causal_validation.validation.rmspe import ( + RMSPESchema, + RMSPETest, + RMSPETestResult, +) +from causal_validation.validation.testing import ( + RMSPETestStatistic, + TestResult, + TestResultFrame, +) + + +def test_schema_coerce(): + df = RMSPESchema.example() + cols = df.columns + for col in cols: + if not col in ["Model", "Dataset"]: + df[col] = np.ceil((df[col])) + RMSPESchema.validate(df) + + +@given( + global_mean=st.floats(min_value=0.0, max_value=10.0), + seed=st.integers(min_value=0, max_value=1000000), + n_control=st.integers(min_value=10, max_value=20), + cf_inflate=st.one_of( + st.floats(min_value=1e-10, max_value=2.0), + st.floats(min_value=-2.0, max_value=-1e-10), + ), + s_inflate=st.one_of( + st.floats(min_value=1e-10, max_value=2.0), + st.floats(min_value=-2.0, max_value=-1e-10), + ), +) +@settings(max_examples=10) +def test_rmspe_test_stat( + global_mean: float, seed: int, n_control: int, cf_inflate: float, s_inflate: float +): + # Simulate data + constants = TestConstants(N_CONTROL=n_control, GLOBAL_SCALE=0.001) + data = simulate_data(global_mean=global_mean, seed=seed, constants=constants) + rmspe = RMSPETestStatistic() + counterfactual = np.concatenate((data.ytr, data.yte), axis=0) + cf_inflate + synthetic = counterfactual + assert rmspe( + data, counterfactual, synthetic, constants.N_PRE_TREATMENT + ) == pytest.approx(1.0) + + synthetic = np.concatenate((data.ytr, data.yte), axis=0) + s_inflate + assert rmspe( + data, counterfactual, synthetic, constants.N_PRE_TREATMENT + ) == pytest.approx(abs(cf_inflate) / abs(s_inflate)) + + synthetic = np.concatenate((data.ytr, data.yte), axis=0) + with pytest.raises( + ZeroDivisionError, match="Error: pre intervention period MSPE is 0!" + ): + rmspe(data, counterfactual, synthetic, constants.N_PRE_TREATMENT) + + +@given( + global_mean=st.floats(min_value=0.0, max_value=10.0), + effect=st.one_of( + st.floats(min_value=1.0, max_value=5.0), + st.floats(min_value=-5.0, max_value=-1.0), + ), + seed=st.integers(min_value=0, max_value=1000000), + n_control=st.integers(min_value=10, max_value=20), + model=st.sampled_from([DID(), SDID()]), +) +@settings(max_examples=10) +def test_rmspe_test( + global_mean: float, + effect: float, + seed: int, + n_control: int, + model: tp.Union[DID, SDID], +): + # Simulate data with a trend and effect + constants = TestConstants(N_CONTROL=n_control, GLOBAL_SCALE=0.001) + data = simulate_data(global_mean=global_mean, seed=seed, constants=constants) + trend_term = Trend(degree=1, coefficient=0.1) + static_effect = StaticEffect(effect=effect) + data = static_effect(trend_term(data)) + + model = AZCausalWrapper(model) + result = RMSPETest(model, data).execute() + + assert isinstance(result, RMSPETestResult) + assert isinstance(result, TestResultFrame) + assert set(result.treatment_test_results.keys()) == set( + result.pseudo_treatment_test_statistics.keys() + ) + + for k, v in result.treatment_test_results.items(): + assert isinstance(v, TestResult) + assert len(result.pseudo_treatment_test_statistics[k]) == n_control + + summary = result.to_df() + RMSPESchema.validate(summary) + assert isinstance(summary, pd.DataFrame) + assert summary.shape == (1, 4) + assert summary["p-value"].iloc[0] == pytest.approx(1.0 / (n_control + 1)) + + rich_summary = result.summary() + assert isinstance(rich_summary, Table) + n_rows = result.summary().row_count + assert n_rows == summary.shape[0] + + +@pytest.mark.parametrize("n_control", [9, 10]) +def test_multiple_models(n_control: int): + constants = TestConstants(N_CONTROL=n_control, GLOBAL_SCALE=0.001) + data = simulate_data(global_mean=20.0, seed=123, constants=constants) + trend_term = Trend(degree=1, coefficient=0.1) + data = trend_term(data) + + model1 = AZCausalWrapper(DID()) + model2 = AZCausalWrapper(SDID()) + result = RMSPETest([model1, model2], data).execute() + + result_df = result.to_df() + result_rich = result.summary() + assert result_df.shape == (2, 4) + assert result_df.shape[0] == result_rich.row_count + assert result_df["Model"].tolist() == ["DID", "SDID"] + for k, v in result.treatment_test_results.items(): + assert isinstance(v, TestResult) + assert len(result.pseudo_treatment_test_statistics[k]) == n_control + + +@given( + seeds=st.lists( + elements=st.integers(min_value=1, max_value=1000), min_size=1, max_size=5 + ) +) +@settings(max_examples=5) +def test_multiple_datasets(seeds: tp.List[int]): + data = [simulate_data(global_mean=20.0, seed=s) for s in seeds] + n_data = len(data) + + model = AZCausalWrapper(DID()) + result = RMSPETest(model, data).execute() + + result_df = result.to_df() + result_rich = result.summary() + assert result_df.shape == (n_data, 4) + assert result_df.shape[0] == result_rich.row_count