Skip to content

Commit b7655b5

Browse files
CompRhysjanosh
andauthored
Adopt protostructure naming (#84)
* doc: adopt protostructure naming * lint: sort imports in notebooks * clean: rename last functions * fix outdated import pymatviz.(utils->powerups).add_identity_line * ruff auto-fixes * fix ruff aviary/utils.py:732:5: PLC0206 Extracting value from dictionary without calling `.items()` and aviary/roost/data.py:116:13: PLR1704 Redefining argument with the local name `idx` * fix save_results_dict doc string * fea: bump version for breaking change * fea: rename parse function used in wren data --------- Co-authored-by: Janosh Riebesell <[email protected]>
1 parent 94fba8c commit b7655b5

21 files changed

+467
-394
lines changed

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ ci:
55

66
repos:
77
- repo: https://github.com/astral-sh/ruff-pre-commit
8-
rev: v0.5.0
8+
rev: v0.5.3
99
hooks:
1010
- id: ruff
1111
args: [--fix]
@@ -30,7 +30,7 @@ repos:
3030
args: [--check-filenames]
3131

3232
- repo: https://github.com/pre-commit/mirrors-mypy
33-
rev: v1.10.1
33+
rev: v1.11.0
3434
hooks:
3535
- id: mypy
3636
exclude: (tests|examples)/

aviary/roost/data.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def __getitem__(self, idx: int):
113113
n_elems = len(elements)
114114
self_idx = []
115115
nbr_idx = []
116-
for idx in range(n_elems):
117-
self_idx += [idx] * n_elems
116+
for elem_idx in range(n_elems):
117+
self_idx += [elem_idx] * n_elems
118118
nbr_idx += list(range(n_elems))
119119

120120
# convert all data to tensors

aviary/segments.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def forward(self, x: Tensor, index: Tensor) -> Tensor:
3838
"""
3939
gate = self.gate_nn(x)
4040

41-
gate = gate - scatter_max(gate, index, dim=0)[0][index]
41+
gate -= scatter_max(gate, index, dim=0)[0][index]
4242
gate = gate.exp()
43-
gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-10)
43+
gate /= scatter_add(gate, index, dim=0)[index] + 1e-10
4444

4545
x = self.message_nn(x)
4646
return scatter_add(gate * x, index, dim=0)
@@ -78,9 +78,9 @@ def forward(self, x: Tensor, index: Tensor, weights: Tensor) -> Tensor:
7878
"""
7979
gate = self.gate_nn(x)
8080

81-
gate = gate - scatter_max(gate, index, dim=0)[0][index]
81+
gate -= scatter_max(gate, index, dim=0)[0][index]
8282
gate = (weights**self.pow) * gate.exp()
83-
gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-10)
83+
gate /= scatter_add(gate, index, dim=0)[index] + 1e-10
8484

8585
x = self.message_nn(x)
8686
return scatter_add(gate * x, index, dim=0)

