diff --git a/azcausal/estimators/panel/vdid.py b/azcausal/estimators/panel/vdid.py index 68c2418..075925f 100644 --- a/azcausal/estimators/panel/vdid.py +++ b/azcausal/estimators/panel/vdid.py @@ -82,19 +82,17 @@ def group_by_index(dx): return dx.groupby(list(dx.index.names)) -def dot_by_columns(ds, columns, name, weight=None): +def dot_by_columns(ds, dim, name, weight=None): if weight is None: weight = dict() - counts = columns.map(len) + counts = dim.map(len) avg = dict() - for k, v in columns.items(): + for k, v in dim.items(): v = [e for e in v if e in ds.columns] - if k in weight: + if weight.get(k, None) is not None: w = np.array([weight[k].get(e, 0.0) for e in v]) - w = (w / np.abs(w).sum()) * np.abs(weight[k]).sum() - avg[k] = ds[v].values @ w else: avg[k] = np.sum(ds[v], axis=1) / counts[k] @@ -110,7 +108,7 @@ def vdid_avg_by(dx, label, col, dim=None, weight=None): if weight is None: avg = group_by_index(dx.droplevel(col, axis='index')).sum().divide(counts, axis='index', level=label) else: - avg = dot_by_columns(dx.droplevel(label, axis='index').unstack(col), dim, label, weight=weight).stack() + avg = dot_by_columns(dx.droplevel(label, axis='index').unstack(col).fillna(0.0), dim, label, weight=weight).stack() return avg, counts @@ -138,17 +136,34 @@ def vdid_take_by_column(dx, col): return dx +def vdid_fix_weights(treatment, weights): + ans = dict() + for k, x in treatment.items(): + weight = weights.get(k, None) + + if weight is not None: + x = list(set(x)) + weight = weight.loc[x] + ans[k] = weight / weight.sum() + + return pd.Series(ans) + + def vdid_jackknife(): - def sample(treatment: pd.Series) -> Iterator[pd.Series]: + def sample(treatment: pd.Series, weights: pd.Series) -> Iterator[pd.Series]: treat, contr = treatment[True], treatment[False] if len(contr) > 1: for i in range(len(contr)): - yield pd.Series({True: treat, False: np.delete(contr, i)}) + treatment_mod = pd.Series({True: treat, False: np.delete(contr, i)}) + weight_mod = vdid_fix_weights(treatment_mod, weights) + yield treatment_mod, weight_mod if len(treat) > 1: for i in range(len(treat)): - yield pd.Series({True: np.delete(treat, i), False: contr}) + treatment_mod = pd.Series({True: np.delete(treat, i), False: contr}) + weight_mod = vdid_fix_weights(treatment_mod, weights) + yield treatment_mod, weight_mod def fit(dse: pd.DataFrame): n = len(dse) @@ -172,7 +187,7 @@ def sample(treatment: pd.Series) -> Iterator[pd.Series]: def vdid_bootstrap(n_samples=1000, seed=1): - def sample(treatment: pd.Series) -> Iterator[pd.Series]: + def sample(treatment: pd.Series, weights: pd.Series) -> Iterator[pd.Series]: H = dict() for unit in treatment[True]: @@ -193,7 +208,7 @@ def sample(treatment: pd.Series) -> Iterator[pd.Series]: # if we have at least one treatment and one control if tt.map(lambda x: len(x) > 0).all(): - yield tt + yield tt, vdid_fix_weights(tt, weights) cnt += 1 return sample, vdid_se @@ -281,8 +296,8 @@ def vdid(dx: pd.DataFrame, ci_sample, ci_fit = ci # simulate based on the standard error method - ci_samp_dict = {sample: f(vdid_ratio(dot_by_columns(matrix, treatment_mod, randomize, weight=weight).stack(), ratio)) - for sample, treatment_mod in enumerate(ci_sample(units))} + ci_samp_dict = {sample: f(vdid_ratio(dot_by_columns(matrix, treatment_mod, randomize, weight=weight_mod).stack(), ratio)) + for sample, (treatment_mod, weight_mod) in enumerate(ci_sample(units, weight))} ci_samp = pd.DataFrame(ci_samp_dict).rename_axis('sample', axis=1).stack()