Skip to content

Commit e43685f

Browse files
authored
Fix get_aflow_label_from_spglib (#75)
* pre-commit autoupdate and run --all-files * fix UnboundLocalError: local variable 'aflow_label_with_chemsys' referenced before assignment if try case raises * fix mypy * ruff unignore and fix RET504 unnecessary-assign * ruff unignore and fix D107 Missing docstring in __init__ * ruff unignore and fix B904
1 parent a009fe7 commit e43685f

15 files changed

+50
-56
lines changed

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ ci:
55

66
repos:
77
- repo: https://github.com/astral-sh/ruff-pre-commit
8-
rev: v0.0.276
8+
rev: v0.0.284
99
hooks:
1010
- id: ruff
1111
args: [--fix]
@@ -28,7 +28,7 @@ repos:
2828
exclude_types: [json]
2929

3030
- repo: https://github.com/psf/black
31-
rev: 23.3.0
31+
rev: 23.7.0
3232
hooks:
3333
- id: black-jupyter
3434

aviary/cgcnn/model.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,15 @@ def __init__(
129129
elem_fea_len: int = 64,
130130
n_graph: int = 4,
131131
) -> None:
132+
"""Initialize DescriptorNetwork.
133+
134+
Args:
135+
elem_emb_len (int): Number of atom features in the input.
136+
nbr_fea_len (int): Number of bond features.
137+
elem_fea_len (int, optional): Number of hidden atom features in the graph convolution
138+
layers. Defaults to 64.
139+
n_graph (int, optional): Number of graph convolution layers. Defaults to 4.
140+
"""
132141
super().__init__()
133142

134143
self.embedding = nn.Linear(elem_emb_len, elem_fea_len)
@@ -222,6 +231,4 @@ def forward(
222231
nbr_summed = scatter_add(nbr_msg, self_idx, dim=0)
223232

224233
nbr_summed = self.bn2(nbr_summed)
225-
out = self.softplus2(atom_in_fea + nbr_summed)
226-
227-
return out
234+
return self.softplus2(atom_in_fea + nbr_summed)

aviary/core.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ class Normalizer:
411411
"""Normalize a Tensor and restore it later."""
412412

413413
def __init__(self) -> None:
414+
"""Initialize Normalizer with mean 0 and std 1."""
414415
self.mean = torch.tensor(0)
415416
self.std = torch.tensor(1)
416417

@@ -579,8 +580,7 @@ def masked_std(x: Tensor, mask: BoolTensor, dim: int = 0, eps: float = 1e-12) ->
579580
mean = masked_mean(x, mask, dim=dim)
580581
squared_diff = (x - mean.unsqueeze(dim=dim)) ** 2
581582
var = masked_mean(squared_diff, mask, dim=dim)
582-
std = (var + eps).sqrt()
583-
return std
583+
return (var + eps).sqrt()
584584

585585

586586
def masked_mean(x: Tensor, mask: BoolTensor, dim: int = 0) -> Tensor:

aviary/predict.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,14 @@ def make_ensemble_predictions(
7272
):
7373
try:
7474
checkpoint = torch.load(checkpoint_path, map_location=device)
75-
except Exception as exc: # noqa: PERF203
75+
except Exception as exc:
7676
raise RuntimeError(f"Failed to load {checkpoint_path=}") from exc
7777

7878
model_params = checkpoint.get("model_params")
7979
if model_params is None:
8080
raise ValueError(f"model_params not found in {checkpoint_path=}")
8181

82-
target_name, task_type = list(model_params["task_dict"].items())[0]
82+
target_name, task_type = next(iter(model_params["task_dict"].items()))
8383
assert task_type in ("regression", "classification"), f"invalid {task_type = }"
8484
if warn_target_mismatch and target_name != target_col:
8585
print(

aviary/roost/data.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
with open(elem_embedding) as file:
5656
self.elem_features = json.load(file)
5757

58-
self.elem_emb_len = len(list(self.elem_features.values())[0])
58+
self.elem_emb_len = len(next(iter(self.elem_features.values())))
5959

6060
self.n_targets = []
6161
for target, task in self.task_dict.items():
@@ -98,14 +98,14 @@ def __getitem__(self, idx: int):
9898

9999
try:
100100
elem_fea = np.vstack([self.elem_features[element] for element in elements])
101-
except AssertionError:
101+
except AssertionError as exc:
102102
raise AssertionError(
103103
f"{material_ids} ({composition}) contains element types not in embedding"
104-
)
105-
except ValueError:
104+
) from exc
105+
except ValueError as exc:
106106
raise ValueError(
107107
f"{material_ids} ({composition}) composition cannot be parsed into elements"
108-
)
108+
) from exc
109109

110110
nele = len(elements)
111111
self_idx = []

aviary/segments.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ def forward(self, x: Tensor, index: Tensor) -> Tensor:
4040
gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-10)
4141

