Skip to content

Commit

Permalink
Fix vdid bug when providing weights
Browse files Browse the repository at this point in the history
  • Loading branch information
Julian Blank committed Aug 16, 2024
1 parent ed6f6ec commit 40a167e
Showing 1 changed file with 29 additions and 14 deletions.
43 changes: 29 additions & 14 deletions azcausal/estimators/panel/vdid.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,17 @@ def group_by_index(dx):
return dx.groupby(list(dx.index.names))


def dot_by_columns(ds, columns, name, weight=None):
def dot_by_columns(ds, dim, name, weight=None):
if weight is None:
weight = dict()
counts = columns.map(len)
counts = dim.map(len)

avg = dict()
for k, v in columns.items():
for k, v in dim.items():
v = [e for e in v if e in ds.columns]

if k in weight:
if weight.get(k, None) is not None:
w = np.array([weight[k].get(e, 0.0) for e in v])
w = (w / np.abs(w).sum()) * np.abs(weight[k]).sum()

avg[k] = ds[v].values @ w
else:
avg[k] = np.sum(ds[v], axis=1) / counts[k]
Expand All @@ -110,7 +108,7 @@ def vdid_avg_by(dx, label, col, dim=None, weight=None):
if weight is None:
avg = group_by_index(dx.droplevel(col, axis='index')).sum().divide(counts, axis='index', level=label)
else:
avg = dot_by_columns(dx.droplevel(label, axis='index').unstack(col), dim, label, weight=weight).stack()
avg = dot_by_columns(dx.droplevel(label, axis='index').unstack(col).fillna(0.0), dim, label, weight=weight).stack()

return avg, counts

Expand Down Expand Up @@ -138,17 +136,34 @@ def vdid_take_by_column(dx, col):
return dx


def vdid_fix_weights(treatment, weights):
ans = dict()
for k, x in treatment.items():
weight = weights.get(k, None)

if weight is not None:
x = list(set(x))
weight = weight.loc[x]
ans[k] = weight / weight.sum()

return pd.Series(ans)


def vdid_jackknife():
def sample(treatment: pd.Series) -> Iterator[pd.Series]:
def sample(treatment: pd.Series, weights: pd.Series) -> Iterator[pd.Series]:
treat, contr = treatment[True], treatment[False]

if len(contr) > 1:
for i in range(len(contr)):
yield pd.Series({True: treat, False: np.delete(contr, i)})
treatment_mod = pd.Series({True: treat, False: np.delete(contr, i)})
weight_mod = vdid_fix_weights(treatment_mod, weights)
yield treatment_mod, weight_mod

if len(treat) > 1:
for i in range(len(treat)):
yield pd.Series({True: np.delete(treat, i), False: contr})
treatment_mod = pd.Series({True: np.delete(treat, i), False: contr})
weight_mod = vdid_fix_weights(treatment_mod, weights)
yield treatment_mod, weight_mod

def fit(dse: pd.DataFrame):
n = len(dse)
Expand All @@ -172,7 +187,7 @@ def sample(treatment: pd.Series) -> Iterator[pd.Series]:


def vdid_bootstrap(n_samples=1000, seed=1):
def sample(treatment: pd.Series) -> Iterator[pd.Series]:
def sample(treatment: pd.Series, weights: pd.Series) -> Iterator[pd.Series]:

H = dict()
for unit in treatment[True]:
Expand All @@ -193,7 +208,7 @@ def sample(treatment: pd.Series) -> Iterator[pd.Series]:

# if we have at least one treatment and one control
if tt.map(lambda x: len(x) > 0).all():
yield tt
yield tt, vdid_fix_weights(tt, weights)
cnt += 1

return sample, vdid_se
Expand Down Expand Up @@ -281,8 +296,8 @@ def vdid(dx: pd.DataFrame,
ci_sample, ci_fit = ci

# simulate based on the standard error method
ci_samp_dict = {sample: f(vdid_ratio(dot_by_columns(matrix, treatment_mod, randomize, weight=weight).stack(), ratio))
for sample, treatment_mod in enumerate(ci_sample(units))}
ci_samp_dict = {sample: f(vdid_ratio(dot_by_columns(matrix, treatment_mod, randomize, weight=weight_mod).stack(), ratio))
for sample, (treatment_mod, weight_mod) in enumerate(ci_sample(units, weight))}

ci_samp = pd.DataFrame(ci_samp_dict).rename_axis('sample', axis=1).stack()

Expand Down

0 comments on commit 40a167e

Please sign in to comment.