1
+ from __future__ import annotations
2
+
1
3
import json
2
4
import time
3
5
from contextlib import contextmanager
4
6
from typing import Generator , Literal
5
7
8
+ import pandas as pd
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ from aviary .core import BaseModelClass
13
+ from aviary .utils import get_metrics
14
+ from aviary .wrenformer .data import df_to_in_mem_dataloader
15
+ from aviary .wrenformer .model import Wrenformer
16
+
17
+ __author__ = "Janosh Riebesell"
18
+ __date__ = "2022-05-10"
19
+
6
20
7
21
def _int_keys (dct : dict ) -> dict :
8
22
# JSON stringifies all dict keys during serialization and does not revert
@@ -45,14 +59,14 @@ def merge_json_on_disk(
45
59
pass
46
60
47
61
def non_serializable_handler (obj : object ) -> str :
48
- # replace functions and classes in dct with string indicating a non-serializable type
62
+ # replace functions and classes in dct with string indicating it's a non-serializable type
49
63
return f"<not serializable: { type (obj ).__qualname__ } >"
50
64
51
65
with open (file_path , "w" ) as file :
52
66
default = (
53
67
non_serializable_handler if on_non_serializable == "annotate" else None
54
68
)
55
- json .dump (dct , file , default = default )
69
+ json .dump (dct , file , default = default , indent = 2 )
56
70
57
71
58
72
@contextmanager
@@ -78,3 +92,110 @@ def print_walltime(
78
92
finally :
79
93
run_time = time .perf_counter () - start_time
80
94
print (f"{ end_desc } took { run_time :.2f} sec" )
95
+
96
+
97
+ def make_ensemble_predictions (
98
+ checkpoint_paths : list [str ],
99
+ df : pd .DataFrame ,
100
+ target_col : str = None ,
101
+ input_col : str = "wyckoff" ,
102
+ model_class : type [BaseModelClass ] = Wrenformer ,
103
+ device : str = None ,
104
+ print_metrics : bool = True ,
105
+ warn_target_mismatch : bool = False ,
106
+ ) -> pd .DataFrame | tuple [pd .DataFrame , pd .DataFrame ]:
107
+ """Make predictions using an ensemble of Wrenformer models.
108
+
109
+ Args:
110
+ checkpoint_paths (list[str]): File paths to model checkpoints created with torch.save().
111
+ df (pd.DataFrame): Dataframe to make predictions on. Will be returned with additional
112
+ columns holding model predictions (and uncertainties for robust models) for each
113
+ model checkpoint.
114
+ target_col (str): Column holding target values. Defaults to None. If None, will not print
115
+ performance metrics.
116
+ input_col (str, optional): Column holding input values. Defaults to 'wyckoff'.
117
+ device (str, optional): torch.device. Defaults to "cuda" if torch.cuda.is_available()
118
+ else "cpu".
119
+ print_metrics (bool, optional): Whether to print performance metrics. Defaults to True
120
+ if target_col is not None.
121
+ warn_target_mismatch (bool, optional): Whether to warn if target_col != target_name from
122
+ model checkpoint. Defaults to False.
123
+
124
+ Returns:
125
+ pd.DataFrame: Input dataframe with added columns for model and ensemble predictions. If
126
+ target_col is not None, returns a 2nd dataframe containing model and ensemble metrics.
127
+ """
128
+ # TODO: Add support for predicting all tasks a multi-task models was trained on. Currently only
129
+ # handles single targets.
130
+ device = device or ("cuda" if torch .cuda .is_available () else "cpu" )
131
+
132
+ data_loader = df_to_in_mem_dataloader (
133
+ df = df ,
134
+ target_col = target_col ,
135
+ input_col = input_col ,
136
+ batch_size = 512 ,
137
+ embedding_type = "wyckoff" ,
138
+ )
139
+
140
+ print (f"Predicting with { len (checkpoint_paths ):,} model checkpoints(s)" )
141
+
142
+ for idx , checkpoint_path in enumerate (tqdm (checkpoint_paths ), 1 ):
143
+ checkpoint = torch .load (checkpoint_path , map_location = device )
144
+
145
+ model_params = checkpoint ["model_params" ]
146
+ target_name , task_type = list (model_params ["task_dict" ].items ())[0 ]
147
+ assert task_type in ("regression" , "classification" ), f"invalid { task_type = } "
148
+ if target_name != target_col and warn_target_mismatch :
149
+ print (
150
+ f"Warning: { target_col = } does not match { target_name = } in checkpoint. "
151
+ "If this is not by accident, disable this warning by passing warn_target=False."
152
+ )
153
+ model = model_class (** model_params )
154
+ model .to (device )
155
+
156
+ model .load_state_dict (checkpoint ["model_state" ])
157
+
158
+ with torch .no_grad ():
159
+ predictions = torch .cat ([model (* inputs )[0 ] for inputs , * _ in data_loader ])
160
+
161
+ if model .robust :
162
+ predictions , aleat_log_std = predictions .chunk (2 , dim = 1 )
163
+ aleat_std = aleat_log_std .exp ().cpu ().numpy ().squeeze ()
164
+ df [f"aleatoric_std_{ idx } " ] = aleat_std .tolist ()
165
+
166
+ predictions = predictions .cpu ().numpy ().squeeze ()
167
+ pred_col = f"{ target_col } _pred_{ idx } " if target_col else f"pred_{ idx } "
168
+ df [pred_col ] = predictions .tolist ()
169
+
170
+ df_preds = df .filter (regex = r"_pred_\d" )
171
+ df [f"{ target_col } _pred_ens" ] = ensemble_preds = df_preds .mean (axis = 1 )
172
+ df [f"{ target_col } _epistemic_std_ens" ] = epistemic_std = df_preds .std (axis = 1 )
173
+
174
+ if df .columns .str .startswith ("aleatoric_std_" ).sum () > 0 :
175
+ aleatoric_std = df .filter (regex = r"aleatoric_std_\d" ).mean (axis = 1 )
176
+ df [f"{ target_col } _aleatoric_std_ens" ] = aleatoric_std
177
+ df [f"{ target_col } _total_std_ens" ] = (
178
+ epistemic_std ** 2 + aleatoric_std ** 2
179
+ ) ** 0.5
180
+
181
+ if target_col and print_metrics :
182
+ targets = df [target_col ]
183
+ all_model_metrics = pd .DataFrame (
184
+ [
185
+ get_metrics (targets , df_preds [pred_col ], task_type )
186
+ for pred_col in df_preds
187
+ ],
188
+ index = df_preds .columns ,
189
+ )
190
+
191
+ print ("\n Single model performance:" )
192
+ print (all_model_metrics .describe ().round (4 ).loc [["mean" , "std" ]])
193
+
194
+ ensemble_metrics = get_metrics (targets , ensemble_preds , task_type )
195
+
196
+ print ("\n Ensemble performance:" )
197
+ for key , val in ensemble_metrics .items ():
198
+ print (f"{ key :<8} { val :.3} " )
199
+ return df , all_model_metrics
200
+
201
+ return df
0 commit comments