4242
x = self.message_nn(x)
43-
out = scatter_add(gate * x, index, dim=0)
44-
45-
return out
43+
return scatter_add(gate * x, index, dim=0)
4644

4745
def __repr__(self) -> str:
4846
gate_nn, message_nn = self.gate_nn, self.message_nn
@@ -82,9 +80,7 @@ def forward(self, x: Tensor, index: Tensor, weights: Tensor) -> Tensor:
8280
gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-10)
8381

8482
x = self.message_nn(x)
85-
out = scatter_add(gate * x, index, dim=0)
86-
87-
return out
83+
return scatter_add(gate * x, index, dim=0)
8884

8985
def __repr__(self) -> str:
9086
pow, gate_nn, message_nn = float(self.pow), self.gate_nn, self.message_nn

aviary/utils.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -805,9 +805,7 @@ def get_metrics(
805805
class1_probas = predictions[:, 1]
806806
metrics["ROCAUC"] = roc_auc_score(targets, class1_probas)
807807

808-
metrics = {key: round(float(val), prec) for key, val in metrics.items()}
809-
810-
return metrics
808+
return {key: round(float(val), prec) for key, val in metrics.items()}
811809

812810

813811
def as_dict_handler(obj: Any) -> dict[str, Any] | None:

aviary/wren/data.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,17 @@ def __init__(
5959
with open(elem_embedding) as emb_file:
6060
self.elem_features = json.load(emb_file)
6161

62-
self.elem_emb_len = len(list(self.elem_features.values())[0])
62+
self.elem_emb_len = len(next(iter(self.elem_features.values())))
6363

6464
if sym_emb in ["bra-alg-off", "spg-alg-off"]:
6565
sym_emb = f"{PKG_DIR}/embeddings/wyckoff/{sym_emb}.json"
6666

6767
with open(sym_emb) as sym_file:
6868
self.sym_features = json.load(sym_file)
6969

70-
self.sym_emb_len = len(list(list(self.sym_features.values())[0].values())[0])
70+
self.sym_emb_len = len(
71+
next(iter(next(iter(self.sym_features.values())).values()))
72+
)
7173

7274
self.n_targets = []
7375
for target, task in self.task_dict.items():

aviary/wren/model.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def forward(
236236
aug_cry_idx (Tensor): Mapping from the crystal idx to augmentation idx
237237
238238
Returns:
239-
Tensor: returns the crystal features of the materials in the batch
239+
Tensor: crystal features of the materials in the batch
240240
"""
241241
# embed the original features into the graph layer description
242242
elem_fea = self.elem_embed(elem_fea)
@@ -254,12 +254,10 @@ def forward(
254254
for attnhead in self.cry_pool
255255
]
256256

257-
crystal_features = scatter_mean(
257+
return scatter_mean(
258258
torch.mean(torch.stack(head_fea), dim=0), aug_cry_idx, dim=0
259259
)
260260

261-
return crystal_features
262-
263261
def __repr__(self) -> str:
264262
return (
265263
f"{type(self).__name__}(n_graph={len(self.graphs)}, cry_heads="

aviary/wren/utils.py

+5-14
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,12 @@ def get_aflow_label_from_spglib(
177177
aflow_label_with_chemsys = get_aflow_label_from_spg_analyzer(
178178
spg_analyzer, errors
179179
)
180+
return aflow_label_with_chemsys
180181

181182
except ValueError as exc:
182-
if errors == "raise":
183-
raise
184183
if errors == "annotate":
185184
return f"invalid spglib: {exc}"
186-
187-
return aflow_label_with_chemsys
185+
raise # we only get here if errors == "raise"
188186

189187

190188
def get_aflow_label_from_spg_analyzer(
@@ -297,9 +295,7 @@ def canonicalize_elem_wyks(elem_wyks: str, spg_num: int) -> str:
297295
scores.append(score)
298296
sorted_iso.append(sorted_el_wyks)
299297

300-
canonical = sorted(zip(scores, sorted_iso), key=lambda x: (x[0], x[1]))[0][1]
301-
302-
return canonical
298+
return sorted(zip(scores, sorted_iso), key=lambda x: (x[0], x[1]))[0][1]
303299

304300

305301
def sort_and_score_wyks(wyks: str) -> tuple[str, int]:
@@ -372,8 +368,6 @@ def count_wyckoff_positions(aflow_label: str) -> int:
372368
Returns:
373369
int: number of distinct Wyckoff positions
374370
"""
375-
num_wyk = 0
376-
377371
aflow_label, _ = aflow_label.split(":") # remove chemical system
378372
# discard prototype formula and spg symbol and spg number
379373
wyk_letters = aflow_label.split("_", maxsplit=3)[-1]
@@ -382,9 +376,7 @@ def count_wyckoff_positions(aflow_label: str) -> int:
382376
wyk_list = re.split("[A-z]", wyk_letters)[:-1] # split on every letter
383377

384378
# count 1 for letters without prefix
385-
num_wyk = sum(1 if len(x) == 0 else int(x) for x in wyk_list)
386-
387-
return num_wyk
379+
return sum(1 if len(x) == 0 else int(x) for x in wyk_list)
388380

389381

390382
def count_crystal_dof(aflow_label: str) -> int:
@@ -488,5 +480,4 @@ def count_distinct_wyckoff_letters(aflow_str: str) -> int:
488480
aflow_str, _ = aflow_str.split(":") # drop chemical system
489481
_, _, _, wyckoff_letters = aflow_str.split("_", 3) # drop prototype, Pearson, spg
490482
wyckoff_letters = wyckoff_letters.translate(remove_digits).replace("_", "")
491-
n_uniq = len(set(wyckoff_letters))
492-
return n_uniq
483+
return len(set(wyckoff_letters)) # number of distinct Wyckoff letters

aviary/wrenformer/data.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,10 @@ def wyckoff_embedding_from_aflow_str(wyckoff_str: str) -> Tensor:
9494
)
9595
element_ratios = element_ratios.repeat(n_augments, 1, 1)
9696

97-
combined_features = torch.cat(
97+
return torch.cat( # combined features
9898
[element_ratios, element_features, symmetry_features], dim=-1
9999
).float()
100100

101-
return combined_features
102-
103101

104102
def get_composition_embedding(formula: str) -> Tensor:
105103
"""Concatenate matscholar element embeddings with element ratios in composition.
@@ -121,9 +119,8 @@ def get_composition_embedding(formula: str) -> Tensor:
121119
element_ratios = torch.tensor(elem_weights)
122120
element_features = torch.tensor(element_features)
123121

124-
combined_features = torch.cat([element_ratios, element_features], dim=1).float()
125-
126-
return combined_features
122+
# combined features
123+
return torch.cat([element_ratios, element_features], dim=1).float()
127124

128125

129126
def df_to_in_mem_dataloader(

examples/inputs/poscar_to_df.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# %%
2+
from __future__ import annotations
3+
24
import glob
35
import os
46

examples/wrenformer/mat_bench/make_plots.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# %%
2+
from __future__ import annotations
3+
24
import json
35
import logging
46
import re

examples/wrenformer/mat_bench/plotting_functions.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1-
from typing import Any, Optional
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
24

35
import numpy as np
4-
import pandas as pd
56
import plotly.express as px
67
import plotly.graph_objs as go
78
import plotly.io as pio
89
from matbench.constants import CLF_KEY, REG_KEY
910
from matbench.metadata import mbv01_metadata
1011
from matbench.metadata import mbv01_metadata as matbench_metadata
11-
from plotly.graph_objs._figure import Figure
1212
from sklearn.metrics import accuracy_score, auc, roc_curve
1313

14+
if TYPE_CHECKING:
15+
import pandas as pd
16+
from plotly.graph_objs._figure import Figure
17+
1418
__author__ = "Janosh Riebesell"
1519
__date__ = "2022-04-25"
1620

@@ -79,7 +83,7 @@ def scale_clf_task(series: pd.Series) -> pd.Series:
7983

8084

8185
def plot_leaderboard(
82-
df: pd.DataFrame, html_path: Optional[str] = None, **kwargs: Any
86+
df: pd.DataFrame, html_path: str | None = None, **kwargs: Any
8387
) -> Figure:
8488
"""Generate the Matbench scaled errors graph seen on
8589
https://matbench.materialsproject.org. Adapted from https://bit.ly/38fDdgt.

pyproject.toml

-3
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,16 @@ select = [
105105
"YTT", # flake8-2020
106106
]
107107
ignore = [
108-
"B904", # Within an except clause, raise exceptions with raise ... from err
109108
"C408", # Unnecessary dict call - rewrite as a literal
110109
"D100", # Missing docstring in public module
111110
"D104", # Missing docstring in public package
112111
"D105", # Missing docstring in magic method
113-
"D107", # Missing docstring in __init__
114112
"D205", # 1 blank line required between summary line and description
115113
"E731", # Do not assign a lambda expression, use a def
116114
"PD901", # pandas-df-variable-name
117115
"PLC1901", # compare-to-empty-string
118116
"PLR", # pylint refactor
119117
"PT006", # pytest-parametrize-names-wrong-type
120-
"RET504", # unnecessary-assign
121118
]
122119
pydocstyle.convention = "google"
123120
isort.known-third-party = ["wandb"]

0 commit comments

Comments
 (0)