diff --git a/CHANGELOG.md b/CHANGELOG.md index 03b15de..6713564 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,20 +3,19 @@ All notable changes to this project will be documented in this file. ## [0.11.0] + +Multi-fidelity learning implemented & New pretrained-models + ### Added - Build multi-fidelity model, SevenNet-MF, based on given modality in the yaml - Modality support for sevenn_inference, sevenn_get_modal, and SevenNetCalculator -- [cli] sevenn_cp tool for checkpoint summary, input generation, multi-modal routines +- sevenn_cp tool for checkpoint summary, input generation, multi-modal routines - Modality append / assign using sevenn_cp - Loss weighting for energy, force and stress for corresponding data label - Ignore unlabelled data when calculating loss. (e.g. stress data for non-pbc structure) - Dict style dataset input for multi-modal and data-weight - (experimental) cuEquivariance support - -### Added (code) -- sevenn.train.modal_dataset SevenNetMultiModalDataset -- sevenn.scripts.backward_compatibility.py -- sevenn.checkpoint.py +- Downloading large checkpoints from url (7net-MF-ompa, 7net-omat) - D3 wB97M param ### Changed diff --git a/README.md b/README.md index b19707d..b216fb4 100644 --- a/README.md +++ b/README.md @@ -5,10 +5,8 @@ SevenNet (Scalable EquiVariance Enabled Neural Network) is a graph neural network (GNN) interatomic potential package that supports parallel molecular dynamics simulations with [`LAMMPS`](https://lammps.org). Its underlying GNN model is based on [`NequIP`](https://github.com/mir-group/nequip). -> [!CAUTION] -> SevenNet+LAMMPS parallel after the commit id of `14851ef (v0.9.3 ~ 0.9.5)` has a serious bug. -> It gives wrong forces when the number of mpi processes is greater than two. The corresponding pip version is yanked for this reason. The bug is fixed for the main branch since `v0.10.0`, and pip (`v0.9.3.post0`). - +> [!NOTE] +> We will soon release a CUDA-accelerated version of SevenNet, which will significantly increase the speed of our pre-trained models on [Matbench Discovery](https://matbench-discovery.materialsproject.org/). ## Features - Pre-trained GNN interatomic potential and fine-tuning interface. @@ -19,29 +17,66 @@ SevenNet (Scalable EquiVariance Enabled Neural Network) is a graph neural networ ## Pre-trained models So far, we have released three pre-trained SevenNet models. Each model has various hyperparameters and training sets, resulting in different accuracy and speed. Please read the descriptions below carefully and choose the model that best suits your purpose. -We provide the training set MAEs (energy, force, and stress) F1 score for WBM dataset and $\kappa_{\mathrm{SRME}}$ from phonondb. For details on these metrics and performance comparisons with other pre-trained models, please visit [Matbench Discovery](https://matbench-discovery.materialsproject.org/). +We provide the training set MAEs (energy, force, and stress) F1 score, and RMSD for the WBM dataset, as well as $\kappa_{\mathrm{SRME}}$ from phonondb and CPS (Combined Performance Score). For details on these metrics and performance comparisons with other pre-trained models, please visit [Matbench Discovery](https://matbench-discovery.materialsproject.org/). These models can be used as interatomic potential on LAMMPS, and also can be loaded through ASE calculator by calling the `keywords` of each model. Please refer [ASE calculator](#ase_calculator) to see the way to load a model through ASE calculator. Additionally, `keywords` can be called in other parts of SevenNet, such as `sevenn_inference`, `sevenn_get_model`, and `checkpoint` in `input.yaml` for fine-tuning. **Acknowledgments**: The models trained on [`MPtrj`](https://figshare.com/articles/dataset/Materials_Project_Trjectory_MPtrj_Dataset/23713842) were supported by the Neural Processing Research Center program of Samsung Advanced Institute of Technology, Samsung Electronics Co., Ltd. The computations for training models were carried out using the Samsung SSC-21 cluster. +--- + +### **SevenNet-MF-ompa (17Mar2025)** +> Model keywords: `7net-mf-ompa` | `SevenNet-mf-ompa` + +**This is our recommended pre-trained model** + +This model leverages [multi-fidelity learning](https://pubs.acs.org/doi/10.1021/jacs.4c14455) to simultaneously train on the [MPtrj](https://figshare.com/articles/dataset/Materials_Project_Trjectory_MPtrj_Dataset/23713842), [sAlex](https://huggingface.co/datasets/fairchem/OMAT24), and [OMat24](https://huggingface.co/datasets/fairchem/OMAT24) datasets. As of March 17, 2025, it has achieved state-of-the-art performance on the [Matbench Discovery](https://matbench-discovery.materialsproject.org/) in the CPS (Combined Performance Score). We have found that this model outperforms most tasks, except for isolated molecule energy, where it performs slightly worse than SevenNet-l3i5. + +```python +from sevenn.calculator import SevenNetCalculator +# "mpa" refers to the MPtrj + sAlex modal, used for evaluating Matbench Discovery. +calc = SevenNetCalculator('7net-mf-ompa', modal='mpa') # Use modal='omat24' for OMat24-trained modal weights. +``` +Theoretically, the `mpa` modal should produce PBE52 results, while the `omat24` modal yields PBE54 results. + +When using the command-line interface of SevenNet, include the `--modal mpa` or `--modal omat24` option to select the desired modality. + + +#### **Matbench Discovery** +| CPS | F1 | $\kappa_{\mathrm{SRME}}$ | RMSD | +|:---:|:---:|:---:|:---:| +|**0.883**|**0.901**|0.317| **0.0115** | + +[Detailed instructions for multi-fidelity](https://github.com/MDIL-SNU/SevenNet/blob/main/sevenn/pretrained_potentials/SevenNet_MF_0/README.md) + +[Link to the full-information checkpoint](https://figshare.com/articles/software/7net_MF_ompa/28590722?file=53029859) --- +### **SevenNet-omat (17Mar2025)** +> Model keywords: `7net-omat` | `SevenNet-omat` + + This model was trained solely on the [OMat24](https://huggingface.co/datasets/fairchem/OMAT24) dataset. It achieves state-of-the-art (SOTA) performance in $\kappa_{\mathrm{SRME}}$ on [Matbench Discovery](https://matbench-discovery.materialsproject.org/); however, the F1 score was not available due to a difference in the POTCAR version. Similar to `SevenNet-MF-ompa`, this model outperforms `SevenNet-l3i5` in most tasks, except for isolated molecule energy. +[Link to the full-information checkpoint](https://figshare.com/articles/software/SevenNet_omat/28593938). + +#### **Matbench Discovery** +* $\kappa_{\mathrm{SRME}}$: **0.221** +--- ### **SevenNet-l3i5 (12Dec2024)** -> Keywords in ASE: `7net-l3i5` and `SevenNet-l3i5` +> Model keywords: `7net-l3i5` | `SevenNet-l3i5` + +The model increases the maximum spherical harmonic degree ($l_{\mathrm{max}}$) to 3, compared to `SevenNet-0` with $l_{\mathrm{max}}$ of 2. While **l3i5** offers improved accuracy across various systems compared to `SevenNet-0`, it is approximately four times slower. As of March 17, 2025, this model has achieved state-of-the-art (SOTA) performance on the CPS metric among compliant models, newly introduced in this [Matbench Discovery](https://matbench-discovery.materialsproject.org/). -The model increases the maximum spherical harmonic degree ($l_{\mathrm{max}}$) to 3, compared to **SevenNet-0 (11Jul2024)** with $l_{\mathrm{max}}$ of 2. -While **l3i5** offers improved accuracy across various systems compared to **SevenNet-0 (11Jul2024)**, it is approximately four times slower. +#### **Matbench Discovery** +| CPS | F1 | $\kappa_{\mathrm{SRME}}$ | RMSD | +|:---:|:---:|:---:|:---:| +|0.764 |0.76|0.55|0.0182| -* Training set MAE: 8.3 meV/atom (energy), 0.029 eV/Ang. (force), and 2.33 kbar (stress) -* Matbench F1 score: 0.76, $\kappa_{\mathrm{SRME}}$: 0.560 -* Training time: 381 GPU-days on A100 --- ### **SevenNet-0 (11Jul2024)** -> Keywords in ASE: `7net-0`, `SevenNet-0`, `7net-0_11Jul2024`, and `SevenNet-0_11Jul2024` +> Model keywords:: `7net-0` | `SevenNet-0` | `7net-0_11Jul2024` | `SevenNet-0_11Jul2024` The model architecture is mainly line with [GNoME](https://github.com/google-deepmind/materials_discovery), a pretrained model that utilizes the NequIP architecture. Five interaction blocks with node features that consist of 128 scalars (*l*=0), 64 vectors (*l*=1), and 32 tensors (*l*=2). @@ -50,9 +85,11 @@ The model was trained with [MPtrj](https://figshare.com/articles/dataset/Materia This model is loaded as the default pre-trained model in ASE calculator. For more information, click [here](sevenn/pretrained_potentials/SevenNet_0__11Jul2024). -* Training set MAE: 11.5 meV/atom (energy), 0.041 eV/Ang. (force), and 2.78 kbar (stress) -* Matbench F1 score: 0.67, $\kappa_{\mathrm{SRME}}$: 0.767 -* Training time: 90 GPU-days on A100 +#### **Matbench Discovery** +| F1 | $\kappa_{\mathrm{SRME}}$ | +|:---:|:---:| +|0.67|0.767| + --- In addition to these latest models, you can find our legacy models from [pretrained_potentials](./sevenn/pretrained_potentials). @@ -106,7 +143,6 @@ The model can be loaded through the following Python code. from sevenn.calculator import SevenNetCalculator calc = SevenNetCalculator(model='7net-0', device='cpu') ``` - SevenNet supports CUDA accelerated D3Calculator. ```python from sevenn.calculator import SevenNetD3Calculator diff --git a/pyproject.toml b/pyproject.toml index 543ff05..093a0c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "numpy", "matscipy", "pandas", + "requests", ] [project.optional-dependencies] test = ["matscipy", "pytest-cov>=5"] diff --git a/setup.cfg b/setup.cfg index 84ac35f..1505c8b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,5 +10,5 @@ include_trailing_comma=True force_grid_wrap=0 use_parentheses=True line_length=80 -known_third_party=ase,braceexpand,e3nn,numpy,packaging,pandas,pytest,sklearn,torch,torch_geometric,tqdm,yaml +known_third_party=ase,braceexpand,e3nn,numpy,packaging,pandas,pytest,requests,sklearn,torch,torch_geometric,tqdm,yaml known_first_party= diff --git a/sevenn/_const.py b/sevenn/_const.py index 05414e4..40351d3 100644 --- a/sevenn/_const.py +++ b/sevenn/_const.py @@ -48,20 +48,18 @@ ACTIVATION_DICT = {'e': ACTIVATION_FOR_EVEN, 'o': ACTIVATION_FOR_ODD} _prefix = os.path.abspath(f'{os.path.dirname(__file__)}/pretrained_potentials') -SEVENNET_0_11Jul2024 = ( - f'{_prefix}/SevenNet_0__11Jul2024/checkpoint_sevennet_0.pth' -) -SEVENNET_0_22May2024 = ( - f'{_prefix}/SevenNet_0__22May2024/checkpoint_sevennet_0.pth' -) -SEVENNET_l3i5 = ( - f'{_prefix}/SevenNet_l3i5/checkpoint_l3i5.pth' -) -SEVENNET_MF_0 = ( - f'{_prefix}/SevenNet_MF_0/checkpoint_sevennet_mf_0.pth' -) - - +SEVENNET_0_11Jul2024 = f'{_prefix}/SevenNet_0__11Jul2024/checkpoint_sevennet_0.pth' +SEVENNET_0_22May2024 = f'{_prefix}/SevenNet_0__22May2024/checkpoint_sevennet_0.pth' +SEVENNET_l3i5 = f'{_prefix}/SevenNet_l3i5/checkpoint_l3i5.pth' +SEVENNET_MF_0 = f'{_prefix}/SevenNet_MF_0/checkpoint_sevennet_mf_0.pth' +SEVENNET_MF_ompa = f'{_prefix}/SevenNet_MF_ompa/checkpoint_sevennet_mf_ompa.pth' +SEVENNET_omat = f'{_prefix}/SevenNet_omat/checkpoint_sevennet_omat.pth' + +_git_prefix = 'https://github.com/MDIL-SNU/SevenNet/releases/download' +CHECKPOINT_DOWNLOAD_LINKS = { + SEVENNET_MF_ompa: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_mf_ompa.pth', + SEVENNET_omat: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_omat.pth', +} # to avoid torch script to compile torch_geometry.data AtomGraphDataType = Dict[str, torch.Tensor] @@ -143,7 +141,9 @@ def error_record_condition(x): }, KEY.CUTOFF: float, KEY.NUM_CONVOLUTION: int, - KEY.CONV_DENOMINATOR: lambda x: isinstance(x, float) or x in [ + KEY.CONV_DENOMINATOR: lambda x: isinstance(x, float) + or x + in [ 'avg_num_neigh', 'sqrt_avg_num_neigh', ], diff --git a/sevenn/calculator.py b/sevenn/calculator.py index e22e2d4..8004d82 100644 --- a/sevenn/calculator.py +++ b/sevenn/calculator.py @@ -2,7 +2,7 @@ import os import pathlib import warnings -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union import numpy as np import torch @@ -22,17 +22,12 @@ class SevenNetCalculator(Calculator): - """ASE calculator for SevenNet models + """Supporting properties: + 'free_energy', 'energy', 'forces', 'stress', 'energies' + free_energy equals energy. 'energies' stores atomic energy. - Multi-GPU parallel MD is not supported for this mode. - Use LAMMPS for multi-GPU parallel MD. - This class is for convenience who want to run SevenNet models with ase. - - Note than ASE calculator is designed to be interface of other programs. - But in this class, we simply run torch model inside ASE calculator. - So there is no FileIO things. - - Here, free_energy = energy + Multi-GPU acceleration is not supported with ASE calculator. + You should use LAMMPS for the acceleration. """ def __init__( @@ -42,14 +37,28 @@ def __init__( device: Union[torch.device, str] = 'auto', modal: Optional[str] = None, enable_cueq: bool = False, - sevennet_config: Optional[Any] = None, # hold meta information + sevennet_config: Optional[Dict] = None, # Not used in logic, just meta info **kwargs, ): - """Initialize the calculator - - Args: - model (SevenNet): path to the checkpoint file, or pretrained - device (str, optional): Torch device to use. Defaults to "auto". + """Initialize SevenNetCalculator. + + Parameters + ---------- + model: str | Path | AtomGraphSequential, default='7net-0' + Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or + path to the checkpoint, deployed model or the model itself + file_type: str, default='checkpoint' + one of 'checkpoint' | 'torchscript' | 'model_instance' + device: str | torch.device, default='auto' + if not given, use CUDA if available + modal: str | None, default=None + modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa, + it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24) + case insensitive + enable_cueq: bool, default=False + if True, use cuEquivariant to accelerate inference. + sevennet_config: dict | None, default=None + Not used, but can be used to carry meta information of this calculator """ super().__init__(**kwargs) self.sevennet_config = None @@ -131,18 +140,21 @@ def __init__( self.model = model_loaded - if isinstance(self.model, AtomGraphSequential) and modal: - if self.model.modal_map is None: - raise ValueError('Modality given, but model has no modal_map') - if modal not in self.model.modal_map: - _modals = list(self.model.modal_map.keys()) - raise ValueError(f'Unknown modal {modal} (not in {_modals})') + self.modal = None + if isinstance(self.model, AtomGraphSequential): + modal_map = self.model.modal_map + if modal_map: + modal_ava = list(modal_map.keys()) + if not modal: + raise ValueError(f'modal argument missing (avail: {modal_ava})') + elif modal not in modal_ava: + raise ValueError(f'unknown modal {modal} (not in {modal_ava})') + self.modal = modal + elif not self.model.modal_map and modal: + warnings.warn(f'modal={modal} is ignored as model has no modal_map') self.model.to(self.device) self.model.eval() - - self.modal = modal - self.implemented_properties = [ 'free_energy', 'energy', @@ -216,6 +228,31 @@ def __init__( cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au **kwargs, ): + """Initialize SevenNetD3Calculator. CUDA required. + + Parameters + ---------- + model: str | Path | AtomGraphSequential + Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or + path to the checkpoint, deployed model or the model itself + file_type: str, default='checkpoint' + one of 'checkpoint' | 'torchscript' | 'model_instance' + device: str | torch.device, default='auto' + if not given, use CUDA if available + modal: str | None, default=None + modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa, + it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24) + enable_cueq: bool, default=False + if True, use cuEquivariant to accelerate inference. + damping_type: str, default='damp_bj' + Damping type of D3, one of 'damp_bj' | 'damp_zero' + functional_name: str, default='pbe' + Target functional name of D3 parameters. + vdw_cutoff: float, default=9000 + vdw cutoff of D3 calculator in au + cn_cutoff: float, default=1600 + cn cutoff of D3 calculator in au + """ d3_calc = D3Calculator( damping_type=damping_type, functional_name=functional_name, @@ -267,9 +304,7 @@ def _load(name: str) -> ctypes.CDLL: load( name=name, - sources=[ - os.path.join(package_dir, 'pair_e3gnn', 'pair_d3_for_ase.cu') - ], + sources=[os.path.join(package_dir, 'pair_e3gnn', 'pair_d3_for_ase.cu')], extra_cuda_cflags=['-O3', '--expt-relaxed-constexpr', '-fmad=false'], build_directory=compile_dir, verbose=True, diff --git a/sevenn/util.py b/sevenn/util.py index c732550..2077f43 100644 --- a/sevenn/util.py +++ b/sevenn/util.py @@ -1,12 +1,17 @@ import os +import os.path as osp import pathlib +import shutil from typing import Dict, List, Tuple, Union import numpy as np +import requests import torch import torch.nn from e3nn.o3 import FullTensorProduct, Irreps +from tqdm import tqdm +import sevenn._const as _const import sevenn._keys as KEY from sevenn.checkpoint import SevenNetCheckpoint @@ -185,12 +190,52 @@ def infer_irreps_out( return Irreps(new_irreps_elem) -def pretrained_name_to_path(name: str) -> str: - import sevenn._const as _const +def download_checkpoint(path: str, url: str): + fname = osp.basename(path) + temp_path = path + '.partial' + try: + # raises permission error if fails + os.makedirs(osp.dirname(path), exist_ok=True) + response = requests.get(url, stream=True, timeout=30) + response.raise_for_status() # Raise exception for bad status codes + + total_size = int(response.headers.get('content-length', 0)) + block_size = 1024 # 1 KB chunks + + progress_bar = tqdm( + total=total_size, + unit='B', + unit_scale=True, + desc=f'Downloading {fname}', + ) + + with open(temp_path, 'wb') as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + + shutil.move(temp_path, path) + print(f'Checkpoint downloaded: {path}') + return path + except PermissionError: + raise + except Exception as e: + # Clean up partial downloads on failure + # May not work as errors handled internally by tqdm etc. + print(f'Download failed: {str(e)}') + if os.path.exists(temp_path): + print(f'Cleaning up partial download: {temp_path}') + os.remove(temp_path) + raise + +def pretrained_name_to_path(name: str) -> str: name = name.lower() heads = ['sevennet', '7net'] checkpoint_path = None + url = None + if ( # TODO: regex name in [f'{n}-0_11july2024' for n in heads] or name in [f'{n}-0_11jul2024' for n in heads] @@ -203,21 +248,44 @@ def pretrained_name_to_path(name: str) -> str: checkpoint_path = _const.SEVENNET_l3i5 elif name in [f'{n}-mf-0' for n in heads]: checkpoint_path = _const.SEVENNET_MF_0 + elif name in [f'{n}-mf-ompa' for n in heads]: + checkpoint_path = _const.SEVENNET_MF_ompa + elif name in [f'{n}-omat' for n in heads]: + checkpoint_path = _const.SEVENNET_omat else: - raise ValueError('Not a valid potential') + raise ValueError('Not a valid pretrained model name') + url = _const.CHECKPOINT_DOWNLOAD_LINKS.get(checkpoint_path) + + paths = [ + checkpoint_path, + checkpoint_path.replace(_const._prefix, osp.expanduser('~/.cache/sevennet')), + ] + + for path in paths: + if osp.exists(path): + return path + + # File not found check url and try download + if url is None: + raise FileNotFoundError(checkpoint_path) - return checkpoint_path + try: + return download_checkpoint(paths[0], url) # 7net package path + except PermissionError: + return download_checkpoint(paths[1], url) # ~/.cache def load_checkpoint(checkpoint: Union[pathlib.Path, str]): - if os.path.isfile(checkpoint): + suggests = ['7net-0, 7net-l3i5, 7net-mf-ompa, 7net-omat'] + if osp.isfile(checkpoint): checkpoint_path = checkpoint else: try: checkpoint_path = pretrained_name_to_path(str(checkpoint)) except ValueError: raise ValueError( - f'Given {checkpoint} is not exists and not a pre-trained name' + f'Given {checkpoint} is not exists and not a pre-trained name.\n' + f'Valid pretrained model names: {suggests}' ) return SevenNetCheckpoint(checkpoint_path) diff --git a/tests/unit_tests/test_pretrained.py b/tests/unit_tests/test_pretrained.py index 58e33d0..4676ca4 100644 --- a/tests/unit_tests/test_pretrained.py +++ b/tests/unit_tests/test_pretrained.py @@ -196,9 +196,9 @@ def test_7net_mf_0(atoms_pbc, atoms_mol): g2_ref_e = torch.tensor([-14.172412872314453]) g2_ref_f = torch.tensor( [ - [4.6566129e-10, -1.3429364e+01, 6.9344816e+00], - [2.3283064e-09, 8.9132404e+00, -9.6807365e+00], - [-2.7939677e-09, 4.5161238e+00, 2.7462559e+00], + [4.6566129e-10, -1.3429364e01, 6.9344816e00], + [2.3283064e-09, 8.9132404e00, -9.6807365e00], + [-2.7939677e-09, 4.5161238e00, 2.7462559e00], ] ) @@ -208,3 +208,137 @@ def test_7net_mf_0(atoms_pbc, atoms_mol): assert acl(g2.inferred_total_energy, g2_ref_e) assert acl(g2.inferred_force, g2_ref_f) + + +def test_7net_mf_ompa_mpa(atoms_pbc, atoms_mol): + cp_path = pretrained_name_to_path('7net-mf-ompa') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + # mpa + g1[KEY.DATA_MODALITY] = 'mpa' + g2[KEY.DATA_MODALITY] = 'mpa' + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + model.set_is_batch_data(True) + + g1_ref_e = torch.tensor([-3.490943193435669]) + g1_ref_f = torch.tensor( + [ + [1.2680445e01, -2.7985498e-04, -2.7979910e-04], + [-1.2680446e01, 2.7984008e-04, 2.7981028e-04], + ] + ) + g1_ref_s = -1 * torch.tensor( + # xx, yy, zz, xy, yz, zx + [-0.6481662, -0.02462837, -0.02462837, 0.02693467, 0.00459635, 0.02693467] + ) + + g2_ref_e = torch.tensor([-12.597525596618652]) + g2_ref_f = torch.tensor( + [ + [0.0, -12.245223, 7.26795], + [0.0, 8.816763, -9.423925], + [0.0, 3.4284601, 2.1559749], + ] + ) + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f) + assert acl(g1.inferred_stress, g1_ref_s) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f) + + +def test_7net_mf_ompa_omat(atoms_pbc, atoms_mol): + cp_path = pretrained_name_to_path('7net-mf-ompa') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + # mpa + g1[KEY.DATA_MODALITY] = 'omat24' + g2[KEY.DATA_MODALITY] = 'omat24' + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + model.set_is_batch_data(True) + + g1_ref_e = torch.tensor([-3.5094668865203857]) + g1_ref_f = torch.tensor( + [ + [1.2562084e01, -1.4219694e-03, -1.4219843e-03], + [-1.2562084e01, 1.4219508e-03, 1.4219955e-03], + ] + ) + g1_ref_s = -1 * torch.tensor( + # xx, yy, zz, xy, yz, zx + [-0.6430905, -0.0254128, -0.02541281, 0.0268343, 0.00460021, 0.0268343] + ) + + g2_ref_e = torch.tensor([-12.6202974319458]) + g2_ref_f = torch.tensor( + [ + [0.0, -12.205926, 7.2050343], + [0.0, 8.790399, -9.368677], + [0.0, 3.4155273, 2.163643], + ] + ) + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f) + assert acl(g1.inferred_stress, g1_ref_s) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f) + + +def test_7net_omat(atoms_pbc, atoms_mol): + cp_path = pretrained_name_to_path('7net-omat') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + model.set_is_batch_data(True) + + g1_ref_e = torch.tensor([-3.5033323764801025]) + g1_ref_f = torch.tensor( + [ + [12.533154, 0.02358698, 0.02358694], + [-12.533153, -0.02358699, -0.02358697], + ] + ) + g1_ref_s = -1 * torch.tensor( + # xx, yy, zz, xy, yz, zx + [-0.6420925, -0.02781446, -0.02781446, 0.02575445, 0.00381664, 0.02575445] + ) + + g2_ref_e = torch.tensor([-12.403768539428711]) + g2_ref_f = torch.tensor( + [ + [0, -12.848297, 7.11432], + [0.0, 9.265477, -9.564951], + [0.0, 3.58282, 2.4506311], + ] + ) + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f) + assert acl(g1.inferred_stress, g1_ref_s) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f)