Skip to content

Commit 5d012d7

Browse files
authored
Merge pull request #50 from CompRhys/wrenformer-ensemble-preds
Wrenformer ensemble predictions
2 parents 79417d5 + 4d0470f commit 5d012d7

File tree

6 files changed

+231
-22
lines changed

6 files changed

+231
-22
lines changed

aviary/utils.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -769,8 +769,8 @@ def save_results_dict(
769769

770770

771771
def get_metrics(
772-
targets: np.ndarray,
773-
predictions: np.ndarray,
772+
targets: np.ndarray | pd.Series,
773+
predictions: np.ndarray | pd.Series,
774774
type: Literal["regression", "classification"],
775775
prec: int = 4,
776776
) -> dict:
@@ -791,17 +791,17 @@ def get_metrics(
791791
metrics = {}
792792

793793
if type == "regression":
794-
metrics["mae"] = np.abs(targets - predictions).mean()
795-
metrics["rmse"] = ((targets - predictions) ** 2).mean() ** 0.5
796-
metrics["r2"] = r2_score(targets, predictions)
794+
metrics["MAE"] = np.abs(targets - predictions).mean()
795+
metrics["RMSE"] = ((targets - predictions) ** 2).mean() ** 0.5
796+
metrics["R2"] = r2_score(targets, predictions)
797797
elif type == "classification":
798798
pred_labels = predictions.argmax(axis=1)
799799

800800
metrics["accuracy"] = accuracy_score(targets, pred_labels)
801801
metrics["balanced_accuracy"] = balanced_accuracy_score(targets, pred_labels)
802-
metrics["f1"] = f1_score(targets, pred_labels)
802+
metrics["F1"] = f1_score(targets, pred_labels)
803803
class1_probas = predictions[:, 1]
804-
metrics["rocauc"] = roc_auc_score(targets, class1_probas)
804+
metrics["ROCAUC"] = roc_auc_score(targets, class1_probas)
805805

806806
metrics = {key: round(float(val), prec) for key, val in metrics.items()}
807807

aviary/wrenformer/data.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ def get_composition_embedding(formula: str) -> Tensor:
125125

126126
def df_to_in_mem_dataloader(
127127
df: pd.DataFrame,
128-
target_col: str,
129128
input_col: str = "wyckoff",
130-
id_col: str = "material_id",
129+
target_col: str = None,
130+
id_col: str = None,
131131
embedding_type: Literal["wyckoff", "composition"] = "wyckoff",
132132
device: str = None,
133133
**kwargs,
@@ -137,10 +137,12 @@ def df_to_in_mem_dataloader(
137137
138138
Args:
139139
df (pd.DataFrame): Expected to have columns input_col, target_col, id_col.
140-
target_col (str): Column name holding the target values.
141140
input_col (str): Column name holding the input values (Aflow Wyckoff labels or composition
142141
strings) from which initial embeddings will be constructed. Defaults to "wyckoff".
143-
id_col (str): Column name holding material identifiers. Defaults to "material_id".
142+
target_col (str): Column name holding the target values. Defaults to None. Only leave this
143+
empty if making predictions since target tensor will be set to list of Nones.
144+
id_col (str): Column name holding sample IDs. Defaults to None. If None, IDs will be
145+
the dataframe index.
144146
embedding_type ('wyckoff' | 'composition'): Defaults to "wyckoff".
145147
device (str): torch.device to load tensors onto. Defaults to
146148
"cuda" if torch.cuda.is_available() else "cpu".
@@ -162,14 +164,18 @@ def df_to_in_mem_dataloader(
162164
if embedding_type == "wyckoff"
163165
else get_composition_embedding
164166
)
165-
targets = torch.tensor(df[target_col], device=device)
167+
targets = (
168+
torch.tensor(df[target_col].to_numpy(), device=device)
169+
if target_col in df
170+
else np.empty(len(df))
171+
)
166172
if targets.dtype == torch.bool:
167173
targets = targets.long() # convert binary classification targets to 0 and 1
168174
inputs = np.empty(len(initial_embeddings), dtype=object)
169175
for idx, tensor in enumerate(initial_embeddings):
170176
inputs[idx] = tensor.to(device)
171177

172-
ids = df[id_col].to_numpy()
178+
ids = (df[id_col] if id_col in df else df.index).to_numpy()
173179
data_loader = InMemoryDataLoader(
174180
[inputs, targets, ids], collate_fn=collate_batch, **kwargs
175181
)

aviary/wrenformer/utils.py

+123-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,22 @@
1+
from __future__ import annotations
2+
13
import json
24
import time
35
from contextlib import contextmanager
46
from typing import Generator, Literal
57

8+
import pandas as pd
9+
import torch
10+
from tqdm import tqdm
11+
12+
from aviary.core import BaseModelClass
13+
from aviary.utils import get_metrics
14+
from aviary.wrenformer.data import df_to_in_mem_dataloader
15+
from aviary.wrenformer.model import Wrenformer
16+
17+
__author__ = "Janosh Riebesell"
18+
__date__ = "2022-05-10"
19+
620

721
def _int_keys(dct: dict) -> dict:
822
# JSON stringifies all dict keys during serialization and does not revert
@@ -45,14 +59,14 @@ def merge_json_on_disk(
4559
pass
4660

4761
def non_serializable_handler(obj: object) -> str:
48-
# replace functions and classes in dct with string indicating a non-serializable type
62+
# replace functions and classes in dct with string indicating it's a non-serializable type
4963
return f"<not serializable: {type(obj).__qualname__}>"
5064

5165
with open(file_path, "w") as file:
5266
default = (
5367
non_serializable_handler if on_non_serializable == "annotate" else None
5468
)
55-
json.dump(dct, file, default=default)
69+
json.dump(dct, file, default=default, indent=2)
5670

5771

5872
@contextmanager
@@ -78,3 +92,110 @@ def print_walltime(
7892
finally:
7993
run_time = time.perf_counter() - start_time
8094
print(f"{end_desc} took {run_time:.2f} sec")
95+
96+
97+
def make_ensemble_predictions(
98+
checkpoint_paths: list[str],
99+
df: pd.DataFrame,
100+
target_col: str = None,
101+
input_col: str = "wyckoff",
102+
model_class: type[BaseModelClass] = Wrenformer,
103+
device: str = None,
104+
print_metrics: bool = True,
105+
warn_target_mismatch: bool = False,
106+
) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
107+
"""Make predictions using an ensemble of Wrenformer models.
108+
109+
Args:
110+
checkpoint_paths (list[str]): File paths to model checkpoints created with torch.save().
111+
df (pd.DataFrame): Dataframe to make predictions on. Will be returned with additional
112+
columns holding model predictions (and uncertainties for robust models) for each
113+
model checkpoint.
114+
target_col (str): Column holding target values. Defaults to None. If None, will not print
115+
performance metrics.
116+
input_col (str, optional): Column holding input values. Defaults to 'wyckoff'.
117+
device (str, optional): torch.device. Defaults to "cuda" if torch.cuda.is_available()
118+
else "cpu".
119+
print_metrics (bool, optional): Whether to print performance metrics. Defaults to True
120+
if target_col is not None.
121+
warn_target_mismatch (bool, optional): Whether to warn if target_col != target_name from
122+
model checkpoint. Defaults to False.
123+
124+
Returns:
125+
pd.DataFrame: Input dataframe with added columns for model and ensemble predictions. If
126+
target_col is not None, returns a 2nd dataframe containing model and ensemble metrics.
127+
"""
128+
# TODO: Add support for predicting all tasks a multi-task models was trained on. Currently only
129+
# handles single targets.
130+
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
131+
132+
data_loader = df_to_in_mem_dataloader(
133+
df=df,
134+
target_col=target_col,
135+
input_col=input_col,
136+
batch_size=512,
137+
embedding_type="wyckoff",
138+
)
139+
140+
print(f"Predicting with {len(checkpoint_paths):,} model checkpoints(s)")
141+
142+
for idx, checkpoint_path in enumerate(tqdm(checkpoint_paths), 1):
143+
checkpoint = torch.load(checkpoint_path, map_location=device)
144+
145+
model_params = checkpoint["model_params"]
146+
target_name, task_type = list(model_params["task_dict"].items())[0]
147+
assert task_type in ("regression", "classification"), f"invalid {task_type = }"
148+
if target_name != target_col and warn_target_mismatch:
149+
print(
150+
f"Warning: {target_col = } does not match {target_name = } in checkpoint. "
151+
"If this is not by accident, disable this warning by passing warn_target=False."
152+
)
153+
model = model_class(**model_params)
154+
model.to(device)
155+
156+
model.load_state_dict(checkpoint["model_state"])
157+
158+
with torch.no_grad():
159+
predictions = torch.cat([model(*inputs)[0] for inputs, *_ in data_loader])
160+
161+
if model.robust:
162+
predictions, aleat_log_std = predictions.chunk(2, dim=1)
163+
aleat_std = aleat_log_std.exp().cpu().numpy().squeeze()
164+
df[f"aleatoric_std_{idx}"] = aleat_std.tolist()
165+
166+
predictions = predictions.cpu().numpy().squeeze()
167+
pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}"
168+
df[pred_col] = predictions.tolist()
169+
170+
df_preds = df.filter(regex=r"_pred_\d")
171+
df[f"{target_col}_pred_ens"] = ensemble_preds = df_preds.mean(axis=1)
172+
df[f"{target_col}_epistemic_std_ens"] = epistemic_std = df_preds.std(axis=1)
173+
174+
if df.columns.str.startswith("aleatoric_std_").sum() > 0:
175+
aleatoric_std = df.filter(regex=r"aleatoric_std_\d").mean(axis=1)
176+
df[f"{target_col}_aleatoric_std_ens"] = aleatoric_std
177+
df[f"{target_col}_total_std_ens"] = (
178+
epistemic_std**2 + aleatoric_std**2
179+
) ** 0.5
180+
181+
if target_col and print_metrics:
182+
targets = df[target_col]
183+
all_model_metrics = pd.DataFrame(
184+
[
185+
get_metrics(targets, df_preds[pred_col], task_type)
186+
for pred_col in df_preds
187+
],
188+
index=df_preds.columns,
189+
)
190+
191+
print("\nSingle model performance:")
192+
print(all_model_metrics.describe().round(4).loc[["mean", "std"]])
193+
194+
ensemble_metrics = get_metrics(targets, ensemble_preds, task_type)
195+
196+
print("\nEnsemble performance:")
197+
for key, val in ensemble_metrics.items():
198+
print(f"{key:<8} {val:.3}")
199+
return df, all_model_metrics
200+
201+
return df
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from glob import glob
5+
6+
import pandas as pd
7+
import wandb
8+
9+
from aviary import ROOT
10+
from aviary.wrenformer.utils import make_ensemble_predictions
11+
12+
__author__ = "Janosh Riebesell"
13+
__date__ = "2022-06-23"
14+
15+
"""Script that downloads checkpoints for an ensemble of Wrenformer models trained on
16+
the MP+WBM dataset and makes predictions on the test set, then prints ensemble metrics.
17+
"""
18+
19+
20+
data_path = f"{ROOT}/datasets/2022-06-09-mp+wbm.json.gz"
21+
target_col = "e_form"
22+
test_size = 0.05
23+
df = pd.read_json(data_path)
24+
# shuffle with same random seed as in run_wrenformer() to get identical train/test split
25+
df = df.sample(frac=1, random_state=0)
26+
train_df = df.sample(frac=1 - test_size, random_state=0)
27+
test_df = df.drop(train_df.index)
28+
29+
30+
load_checkpoints_from_wandb = True
31+
32+
if load_checkpoints_from_wandb:
33+
wandb.login()
34+
wandb_api = wandb.Api()
35+
36+
runs = wandb_api.runs("aviary/mp-wbm", filters={"tags": {"$in": ["ensemble-id-2"]}})
37+
38+
print(
39+
f"Loading checkpoints for the following run IDs:\n{', '.join(run.id for run in runs)}\n"
40+
)
41+
42+
checkpoint_paths: list[str] = []
43+
for run in runs:
44+
run_path = "/".join(run.path)
45+
checkpoint_dir = f"{ROOT}/.wandb_checkpoints/{run_path}"
46+
os.makedirs(checkpoint_dir, exist_ok=True)
47+
48+
checkpoint_path = f"{checkpoint_dir}/checkpoint.pth"
49+
checkpoint_paths.append(checkpoint_path)
50+
51+
# download checkpoint from wandb if not already present
52+
if os.path.isfile(checkpoint_path):
53+
continue
54+
wandb.restore("checkpoint.pth", root=checkpoint_dir, run_path=run_path)
55+
else:
56+
# load checkpoints from local run dirs
57+
checkpoint_paths = glob(
58+
f"{ROOT}/examples/mp_wbm/job-logs/wandb/run-20220621_13*/files/checkpoint.pth"
59+
)
60+
61+
print(f"Predicting with {len(checkpoint_paths):,} model checkpoints(s)")
62+
63+
test_df, ensemble_metrics = make_ensemble_predictions(
64+
checkpoint_paths, df=test_df, target_col=target_col
65+
)
66+
67+
test_df.to_csv(f"{ROOT}/examples/mp_wbm/ensemble-predictions.csv")
68+
69+
70+
# print output:
71+
# Predicting with 10 model checkpoints(s)
72+
#
73+
# Single model performance:
74+
# MAE RMSE R2
75+
# mean 0.0369 0.1218 0.9864
76+
# std 0.0005 0.0014 0.0003
77+
#
78+
# Ensemble performance:
79+
# MAE 0.0308
80+
# RMSE 0.118
81+
# R2 0.987

examples/wrenformer.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def run_wrenformer(
317317

318318
# save model checkpoint
319319
if checkpoint is not None:
320-
state_dict = {
320+
checkpoint_dict = {
321321
"model_params": model_params,
322322
"model_state": inference_model.state_dict(),
323323
"optimizer_state": optimizer_instance.state_dict(),
@@ -327,16 +327,17 @@ def run_wrenformer(
327327
"metrics": test_metrics,
328328
"run_name": run_name,
329329
"normalizer_dict": normalizer_dict,
330+
"run_params": run_params,
330331
}
331332
if checkpoint == "local":
332333
os.makedirs(f"{ROOT}/models", exist_ok=True)
333334
checkpoint_path = f"{ROOT}/models/{timestamp}-{run_name}.pth"
334-
torch.save(state_dict, checkpoint_path)
335+
torch.save(checkpoint_dict, checkpoint_path)
335336
if checkpoint == "wandb":
336337
assert (
337338
wandb_project and wandb.run is not None
338339
), "can't save model checkpoint to Weights and Biases, wandb.run is None"
339-
torch.save(state_dict, f"{wandb.run.dir}/checkpoint.pth")
340+
torch.save(checkpoint_dict, f"{wandb.run.dir}/checkpoint.pth")
340341

341342
# record test set metrics and scatter/ROC plots to wandb
342343
if wandb_project:

tests/test_wrenformer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def test_wrenformer_regression(df_matbench_phonons_wyckoff):
1616
epochs=30,
1717
)
1818

19-
assert test_metrics["mae"] < 260, test_metrics
20-
assert test_metrics["rmse"] < 420, test_metrics
21-
assert test_metrics["r2"] > 0.1, test_metrics
19+
assert test_metrics["MAE"] < 260, test_metrics
20+
assert test_metrics["RMSE"] < 420, test_metrics
21+
assert test_metrics["R2"] > 0.1, test_metrics
2222

2323

2424
def test_wrenformer_classification(df_matbench_phonons_wyckoff):
@@ -36,4 +36,4 @@ def test_wrenformer_classification(df_matbench_phonons_wyckoff):
3636
)
3737

3838
assert test_metrics["accuracy"] > 0.7, test_metrics
39-
assert test_metrics["rocauc"] > 0.8, test_metrics
39+
assert test_metrics["ROCAUC"] > 0.8, test_metrics

0 commit comments

Comments
 (0)