aviary/train.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -246,14 +246,13 @@ def train_model(
246246
print("Starting stochastic weight averaging...")
247247
swa_model.update_parameters(model)
248248
swa_scheduler.step()
249+
elif scheduler_name == "ReduceLROnPlateau":
250+
val_metric = val_metrics[target_col][
251+
"MAE" if task_type == reg_key else "Accuracy"
252+
]
253+
lr_scheduler.step(val_metric)
249254
else:
250-
if scheduler_name == "ReduceLROnPlateau":
251-
val_metric = val_metrics[target_col][
252-
"MAE" if task_type == reg_key else "Accuracy"
253-
]
254-
lr_scheduler.step(val_metric)
255-
else:
256-
lr_scheduler.step()
255+
lr_scheduler.step()
257256

258257
model.epoch += 1
259258

aviary/utils.py

+20-34
Original file line numberDiff line numberDiff line change
@@ -237,15 +237,12 @@ def initialize_losses(
237237
raise NameError(
238238
"Only L1 or L2 losses are allowed for robust regression tasks"
239239
)
240+
elif loss_name_dict[name] == "L1":
241+
loss_func_dict[name] = (task, L1Loss())
242+
elif loss_name_dict[name] == "L2":
243+
loss_func_dict[name] = (task, MSELoss())
240244
else:
241-
if loss_name_dict[name] == "L1":
242-
loss_func_dict[name] = (task, L1Loss())
243-
elif loss_name_dict[name] == "L2":
244-
loss_func_dict[name] = (task, MSELoss())
245-
else:
246-
raise NameError(
247-
"Only L1 or L2 losses are allowed for regression tasks"
248-
)
245+
raise NameError("Only L1 or L2 losses are allowed for regression tasks")
249246

250247
return loss_func_dict
251248

@@ -723,46 +720,35 @@ def save_results_dict(
723720
"""Save the results to a file after model evaluation.
724721
725722
Args:
726-
ids (dict[str, list[str | int]]): ): Each key is the name of an identifier
723+
ids (dict[str, list[str | int]]): Each key is the name of an identifier
727724
(e.g. material ID, composition, ...) and its value a list of IDs.
728-
results_dict (dict[str, Any]): ): nested dictionary of results
729-
{name: {col: data}}
730-
model_name (str): ): The name given the model via the --model-name flag.
731-
run_id (str): ): The run ID given to the model via the --run-id flag.
725+
results_dict (dict[str, Any]): nested dictionary of results {name: {col: data}}
726+
model_name (str): The name given the model via the --model-name flag.
727+
run_id (str): The run ID given to the model via the --run-id flag.
732728
"""
733-
results = {}
729+
results: dict[str, np.ndarray] = {}
734730

735-
for target_name in results_dict:
736-
for col, data in results_dict[target_name].items():
731+
for target_name, target_data in results_dict.items():
732+
for col, data in target_data.items():
737733
# NOTE we save pre_logits rather than logits due to fact
738734
# that with the heteroskedastic setup we want to be able to
739735
# sample from the Gaussian distributed pre_logits we parameterize.
740736
if "pre-logits" in col:
741737
for n_ens, y_pre_logit in enumerate(data):
742-
results.update(
743-
{
744-
f"{target_name}_{col}_c{lab}_n{n_ens}": val.ravel()
745-
for lab, val in enumerate(y_pre_logit.T)
746-
}
747-
)
738+
results |= {
739+
f"{target_name}_{col}_c{lab}_n{n_ens}": val.ravel()
740+
for lab, val in enumerate(y_pre_logit.T)
741+
}
748742

749-
elif "pred" in col:
750-
preds = {
743+
elif "pred" in col or "ale" in col:
744+
# elif so that pre-logit-ale doesn't trigger
745+
results |= {
751746
f"{target_name}_{col}_n{n_ens}": val.ravel()
752747
for (n_ens, val) in enumerate(data)
753748
}
754-
results.update(preds)
755-
756-
elif "ale" in col: # elif so that pre-logit-ale doesn't trigger
757-
results.update(
758-
{
759-
f"{target_name}_{col}_n{n_ens}": val.ravel()
760-
for (n_ens, val) in enumerate(data)
761-
}
762-
)
763749

764750
elif col == "target":
765-
results.update({f"{target_name}_target": data})
751+
results |= {f"{target_name}_target": data}
766752

767753
df = pd.DataFrame({**ids, **results})
768754

aviary/wren/data.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,10 @@ def __getitem__(self, idx: int):
108108
- list[str | int]: identifiers like material_id, composition
109109
"""
110110
row = self.df.iloc[idx]
111-
wyckoff_str = row[self.inputs]
111+
protostructure_label = row[self.inputs]
112112
material_ids = row[self.identifiers].to_list()
113113

114-
parsed_output = parse_aflow_wyckoff_str(wyckoff_str)
114+
parsed_output = parse_protostructure_label(protostructure_label)
115115
spg_num, wyk_site_multiplcities, elements, augmented_wyks = parsed_output
116116

117117
wyk_site_multiplcities = np.atleast_2d(wyk_site_multiplcities).T / np.sum(
@@ -256,21 +256,29 @@ def collate_batch(
256256
)
257257

258258

259-
def parse_aflow_wyckoff_str(
260-
aflow_label: str,
259+
def parse_protostructure_label(
260+
protostructure_label: str,
261261
) -> tuple[str, list[float], list[str], list[tuple[str, ...]]]:
262262
"""Parse the Wren AFLOW-like Wyckoff encoding.
263263
264264
Args:
265-
aflow_label (str): AFLOW-style prototype string with appended chemical system
265+
protostructure_label (str): label constructed as `aflow_label:chemsys` where
266+
aflow_label is an AFLOW-style prototype label chemsys is the alphabetically
267+
sorted chemical system.
266268
267269
Returns:
268270
tuple[str, list[float], list[str], list[str]]: spacegroup number, Wyckoff site
269271
multiplicities, elements symbols and equivalent wyckoff sets
270272
"""
271-
proto, chemsys = aflow_label.split(":")
273+
aflow_label, chemsys = protostructure_label.split(":")
272274
elems = chemsys.split("-")
273-
_, _, spg_num, *wyckoff_letters = proto.split("_")
275+
_, _, spg_num, *wyckoff_letters = aflow_label.split("_")
276+
277+
if len(elems) != len(wyckoff_letters):
278+
raise ValueError(
279+
f"Chemical system {chemsys} does not match Wyckoff letters "
280+
f"{wyckoff_letters}"
281+
)
274282

275283
wyckoff_site_multiplicities = []
276284
elements = []

0 commit comments

Comments
 (0)