Skip to content

Commit 8bb8de1

Browse files
authored
Revert "Rmspe test stat (#22)"
This reverts commit 9b7b5aa.
1 parent 9b7b5aa commit 8bb8de1

File tree

14 files changed

+42
-557
lines changed

14 files changed

+42
-557
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,4 @@ in your terminal.
7676
1. Follow steps 1-3 from `For Users`
7777
2. Create a hatch environment `hatch env create`
7878
3. Open a hatch shell `hatch shell`
79-
4. Validate your installation by running `hatch run dev:test`
79+
4. Validate your installation by running `hatch run tests:test`

src/causal_validation/__about__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "0.0.8"
1+
__version__ = "0.0.7"
22

33
__all__ = ["__version__"]

src/causal_validation/data.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class Dataset:
2727
yte: Float[np.ndarray, "M 1"]
2828
_start_date: dt.date
2929
counterfactual: tp.Optional[Float[np.ndarray, "M 1"]] = None
30-
synthetic: tp.Optional[Float[np.ndarray, "M 1"]] = None
3130
_name: str = None
3231

3332
def to_df(
@@ -152,13 +151,7 @@ def drop_unit(self, idx: int) -> Dataset:
152151
Xtr = np.delete(self.Xtr, [idx], axis=1)
153152
Xte = np.delete(self.Xte, [idx], axis=1)
154153
return Dataset(
155-
Xtr,
156-
Xte,
157-
self.ytr,
158-
self.yte,
159-
self._start_date,
160-
self.counterfactual,
161-
self.synthetic,
154+
Xtr, Xte, self.ytr, self.yte, self._start_date, self.counterfactual
162155
)
163156

164157
def to_placebo_data(self, to_treat_idx: int) -> Dataset:
@@ -211,6 +204,4 @@ def reassign_treatment(
211204
) -> Dataset:
212205
Xtr = data.Xtr
213206
Xte = data.Xte
214-
return Dataset(
215-
Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic
216-
)
207+
return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual)

src/causal_validation/models.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,14 @@
1515
class Result:
1616
effect: Effect
1717
counterfactual: Float[NPArray, "N 1"]
18-
synthetic: Float[NPArray, "N 1"]
1918
observed: Float[NPArray, "N 1"]
2019

2120

2221
@dataclass
2322
class AZCausalWrapper:
2423
model: Estimator
2524
error_estimator: tp.Optional[Error] = None
26-
_az_result: _Result = None
25+
_az_result: Result = None
2726

2827
def __post_init__(self):
2928
self._model_name = self.model.__class__.__name__
@@ -38,7 +37,6 @@ def __call__(self, data: Dataset, **kwargs) -> Result:
3837
res = Result(
3938
effect=result.effect,
4039
counterfactual=self.counterfactual,
41-
synthetic=self.synthetic,
4240
observed=self.observed,
4341
)
4442
return res
@@ -49,12 +47,6 @@ def counterfactual(self) -> Float[NPArray, "N 1"]:
4947
c_factual = df.loc[:, "CF"].values.reshape(-1, 1)
5048
return c_factual
5149

52-
@property
53-
def synthetic(self) -> Float[NPArray, "N 1"]:
54-
df = self._az_result.effect.by_time
55-
synth_control = df.loc[:, "C"].values.reshape(-1, 1)
56-
return synth_control
57-
5850
@property
5951
def observed(self) -> Float[NPArray, "N 1"]:
6052
df = self._az_result.effect.by_time
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from causal_validation.transforms.noise import Noise
21
from causal_validation.transforms.periodic import Periodic
32
from causal_validation.transforms.trends import Trend
43

5-
__all__ = ["Trend", "Periodic", "Noise"]
4+
__all__ = ["Trend", "Periodic"]

src/causal_validation/transforms/base.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ def apply_values(
7474
ytr = ytr + pre_intervention_vals[:, :1]
7575
Xte = Xte + post_intervention_vals[:, 1:]
7676
yte = yte + post_intervention_vals[:, :1]
77-
return Dataset(
78-
Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic
79-
)
77+
return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual)
8078

8179

8280
@dataclass(kw_only=True)
@@ -93,6 +91,4 @@ def apply_values(
9391
ytr = ytr * pre_intervention_vals
9492
Xte = Xte * post_intervention_vals
9593
yte = yte * post_intervention_vals
96-
return Dataset(
97-
Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic
98-
)
94+
return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual)

src/causal_validation/transforms/noise.py

-30
This file was deleted.

src/causal_validation/validation/placebo.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
Column,
1010
DataFrameSchema,
1111
)
12+
from rich import box
1213
from rich.progress import (
1314
Progress,
1415
ProgressBar,
1516
track,
1617
)
18+
from rich.table import Table
1719
from scipy.stats import ttest_1samp
1820
from tqdm import (
1921
tqdm,
@@ -28,7 +30,6 @@
2830
AZCausalWrapper,
2931
Result,
3032
)
31-
from causal_validation.validation.testing import TestResultFrame
3233

3334
PlaceboSchema = DataFrameSchema(
3435
{
@@ -45,12 +46,13 @@
4546

4647

4748
@dataclass
48-
class PlaceboTestResult(TestResultFrame):
49+
class PlaceboTestResult:
4950
effects: tp.Dict[tp.Tuple[str, str], tp.List[Result]]
5051

5152
def _model_to_df(
5253
self, model_name: str, dataset_name: str, effects: tp.List[Result]
5354
) -> pd.DataFrame:
55+
breakpoint()
5456
_effects = [e.effect.percentage().value for e in effects]
5557
_n_effects = len(_effects)
5658
expected_effect = np.mean(_effects)
@@ -77,6 +79,21 @@ def to_df(self) -> pd.DataFrame:
7779
PlaceboSchema.validate(df)
7880
return df
7981

82+
def summary(self, precision: int = 4) -> Table:
83+
table = Table(show_header=True, box=box.MARKDOWN)
84+
df = self.to_df()
85+
numeric_cols = df.select_dtypes(include=[np.number])
86+
df.loc[:, numeric_cols.columns] = np.round(numeric_cols, decimals=precision)
87+
88+
for column in df.columns:
89+
table.add_column(str(column), style="magenta")
90+
91+
for _, value_list in enumerate(df.values.tolist()):
92+
row = [str(x) for x in value_list]
93+
table.add_row(*row)
94+
95+
return table
96+
8097

8198
@dataclass
8299
class PlaceboTest:

src/causal_validation/validation/rmspe.py

-133
This file was deleted.

0 commit comments

Comments
 (0)