From b4772e25926757c6b032e59e3770b2f4eb4e4129 Mon Sep 17 00:00:00 2001 From: Stefan Doerr Date: Wed, 12 Feb 2025 11:20:44 +0100 Subject: [PATCH] blacken code (#359) --- benchmarks/inference.py | 1 + benchmarks/neighbors.py | 1 + examples/openmm-integration.py | 16 +++-- tests/test_datasets.py | 8 ++- tests/test_mdcath.py | 4 +- tests/test_priors.py | 87 +++++++++++++++++------ tests/utils.py | 6 +- torchmdnet/data.py | 1 + torchmdnet/datasets/ani.py | 10 +-- torchmdnet/datasets/comp6.py | 23 +++--- torchmdnet/datasets/md17.py | 105 +++++++++++++++------------- torchmdnet/datasets/md22.py | 35 +++++----- torchmdnet/datasets/mdcath.py | 102 +++++++++++++++++---------- torchmdnet/datasets/water.py | 41 +++++++---- torchmdnet/models/__init__.py | 7 +- torchmdnet/models/output_modules.py | 18 ++--- torchmdnet/models/torchmd_et.py | 1 + torchmdnet/models/wrappers.py | 1 + torchmdnet/priors/base.py | 10 ++- torchmdnet/priors/coulomb.py | 58 +++++++++++---- torchmdnet/priors/d2.py | 13 +++- torchmdnet/priors/zbl.py | 10 ++- torchmdnet/scripts/train.py | 2 +- torchmdnet/utils.py | 12 ++-- 24 files changed, 369 insertions(+), 203 deletions(-) diff --git a/benchmarks/inference.py b/benchmarks/inference.py index 1fd271b3c..9d694bf5f 100644 --- a/benchmarks/inference.py +++ b/benchmarks/inference.py @@ -122,6 +122,7 @@ def benchmark_pdb(pdb_file, **kwargs): "2L emb 64": {"num_layers": 2, "embedding_dimension": 64}, } + def benchmark_all(): timings = {} for pdb_file in os.listdir("systems"): diff --git a/benchmarks/neighbors.py b/benchmarks/neighbors.py index d4e6d3ad0..2db969d92 100644 --- a/benchmarks/neighbors.py +++ b/benchmarks/neighbors.py @@ -10,6 +10,7 @@ from typing import Optional from torch_cluster import radius_graph + class Distance(nn.Module): def __init__( self, diff --git a/examples/openmm-integration.py b/examples/openmm-integration.py index 47953e9b5..ad4e7043f 100644 --- a/examples/openmm-integration.py +++ b/examples/openmm-integration.py @@ -5,7 +5,9 @@ import openmm import openmmtorch except ImportError: - raise ImportError("Please install OpenMM and OpenMM-Torch (you can use conda install -c conda-forge openmm openmm-torch)") + raise ImportError( + "Please install OpenMM and OpenMM-Torch (you can use conda install -c conda-forge openmm openmm-torch)" + ) import sys import torch @@ -34,9 +36,9 @@ def __init__(self, embeddings, model): def forward(self, positions): # OpenMM works with nanometer positions and kilojoule per mole energies # Depending on the model, you might need to convert the units - positions = positions.to(torch.float32) * 10.0 # nm -> A + positions = positions.to(torch.float32) * 10.0 # nm -> A energy = self.model(z=self.embeddings, pos=positions)[0] - return energy * 96.4916 # eV -> kJ/mol + return energy * 96.4916 # eV -> kJ/mol pdb = PDBFile("../benchmarks/systems/chignolin.pdb") @@ -54,9 +56,11 @@ def forward(self, positions): for atom in pdb.topology.atoms(): system.addParticle(atom.element.mass) system.addForce(torch_force) -integrator = LangevinMiddleIntegrator(298.15*kelvin, 1/picosecond, 2*femtosecond) -platform = Platform.getPlatformByName('CPU') +integrator = LangevinMiddleIntegrator(298.15 * kelvin, 1 / picosecond, 2 * femtosecond) +platform = Platform.getPlatformByName("CPU") simulation = Simulation(pdb.topology, system, integrator, platform) simulation.context.setPositions(pdb.positions) -simulation.reporters.append(StateDataReporter(sys.stdout, 1, step=True, potentialEnergy=True, temperature=True)) +simulation.reporters.append( + StateDataReporter(sys.stdout, 1, step=True, potentialEnergy=True, temperature=True) +) simulation.step(10) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 8cdf0d1ae..7402f8f31 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -328,6 +328,10 @@ def test_hdf5_with_and_without_caching(num_files, tile_embed, batch_size, tmpdir for sample_cached, sample in zip(dl_cached, dl): assert np.allclose(sample_cached.pos, sample.pos), "Sample has incorrect coords" - assert np.allclose(sample_cached.z, sample.z), "Sample has incorrect atom numbers" + assert np.allclose( + sample_cached.z, sample.z + ), "Sample has incorrect atom numbers" assert np.allclose(sample_cached.y, sample.y), "Sample has incorrect energy" - assert np.allclose(sample_cached.neg_dy, sample.neg_dy), "Sample has incorrect forces" \ No newline at end of file + assert np.allclose( + sample_cached.neg_dy, sample.neg_dy + ), "Sample has incorrect forces" diff --git a/tests/test_mdcath.py b/tests/test_mdcath.py index 0c7443426..0d2add0bc 100644 --- a/tests/test_mdcath.py +++ b/tests/test_mdcath.py @@ -167,9 +167,7 @@ def test_mdcath_args(tmpdir, skipframes, batch_size, pdb_list): data.flush() data.close() - dataset = MDCATH( - root=tmpdir, skip_frames=skipframes, pdb_list=pdb_list - ) + dataset = MDCATH(root=tmpdir, skip_frames=skipframes, pdb_list=pdb_list) dl = DataLoader( dataset, batch_size=batch_size, diff --git a/tests/test_priors.py b/tests/test_priors.py index c77d49263..5f28caf69 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -38,74 +38,115 @@ def test_atomref(model_name, enable_atomref): # check if the output of both models differs by the expected atomref contribution if enable_atomref: - expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(1) + expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze( + 1 + ) else: expected_offset = 0 torch.testing.assert_close(x_atomref, x_no_atomref + expected_offset) + @mark.parametrize("trainable", [True, False]) def test_atomref_trainable(trainable): dataset = DummyDataset(has_atomref=True) atomref = Atomref(max_z=100, dataset=dataset, trainable=trainable) assert atomref.atomref.weight.requires_grad == trainable + def test_learnableatomref(): atomref = LearnableAtomref(max_z=100) assert atomref.atomref.weight.requires_grad == True + def test_zbl(): - pos = torch.tensor([[1.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, -1.0]], dtype=torch.float32) # Atom positions in Bohr + pos = torch.tensor( + [[1.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, -1.0]], + dtype=torch.float32, + ) # Atom positions in Bohr types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types - atomic_number = torch.tensor([1, 6, 8], dtype=torch.int8) # Mapping of atom types to atomic numbers + atomic_number = torch.tensor( + [1, 6, 8], dtype=torch.int8 + ) # Mapping of atom types to atomic numbers distance_scale = 5.29177210903e-11 # Convert Bohr to meters - energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules + energy_scale = 1000.0 / 6.02214076e23 # Convert kJ/mol to Joules # Use the ZBL class to compute the energy. - zbl = ZBL(10.0, 5, atomic_number, distance_scale=distance_scale, energy_scale=energy_scale) - energy = zbl.post_reduce(torch.zeros((1,)), types, pos, torch.zeros_like(types), None, {})[0] + zbl = ZBL( + 10.0, 5, atomic_number, distance_scale=distance_scale, energy_scale=energy_scale + ) + energy = zbl.post_reduce( + torch.zeros((1,)), types, pos, torch.zeros_like(types), None, {} + )[0] # Compare to the expected value. def compute_interaction(pos1, pos2, z1, z2): - delta = pos1-pos2 + delta = pos1 - pos2 r = torch.sqrt(torch.dot(delta, delta)) - x = r / (0.8854/(z1**0.23 + z2**0.23)) - phi = 0.1818*torch.exp(-3.2*x) + 0.5099*torch.exp(-0.9423*x) + 0.2802*torch.exp(-0.4029*x) + 0.02817*torch.exp(-0.2016*x) - cutoff = 0.5*(torch.cos(r*torch.pi/10.0) + 1.0) - return cutoff*phi*(138.935/5.29177210903e-2)*z1*z2/r + x = r / (0.8854 / (z1**0.23 + z2**0.23)) + phi = ( + 0.1818 * torch.exp(-3.2 * x) + + 0.5099 * torch.exp(-0.9423 * x) + + 0.2802 * torch.exp(-0.4029 * x) + + 0.02817 * torch.exp(-0.2016 * x) + ) + cutoff = 0.5 * (torch.cos(r * torch.pi / 10.0) + 1.0) + return cutoff * phi * (138.935 / 5.29177210903e-2) * z1 * z2 / r expected = 0 for i in range(len(pos)): for j in range(i): - expected += compute_interaction(pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]]) + expected += compute_interaction( + pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]] + ) torch.testing.assert_close(expected, energy, rtol=1e-4, atol=1e-4) + @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) def test_coulomb(dtype): - pos = torch.tensor([[0.5, 0.0, 0.0], [1.5, 0.0, 0.0], [0.8, 0.8, 0.0], [0.0, 0.0, -0.4]], dtype=dtype) # Atom positions in nm + pos = torch.tensor( + [[0.5, 0.0, 0.0], [1.5, 0.0, 0.0], [0.8, 0.8, 0.0], [0.0, 0.0, -0.4]], + dtype=dtype, + ) # Atom positions in nm charge = torch.tensor([0.2, -0.1, 0.8, -0.9], dtype=dtype) # Partial charges types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types distance_scale = 1e-9 # Convert nm to meters - energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules + energy_scale = 1000.0 / 6.02214076e23 # Convert kJ/mol to Joules lower_switch_distance = 0.9 upper_switch_distance = 1.3 # Use the Coulomb class to compute the energy. - coulomb = Coulomb(lower_switch_distance, upper_switch_distance, 5, distance_scale=distance_scale, energy_scale=energy_scale) - energy = coulomb.post_reduce(torch.zeros((1,)), types, pos, torch.zeros_like(types), extra_args={'partial_charges':charge})[0] + coulomb = Coulomb( + lower_switch_distance, + upper_switch_distance, + 5, + distance_scale=distance_scale, + energy_scale=energy_scale, + ) + energy = coulomb.post_reduce( + torch.zeros((1,)), + types, + pos, + torch.zeros_like(types), + extra_args={"partial_charges": charge}, + )[0] # Compare to the expected value. def compute_interaction(pos1, pos2, z1, z2): - delta = pos1-pos2 + delta = pos1 - pos2 r = torch.sqrt(torch.dot(delta, delta)) if r < lower_switch_distance: return 0 - energy = 138.935*z1*z2/r + energy = 138.935 * z1 * z2 / r if r < upper_switch_distance: - energy *= 0.5-0.5*torch.cos(torch.pi*(r-lower_switch_distance)/(upper_switch_distance-lower_switch_distance)) + energy *= 0.5 - 0.5 * torch.cos( + torch.pi + * (r - lower_switch_distance) + / (upper_switch_distance - lower_switch_distance) + ) return energy expected = 0 @@ -120,10 +161,12 @@ def test_multiple_priors(dtype): # Create a model from a config file. dataset = DummyDataset(has_atomref=True) - config_file = join(dirname(__file__), 'priors.yaml') - args = load_example_args('equivariant-transformer', config_file=config_file, dtype=dtype) + config_file = join(dirname(__file__), "priors.yaml") + args = load_example_args( + "equivariant-transformer", config_file=config_file, dtype=dtype + ) prior_models = create_prior_models(args, dataset) - args['prior_args'] = [p.get_init_args() for p in prior_models] + args["prior_args"] = [p.get_init_args() for p in prior_models] model = LNNP(args, prior_model=prior_models) priors = model.model.prior_model diff --git a/tests/utils.py b/tests/utils.py index ef8bcddb9..effd5781d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,7 +11,9 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs): if config_file is None: if model_name == "tensornet": - config_file = join(dirname(dirname(__file__)), "examples", "TensorNet-QM9.yaml") + config_file = join( + dirname(dirname(__file__)), "examples", "TensorNet-QM9.yaml" + ) else: config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml") with open(config_file, "r") as f: @@ -84,7 +86,7 @@ def _get_atomref(self): return self.atomref DummyDataset.get_atomref = _get_atomref - self.atomic_number = torch.arange(max(atom_types)+1) + self.atomic_number = torch.arange(max(atom_types) + 1) self.distance_scale = 1.0 self.energy_scale = 1.0 diff --git a/torchmdnet/data.py b/torchmdnet/data.py index 986e19f79..dd1138c7d 100644 --- a/torchmdnet/data.py +++ b/torchmdnet/data.py @@ -14,6 +14,7 @@ from torchmdnet.models.utils import scatter import warnings + class DataModule(LightningDataModule): """A LightningDataModule for loading datasets from the torchmdnet.datasets module. diff --git a/torchmdnet/datasets/ani.py b/torchmdnet/datasets/ani.py index e7ca1add0..279b5b64c 100644 --- a/torchmdnet/datasets/ani.py +++ b/torchmdnet/datasets/ani.py @@ -48,12 +48,12 @@ def raw_file_names(self): def get_atomref(self, max_z=100): """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior. - Args: - max_z (int): Maximum atomic number + Args: + max_z (int): Maximum atomic number - Returns: - torch.Tensor: Atomic energy reference values for each element in the dataset. - """ + Returns: + torch.Tensor: Atomic energy reference values for each element in the dataset. + """ refs = pt.zeros(max_z) for key, val in self._ELEMENT_ENERGIES.items(): refs[key] = val * self.HARTREE_TO_EV diff --git a/torchmdnet/datasets/comp6.py b/torchmdnet/datasets/comp6.py index a810a3d4a..8057c1f32 100644 --- a/torchmdnet/datasets/comp6.py +++ b/torchmdnet/datasets/comp6.py @@ -62,12 +62,12 @@ def raw_url(self): def get_atomref(self, max_z=100): """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior. - Args: - max_z (int): Maximum atomic number + Args: + max_z (int): Maximum atomic number - Returns: - torch.Tensor: Atomic energy reference values for each element in the dataset. - """ + Returns: + torch.Tensor: Atomic energy reference values for each element in the dataset. + """ refs = pt.zeros(max_z) for key, val in self._ELEMENT_ENERGIES.items(): refs[key] = val * self.HARTREE_TO_EV @@ -142,6 +142,7 @@ def raw_url_name(self): def raw_file_names(self): return ["ani_md_bench.h5"] + class DrugBank(COMP6Base): """ DrugBank Benchmark. This benchmark is developed through a subsampling of the @@ -247,7 +248,7 @@ def __init__( self.subsets = [ DS(root, transform, pre_transform, pre_filter) - for DS in (ANIMD, DrugBank, GDB07to09, GDB10to13, Tripeptides, S66X8) + for DS in (ANIMD, DrugBank, GDB07to09, GDB10to13, Tripeptides, S66X8) ] self.num_samples = sum(len(subset) for subset in self.subsets) @@ -347,12 +348,12 @@ def sample_iter(self, mol_ids=False): def get_atomref(self, max_z=100): """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior. - Args: - max_z (int): Maximum atomic number + Args: + max_z (int): Maximum atomic number - Returns: - torch.Tensor: Atomic energy reference values for each element in the dataset. - """ + Returns: + torch.Tensor: Atomic energy reference values for each element in the dataset. + """ refs = pt.zeros(max_z) for key, val in self._ELEMENT_ENERGIES.items(): refs[key] = val * self.HARTREE_TO_EV diff --git a/torchmdnet/datasets/md17.py b/torchmdnet/datasets/md17.py index ecee51a0e..ab829da5f 100644 --- a/torchmdnet/datasets/md17.py +++ b/torchmdnet/datasets/md17.py @@ -17,39 +17,42 @@ # extracted from PyG MD17 dataset class + class MD17(InMemoryDataset): - gdml_url = 'http://quantum-machine.org/gdml/data/npz' - revised_url = ('https://archive.materialscloud.org/record/' - 'file?filename=rmd17.tar.bz2&record_id=466') + gdml_url = "http://quantum-machine.org/gdml/data/npz" + revised_url = ( + "https://archive.materialscloud.org/record/" + "file?filename=rmd17.tar.bz2&record_id=466" + ) file_names = { - 'benzene': 'md17_benzene2017.npz', - 'uracil': 'md17_uracil.npz', - 'naphtalene': 'md17_naphthalene.npz', - 'aspirin': 'md17_aspirin.npz', - 'salicylic_acid': 'md17_salicylic.npz', - 'malonaldehyde': 'md17_malonaldehyde.npz', - 'ethanol': 'md17_ethanol.npz', - 'toluene': 'md17_toluene.npz', - 'paracetamol': 'paracetamol_dft.npz', - 'azobenzene': 'azobenzene_dft.npz', - 'revised_benzene': 'rmd17_benzene.npz', - 'revised_uracil': 'rmd17_uracil.npz', - 'revised_naphthalene': 'rmd17_naphthalene.npz', - 'revised_aspirin': 'rmd17_aspirin.npz', - 'revised_salicylic_acid': 'rmd17_salicylic.npz', - 'revised_malonaldehyde': 'rmd17_malonaldehyde.npz', - 'revised_ethanol': 'rmd17_ethanol.npz', - 'revised_toluene': 'rmd17_toluene.npz', - 'revised_paracetamol': 'rmd17_paracetamol.npz', - 'revised_azobenzene': 'rmd17_azobenzene.npz', - 'benzene_CCSD_T': 'benzene_ccsd_t.zip', - 'aspirin_CCSD': 'aspirin_ccsd.zip', - 'malonaldehyde_CCSD_T': 'malonaldehyde_ccsd_t.zip', - 'ethanol_CCSD_T': 'ethanol_ccsd_t.zip', - 'toluene_CCSD_T': 'toluene_ccsd_t.zip', - 'benzene_FHI-aims': 'benzene2018_dft.npz', + "benzene": "md17_benzene2017.npz", + "uracil": "md17_uracil.npz", + "naphtalene": "md17_naphthalene.npz", + "aspirin": "md17_aspirin.npz", + "salicylic_acid": "md17_salicylic.npz", + "malonaldehyde": "md17_malonaldehyde.npz", + "ethanol": "md17_ethanol.npz", + "toluene": "md17_toluene.npz", + "paracetamol": "paracetamol_dft.npz", + "azobenzene": "azobenzene_dft.npz", + "revised_benzene": "rmd17_benzene.npz", + "revised_uracil": "rmd17_uracil.npz", + "revised_naphthalene": "rmd17_naphthalene.npz", + "revised_aspirin": "rmd17_aspirin.npz", + "revised_salicylic_acid": "rmd17_salicylic.npz", + "revised_malonaldehyde": "rmd17_malonaldehyde.npz", + "revised_ethanol": "rmd17_ethanol.npz", + "revised_toluene": "rmd17_toluene.npz", + "revised_paracetamol": "rmd17_paracetamol.npz", + "revised_azobenzene": "rmd17_azobenzene.npz", + "benzene_CCSD_T": "benzene_ccsd_t.zip", + "aspirin_CCSD": "aspirin_ccsd.zip", + "malonaldehyde_CCSD_T": "malonaldehyde_ccsd_t.zip", + "ethanol_CCSD_T": "ethanol_ccsd_t.zip", + "toluene_CCSD_T": "toluene_ccsd_t.zip", + "benzene_FHI-aims": "benzene2018_dft.npz", } def __init__( @@ -66,19 +69,21 @@ def __init__( raise ValueError(f"Unknown dataset name '{name}'") self.name = name - self.revised = 'revised' in name - self.ccsd = 'CCSD' in self.name + self.revised = "revised" in name + self.ccsd = "CCSD" in self.name super().__init__(root, transform, pre_transform, pre_filter) if len(self.processed_file_names) == 1 and train is not None: raise ValueError( f"'{self.name}' dataset does not provide pre-defined splits " - f"but the 'train' argument is set to '{train}'") + f"but the 'train' argument is set to '{train}'" + ) elif len(self.processed_file_names) == 2 and train is None: raise ValueError( f"'{self.name}' dataset does provide pre-defined splits but " - f"the 'train' argument was not specified") + f"the 'train' argument was not specified" + ) idx = 0 if train is None or train else 1 self.data, self.slices = torch.load(self.processed_paths[idx]) @@ -89,36 +94,36 @@ def mean(self) -> float: @property def raw_dir(self) -> str: if self.revised: - return osp.join(self.root, 'raw') - return osp.join(self.root, self.name, 'raw') + return osp.join(self.root, "raw") + return osp.join(self.root, self.name, "raw") @property def processed_dir(self) -> str: - return osp.join(self.root, 'processed', self.name) + return osp.join(self.root, "processed", self.name) @property def raw_file_names(self) -> str: name = self.file_names[self.name] if self.revised: - return osp.join('rmd17', 'npz_data', name) + return osp.join("rmd17", "npz_data", name) elif self.ccsd: - return name[:-4] + '-train.npz', name[:-4] + '-test.npz' + return name[:-4] + "-train.npz", name[:-4] + "-test.npz" return name @property def processed_file_names(self) -> List[str]: if self.ccsd: - return ['train.pt', 'test.pt'] + return ["train.pt", "test.pt"] else: - return ['data.pt'] + return ["data.pt"] def download(self): if self.revised: path = download_url(self.revised_url, self.raw_dir) - extract_tar(path, self.raw_dir, mode='r:bz2') + extract_tar(path, self.raw_dir, mode="r:bz2") os.unlink(path) else: - url = f'{self.gdml_url}/{self.file_names[self.name]}' + url = f"{self.gdml_url}/{self.file_names[self.name]}" path = download_url(url, self.raw_dir) if self.ccsd: extract_zip(path, self.raw_dir) @@ -130,15 +135,15 @@ def process(self): raw_data = np.load(raw_path) if self.revised: - z = torch.from_numpy(raw_data['nuclear_charges']).long() - pos = torch.from_numpy(raw_data['coords']).float() - energy = torch.from_numpy(raw_data['energies']).float() - force = torch.from_numpy(raw_data['forces']).float() + z = torch.from_numpy(raw_data["nuclear_charges"]).long() + pos = torch.from_numpy(raw_data["coords"]).float() + energy = torch.from_numpy(raw_data["energies"]).float() + force = torch.from_numpy(raw_data["forces"]).float() else: - z = torch.from_numpy(raw_data['z']).long() - pos = torch.from_numpy(raw_data['R']).float() - energy = torch.from_numpy(raw_data['E']).float() - force = torch.from_numpy(raw_data['F']).float() + z = torch.from_numpy(raw_data["z"]).long() + pos = torch.from_numpy(raw_data["R"]).float() + energy = torch.from_numpy(raw_data["E"]).float() + force = torch.from_numpy(raw_data["F"]).float() data_list = [] for i in range(pos.size(0)): diff --git a/torchmdnet/datasets/md22.py b/torchmdnet/datasets/md22.py index ce8adfb96..959b0c354 100644 --- a/torchmdnet/datasets/md22.py +++ b/torchmdnet/datasets/md22.py @@ -15,18 +15,19 @@ extract_zip, ) + class MD22(InMemoryDataset): - gdml_url = 'http://quantum-machine.org/gdml/data/npz' + gdml_url = "http://quantum-machine.org/gdml/data/npz" file_names = { - 'AT-AT-CG-CG': 'md22_AT-AT-CG-CG.npz', - 'AT-AT': 'md22_AT-AT.npz', - 'Ac-Ala3-NHMe': 'md22_Ac-Ala3-NHMe.npz', - 'DHA': 'md22_DHA.npz', - 'buckyball-catcher': 'md22_buckyball-catcher.npz', - 'dw-nanotube': 'md22_dw_nanotube.npz', - 'stachyose': 'md22_stachyose.npz', + "AT-AT-CG-CG": "md22_AT-AT-CG-CG.npz", + "AT-AT": "md22_AT-AT.npz", + "Ac-Ala3-NHMe": "md22_Ac-Ala3-NHMe.npz", + "DHA": "md22_DHA.npz", + "buckyball-catcher": "md22_buckyball-catcher.npz", + "dw-nanotube": "md22_dw_nanotube.npz", + "stachyose": "md22_stachyose.npz", } def __init__( @@ -53,11 +54,11 @@ def mean(self) -> float: @property def raw_dir(self) -> str: - return osp.join(self.root, self.name, 'raw') + return osp.join(self.root, self.name, "raw") @property def processed_dir(self) -> str: - return osp.join(self.root, 'processed', self.name) + return osp.join(self.root, "processed", self.name) @property def raw_file_names(self) -> str: @@ -66,21 +67,21 @@ def raw_file_names(self) -> str: @property def processed_file_names(self) -> List[str]: - return ['data.pt'] + return ["data.pt"] def download(self): - url = f'{self.gdml_url}/{self.file_names[self.name]}' + url = f"{self.gdml_url}/{self.file_names[self.name]}" path = download_url(url, self.raw_dir) def process(self): it = zip(self.raw_paths, self.processed_paths) for raw_path, processed_path in it: raw_data = np.load(raw_path) - - z = torch.from_numpy(raw_data['z']).long() - pos = torch.from_numpy(raw_data['R']).float() - energy = torch.from_numpy(raw_data['E']).float() - force = torch.from_numpy(raw_data['F']).float() + + z = torch.from_numpy(raw_data["z"]).long() + pos = torch.from_numpy(raw_data["R"]).float() + energy = torch.from_numpy(raw_data["E"]).float() + force = torch.from_numpy(raw_data["F"]).float() data_list = [] for i in range(pos.size(0)): diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 4ea072672..5fadc923f 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -14,7 +14,8 @@ import urllib.request from collections import defaultdict -logger = logging.getLogger('MDCATH') +logger = logging.getLogger("MDCATH") + def load_pdb_list(pdb_list): """Load PDB list from a file or return list directly.""" @@ -26,6 +27,7 @@ def load_pdb_list(pdb_list): return [line.strip() for line in file] raise ValueError("Invalid PDB list. Please provide a list or a path to a file.") + class MDCATH(Dataset): def __init__( self, @@ -69,7 +71,7 @@ def __init__( skip_frames: int Number of frames to skip in the trajectory. Default is 1. pdb_list: list or str - List of PDB IDs to download or path to a file with the PDB IDs. If None, all available PDB IDs from 'source_file' will be downloaded. + List of PDB IDs to download or path to a file with the PDB IDs. If None, all available PDB IDs from 'source_file' will be downloaded. The filters will be applied to the PDB IDs in this list in any case. Default is None. min_gyration_radius: float Minimum gyration radius (in nm) of the protein structure. Default is None. @@ -106,7 +108,6 @@ def __init__( # Calculate the total size of the dataset in MB self.total_size_mb = self.calculate_dataset_size() - logger.info(f"Total number of domains: {len(self.processed.keys())}") logger.info(f"Total number of conformers: {self.num_conformers}") logger.info(f"Total size of dataset: {self.total_size_mb} MB") @@ -120,23 +121,27 @@ def raw_dir(self): # Override the raw_dir property to return the root directory # The files will be downloaded to the root directory, compatible only with original mdcath dataset return self.root - + def _ensure_source_file(self): """Ensure the source file is downloaded before processing.""" source_path = os.path.join(self.root, self.source_file) if not os.path.exists(source_path): - assert self.source_file == "mdcath_source.h5", "Only 'mdcath_source.h5' is supported as source file for download." + assert ( + self.source_file == "mdcath_source.h5" + ), "Only 'mdcath_source.h5' is supported as source file for download." logger.info(f"Downloading source file {self.source_file}") urllib.request.urlretrieve(opj(self.url, self.source_file), source_path) - + def download(self): for pdb_id in self.processed.keys(): file_name = f"{self.file_basename}_{pdb_id}.h5" file_path = opj(self.raw_dir, file_name) if not os.path.exists(file_path): - assert self.file_basename == "mdcath_dataset", "Only 'mdcath_dataset' is supported as file_basename for download." + assert ( + self.file_basename == "mdcath_dataset" + ), "Only 'mdcath_dataset' is supported as file_basename for download." # Download the file if it does not exist - urllib.request.urlretrieve(opj(self.url, 'data', file_name), file_path) + urllib.request.urlretrieve(opj(self.url, "data", file_name), file_path) def calculate_dataset_size(self): total_size_bytes = 0 @@ -145,21 +150,27 @@ def calculate_dataset_size(self): total_size_bytes += os.path.getsize(opj(self.root, file_name)) total_size_mb = round(total_size_bytes / (1024 * 1024), 4) return total_size_mb - + def _filter_and_prepare_data(self): source_info_path = os.path.join(self.root, self.source_file) - + self.processed = defaultdict(list) self.num_conformers = 0 with h5py.File(source_info_path, "r") as file: domains = file.keys() if self.pdb_list is None else self.pdb_list - + for pdb_id in tqdm(domains, desc="Processing mdcath source"): pdb_group = file[pdb_id] - if self.numAtoms is not None and pdb_group.attrs["numProteinAtoms"] > self.numAtoms: + if ( + self.numAtoms is not None + and pdb_group.attrs["numProteinAtoms"] > self.numAtoms + ): continue - if self.numResidues is not None and pdb_group.attrs["numResidues"] > self.numResidues: + if ( + self.numResidues is not None + and pdb_group.attrs["numResidues"] > self.numResidues + ): continue self._process_temperatures(pdb_id, pdb_group) @@ -170,16 +181,24 @@ def _process_temperatures(self, pdb_id, pdb_group): def _evaluate_replica(self, pdb_id, temp, replica, pdb_group): conditions = [ - self.numFrames is not None and pdb_group[temp][replica].attrs["numFrames"] < self.numFrames, - self.min_gyration_radius is not None and pdb_group[temp][replica].attrs["min_gyration_radius"] < self.min_gyration_radius, - self.max_gyration_radius is not None and pdb_group[temp][replica].attrs["max_gyration_radius"] > self.max_gyration_radius, + self.numFrames is not None + and pdb_group[temp][replica].attrs["numFrames"] < self.numFrames, + self.min_gyration_radius is not None + and pdb_group[temp][replica].attrs["min_gyration_radius"] + < self.min_gyration_radius, + self.max_gyration_radius is not None + and pdb_group[temp][replica].attrs["max_gyration_radius"] + > self.max_gyration_radius, self._evaluate_structure(pdb_group, temp, replica), - self.numNoHAtoms is not None and pdb_group.attrs["numNoHAtoms"] > self.numNoHAtoms, + self.numNoHAtoms is not None + and pdb_group.attrs["numNoHAtoms"] > self.numNoHAtoms, ] if any(conditions): return - - num_frames = math.ceil(pdb_group[temp][replica].attrs["numFrames"] / self.skip_frames) + + num_frames = math.ceil( + pdb_group[temp][replica].attrs["numFrames"] / self.skip_frames + ) self.processed[pdb_id].append((temp, replica, num_frames)) self.num_conformers += num_frames @@ -193,32 +212,39 @@ def len(self): return self.num_conformers def _setup_idx(self): - files = [opj(self.root, f"{self.file_basename}_{pdb_id}.h5") for pdb_id in self.processed.keys()] + files = [ + opj(self.root, f"{self.file_basename}_{pdb_id}.h5") + for pdb_id in self.processed.keys() + ] self.idx = [] for i, (pdb, group_info) in enumerate(self.processed.items()): for temp, replica, num_frames in group_info: # build the catalog here for each conformer - d = [(pdb, files[i], temp, replica, conf_id) for conf_id in range(num_frames)] - self.idx.extend(d) - - assert (len(self.idx) == self.num_conformers), f"Mismatch between number of conformers and idxs: {self.num_conformers} vs {len(self.idx)}" - - + d = [ + (pdb, files[i], temp, replica, conf_id) + for conf_id in range(num_frames) + ] + self.idx.extend(d) + + assert ( + len(self.idx) == self.num_conformers + ), f"Mismatch between number of conformers and idxs: {self.num_conformers} vs {len(self.idx)}" + def process_specific_group(self, pdb, file, temp, repl, conf_idx): # do not use attributes from h5group because is will cause memory leak # use the read_direct and np.s_ to get the coords and forces of interest directly - conf_idx = conf_idx*self.skip_frames - slice_idxs = np.s_[conf_idx:conf_idx+1] + conf_idx = conf_idx * self.skip_frames + slice_idxs = np.s_[conf_idx : conf_idx + 1] with h5py.File(file, "r") as f: z = f[pdb]["z"][:] coords = np.zeros((z.shape[0], 3)) forces = np.zeros((z.shape[0], 3)) - - group = f[f'{pdb}/{temp}/{repl}'] - group['coords'].read_direct(coords, slice_idxs) - group['forces'].read_direct(forces, slice_idxs) - + group = f[f"{pdb}/{temp}/{repl}"] + + group["coords"].read_direct(coords, slice_idxs) + group["forces"].read_direct(forces, slice_idxs) + # coords and forces shape (num_atoms, 3) assert ( coords.shape[0] == forces.shape[0] @@ -228,7 +254,7 @@ def process_specific_group(self, pdb, file, temp, repl, conf_idx): ), f"Number of atoms mismatch between coords and z: {group['coords'].shape[1]} vs {z.shape[0]}" return (z, coords, forces) - + def get(self, element): data = Data() if self.idx is None: @@ -236,9 +262,11 @@ def get(self, element): self._setup_idx() # fields_data is a tuple with the file, pdb, temp, replica, conf_idx pdb_id, file_path, temp, replica, conf_idx = self.idx[element] - z, coords, forces = self.process_specific_group(pdb_id, file_path, temp, replica, conf_idx) + z, coords, forces = self.process_specific_group( + pdb_id, file_path, temp, replica, conf_idx + ) data.z = torch.tensor(z, dtype=torch.long) data.pos = torch.tensor(coords, dtype=torch.float) data.neg_dy = torch.tensor(forces, dtype=torch.float) - data.info = f'{pdb_id}_{temp}_{replica}_{conf_idx}' - return data \ No newline at end of file + data.info = f"{pdb_id}_{temp}_{replica}_{conf_idx}" + return data diff --git a/torchmdnet/datasets/water.py b/torchmdnet/datasets/water.py index 476d6de3b..524b84182 100644 --- a/torchmdnet/datasets/water.py +++ b/torchmdnet/datasets/water.py @@ -6,8 +6,9 @@ import zipfile import re + def create_numpy_arrays(file_path): - with open(file_path, 'r') as file: + with open(file_path, "r") as file: num_atoms = int(file.readline().strip()) file.seek(0) num_conformations = sum(1 for line in file if line.strip().isdigit()) @@ -21,13 +22,17 @@ def create_numpy_arrays(file_path): for i in range(num_conformations): _ = file.readline() properties_line = file.readline() - tot_energy_match = re.search(r'TotEnergy=(-?\d+\.\d+)', properties_line) + tot_energy_match = re.search(r"TotEnergy=(-?\d+\.\d+)", properties_line) pbc_match = re.search(r'pbc="([T|F] [T|F] [T|F])"', properties_line) lattice_match = re.search(r'Lattice="([-?\d+.\d+\s]+)"', properties_line) energies[i] = float(tot_energy_match.group(1)) if tot_energy_match else None - pbc = [s == 'T' for s in pbc_match.group(1).split()] if pbc_match else None + pbc = [s == "T" for s in pbc_match.group(1).split()] if pbc_match else None assert pbc == [True, True, True] or pbc == [False, False, False] - box_vectors[i] = [float(x) for x in lattice_match.group(1).split()] if lattice_match else None + box_vectors[i] = ( + [float(x) for x in lattice_match.group(1).split()] + if lattice_match + else None + ) for j in range(num_atoms): atom_line = file.readline().strip().split() positions[i, j] = [float(x) for x in atom_line[1:4]] @@ -35,8 +40,9 @@ def create_numpy_arrays(file_path): atomic_numbers[i, j] = int(atom_line[7]) return energies, forces, positions, atomic_numbers, box_vectors + class WaterBox(InMemoryDataset): - """ WaterBox dataset from [1]_. + """WaterBox dataset from [1]_. The dataset consists of 1593 water molecules in a cubic box with periodic boundary conditions. The molecules are sampled from a molecular dynamics simulation of liquid water. @@ -58,7 +64,8 @@ class WaterBox(InMemoryDataset): ---------- [1] Ab initio thermodynamics of liquid and solid water. Bingqing et. al. https://arxiv.org/abs/1811.08630 """ - url = 'https://archive.materialscloud.org/record/file?record_id=71&filename=training-set.zip' + + url = "https://archive.materialscloud.org/record/file?record_id=71&filename=training-set.zip" def __init__(self, root, transform=None, pre_transform=None): super(WaterBox, self).__init__(root, transform, pre_transform) @@ -66,25 +73,31 @@ def __init__(self, root, transform=None, pre_transform=None): @property def raw_file_names(self): - return ['dataset_1593.xyz'] + return ["dataset_1593.xyz"] @property def processed_file_names(self): - return ['data.pt'] + return ["data.pt"] def download(self): r = requests.get(self.url) if r.status_code != 200: - raise Exception(f"Failed to download file from {self.url}. Status code: {r.status_code}") - zip_path = os.path.join(self.raw_dir, 'training-set.zip') - with open(zip_path, 'wb') as f: + raise Exception( + f"Failed to download file from {self.url}. Status code: {r.status_code}" + ) + zip_path = os.path.join(self.raw_dir, "training-set.zip") + with open(zip_path, "wb") as f: f.write(r.content) - with zipfile.ZipFile(zip_path, 'r') as zip_ref: + with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(self.raw_dir) def process(self): - dataset_xyz_path = os.path.join(self.raw_dir, 'training-set', 'dataset_1593.xyz') - energies, forces, positions, atomic_numbers, box_vectors = create_numpy_arrays(dataset_xyz_path) + dataset_xyz_path = os.path.join( + self.raw_dir, "training-set", "dataset_1593.xyz" + ) + energies, forces, positions, atomic_numbers, box_vectors = create_numpy_arrays( + dataset_xyz_path + ) data_list = [] for i in range(len(energies)): diff --git a/torchmdnet/models/__init__.py b/torchmdnet/models/__init__.py index 41e00338b..e0139eaee 100644 --- a/torchmdnet/models/__init__.py +++ b/torchmdnet/models/__init__.py @@ -2,4 +2,9 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -__all_models__ = ["graph-network", "transformer", "equivariant-transformer", "tensornet"] +__all_models__ = [ + "graph-network", + "transformer", + "equivariant-transformer", + "tensornet", +] diff --git a/torchmdnet/models/output_modules.py b/torchmdnet/models/output_modules.py index bf408aa36..1aa4f2b7e 100644 --- a/torchmdnet/models/output_modules.py +++ b/torchmdnet/models/output_modules.py @@ -66,7 +66,7 @@ def __init__( allow_prior_model=True, reduce_op="sum", dtype=torch.float, - **kwargs + **kwargs, ): super(Scalar, self).__init__( allow_prior_model=allow_prior_model, reduce_op=reduce_op @@ -96,7 +96,7 @@ def __init__( allow_prior_model=True, reduce_op="sum", dtype=torch.float, - **kwargs + **kwargs, ): super(EquivariantScalar, self).__init__( allow_prior_model=allow_prior_model, reduce_op=reduce_op @@ -138,7 +138,7 @@ def __init__( activation="silu", reduce_op="sum", dtype=torch.float, - **kwargs + **kwargs, ): super(DipoleMoment, self).__init__( hidden_channels, @@ -146,7 +146,7 @@ def __init__( allow_prior_model=False, reduce_op=reduce_op, dtype=dtype, - **kwargs + **kwargs, ) atomic_mass = torch.from_numpy(atomic_masses).to(dtype) self.register_buffer("atomic_mass", atomic_mass) @@ -171,7 +171,7 @@ def __init__( activation="silu", reduce_op="sum", dtype=torch.float, - **kwargs + **kwargs, ): super(EquivariantDipoleMoment, self).__init__( hidden_channels, @@ -179,7 +179,7 @@ def __init__( allow_prior_model=False, reduce_op=reduce_op, dtype=dtype, - **kwargs + **kwargs, ) atomic_mass = torch.from_numpy(atomic_masses).to(dtype) self.register_buffer("atomic_mass", atomic_mass) @@ -205,7 +205,7 @@ def __init__( activation="silu", reduce_op="sum", dtype=torch.float, - **kwargs + **kwargs, ): super(ElectronicSpatialExtent, self).__init__( allow_prior_model=False, reduce_op=reduce_op @@ -248,7 +248,7 @@ def __init__( activation="silu", reduce_op="sum", dtype=torch.float, - **kwargs + **kwargs, ): super(EquivariantVectorOutput, self).__init__( hidden_channels, @@ -256,7 +256,7 @@ def __init__( allow_prior_model=False, reduce_op="sum", dtype=dtype, - **kwargs + **kwargs, ) def pre_reduce(self, x, v, z, pos, batch): diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index 5ff168d54..9ab7748d6 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -15,6 +15,7 @@ ) from torchmdnet.utils import deprecated_class + class TorchMD_ET(nn.Module): r"""Equivariant Transformer's architecture. From Equivariant Transformers for Neural Network based Molecular Potentials; P. Tholke and G. de Fabritiis. diff --git a/torchmdnet/models/wrappers.py b/torchmdnet/models/wrappers.py index 444805e06..f9637d0f8 100644 --- a/torchmdnet/models/wrappers.py +++ b/torchmdnet/models/wrappers.py @@ -34,6 +34,7 @@ class AtomFilter(BaseWrapper): """ Remove atoms with Z > remove_threshold from the model's output. """ + def __init__(self, model, remove_threshold): super(AtomFilter, self).__init__(model) self.remove_threshold = remove_threshold diff --git a/torchmdnet/priors/base.py b/torchmdnet/priors/base.py index be7593826..01e028f2f 100644 --- a/torchmdnet/priors/base.py +++ b/torchmdnet/priors/base.py @@ -38,7 +38,15 @@ def pre_reduce(self, x, z, pos, batch, extra_args: Optional[Dict[str, Tensor]]): """ return x - def post_reduce(self, y, z, pos, batch, box: Optional[Tensor], extra_args: Optional[Dict[str, Tensor]]): + def post_reduce( + self, + y, + z, + pos, + batch, + box: Optional[Tensor], + extra_args: Optional[Dict[str, Tensor]], + ): r"""Post-reduce method of the prior model. Args: diff --git a/torchmdnet/priors/coulomb.py b/torchmdnet/priors/coulomb.py index 449e2c530..934546f52 100644 --- a/torchmdnet/priors/coulomb.py +++ b/torchmdnet/priors/coulomb.py @@ -7,6 +7,7 @@ from torchmdnet.models.utils import OptimizedDistance, scatter from typing import Optional, Dict + class Coulomb(BasePrior): """This class implements a Coulomb potential, scaled by a cosine switching function to reduce its effect at short distances. @@ -33,32 +34,55 @@ class Coulomb(BasePrior): The Dataset used with this class must include a `partial_charges` field for each sample, and provide `distance_scale` and `energy_scale` attributes if they are not explicitly passed as arguments. """ - def __init__(self, lower_switch_distance, upper_switch_distance, max_num_neighbors, distance_scale=None, energy_scale=None, box_vecs=None, dataset=None): + + def __init__( + self, + lower_switch_distance, + upper_switch_distance, + max_num_neighbors, + distance_scale=None, + energy_scale=None, + box_vecs=None, + dataset=None, + ): super(Coulomb, self).__init__() if distance_scale is None: distance_scale = dataset.distance_scale if energy_scale is None: energy_scale = dataset.energy_scale - self.distance = OptimizedDistance(0, torch.inf, max_num_pairs=-max_num_neighbors) + self.distance = OptimizedDistance( + 0, torch.inf, max_num_pairs=-max_num_neighbors + ) self.lower_switch_distance = lower_switch_distance self.upper_switch_distance = upper_switch_distance self.max_num_neighbors = max_num_neighbors self.distance_scale = float(distance_scale) self.energy_scale = float(energy_scale) self.initial_box = box_vecs + def get_init_args(self): - return {'lower_switch_distance': self.lower_switch_distance, - 'upper_switch_distance': self.upper_switch_distance, - 'max_num_neighbors': self.max_num_neighbors, - 'distance_scale': self.distance_scale, - 'energy_scale': self.energy_scale, - 'initial_box': self.initial_box} + return { + "lower_switch_distance": self.lower_switch_distance, + "upper_switch_distance": self.upper_switch_distance, + "max_num_neighbors": self.max_num_neighbors, + "distance_scale": self.distance_scale, + "energy_scale": self.energy_scale, + "initial_box": self.initial_box, + } def reset_parameters(self): pass - def post_reduce(self, y, z, pos, batch, box: Optional[torch.Tensor] = None, extra_args: Optional[Dict[str, torch.Tensor]] = None): - """ Compute the Coulomb energy for each sample in a batch. + def post_reduce( + self, + y, + z, + pos, + batch, + box: Optional[torch.Tensor] = None, + extra_args: Optional[Dict[str, torch.Tensor]] = None, + ): + """Compute the Coulomb energy for each sample in a batch. Parameters ---------- @@ -81,17 +105,21 @@ def post_reduce(self, y, z, pos, batch, box: Optional[torch.Tensor] = None, extr Tensor of shape (batch_size, 1) containing the energies of each sample in the batch. """ # Convert to nm and calculate distance. - x = 1e9*self.distance_scale*pos + x = 1e9 * self.distance_scale * pos box = box if box is not None else self.initial_box edge_index, distance, _ = self.distance(x, batch, box=box) # Compute the energy, converting to the dataset's units. Multiply by 0.5 because every atom pair # appears twice. - q = extra_args['partial_charges'][edge_index] + q = extra_args["partial_charges"][edge_index] lower = torch.tensor(self.lower_switch_distance) upper = torch.tensor(self.upper_switch_distance) - phase = (torch.max(lower, torch.min(upper, distance))-lower)/(upper-lower) - energy = (0.5-0.5*torch.cos(torch.pi*phase))*q[0]*q[1]/distance - energy = 0.5*(2.30707e-28/self.energy_scale/self.distance_scale)*scatter(energy, batch[edge_index[0]], dim=0, reduce="sum") + phase = (torch.max(lower, torch.min(upper, distance)) - lower) / (upper - lower) + energy = (0.5 - 0.5 * torch.cos(torch.pi * phase)) * q[0] * q[1] / distance + energy = ( + 0.5 + * (2.30707e-28 / self.energy_scale / self.distance_scale) + * scatter(energy, batch[edge_index[0]], dim=0, reduce="sum") + ) energy = energy.reshape(y.shape) return y + energy diff --git a/torchmdnet/priors/d2.py b/torchmdnet/priors/d2.py index 6b053cce3..13addd741 100644 --- a/torchmdnet/priors/d2.py +++ b/torchmdnet/priors/d2.py @@ -7,6 +7,7 @@ import torch as pt from typing import Optional, Dict + class D2(BasePrior): """ Dispersive correction term as used in DFT-D2. @@ -104,7 +105,7 @@ class D2(BasePrior): [29.99, 1.881], # 54 Xe ], dtype=pt.float64, - ) #::meta private: + ) #::meta private: C_6_R_r[:, 1] *= 0.1 # Å --> nm def __init__( @@ -159,7 +160,15 @@ def get_init_args(self): "energy_scale": self.energy_scale, } - def post_reduce(self, y, z, pos, batch, box: Optional[pt.Tensor] = None, extra_args: Optional[Dict[str, pt.Tensor]] = None): + def post_reduce( + self, + y, + z, + pos, + batch, + box: Optional[pt.Tensor] = None, + extra_args: Optional[Dict[str, pt.Tensor]] = None, + ): # Convert to interal units: nm and J/mol # NOTE: float32 is overflowed, if m and J are used diff --git a/torchmdnet/priors/zbl.py b/torchmdnet/priors/zbl.py index de6515c49..4b12f5772 100644 --- a/torchmdnet/priors/zbl.py +++ b/torchmdnet/priors/zbl.py @@ -71,7 +71,15 @@ def get_init_args(self): def reset_parameters(self): pass - def post_reduce(self, y, z, pos, batch, box: Optional[torch.Tensor] = None, extra_args: Optional[Dict[str, torch.Tensor]] = None): + def post_reduce( + self, + y, + z, + pos, + batch, + box: Optional[torch.Tensor] = None, + extra_args: Optional[Dict[str, torch.Tensor]] = None, + ): edge_index, distance, _ = self.distance(pos, batch, box) if edge_index.shape[1] == 0: return y diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 1f46273dd..b8c2941ad 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -225,7 +225,7 @@ def main(): args.early_stopping_monitor, patience=args.early_stopping_patience ) callbacks.append(early_stopping) - + check_logs(args.log_dir) csv_logger = CSVLogger(args.log_dir, name="", version="") _logger = [csv_logger] diff --git a/torchmdnet/utils.py b/torchmdnet/utils.py index 1a1f7cd5a..0a0305f30 100644 --- a/torchmdnet/utils.py +++ b/torchmdnet/utils.py @@ -372,7 +372,9 @@ def write_as_hdf5(files, hdf5_dataset, tile_embed=True): num_samples = coord_data.shape[0] group.create_dataset("pos", data=coord_data) if tile_embed: - group.create_dataset("types", data=np.tile(embed_data, (num_samples, 1))) + group.create_dataset( + "types", data=np.tile(embed_data, (num_samples, 1)) + ) else: group.create_dataset("types", data=embed_data) if "y" in files: @@ -402,12 +404,14 @@ def wrapped_init(self, *args, **kwargs): cls.__init__ = wrapped_init return cls + def check_logs(log_dir): - import os + import os import time - metr_file_path = os.path.join(log_dir, 'metrics.csv') + + metr_file_path = os.path.join(log_dir, "metrics.csv") if os.path.exists(metr_file_path): # we make a backup of the metrics file (rename) bckp_date = f'{time.strftime("%Y%m%d")}-{time.strftime("%H%M%S")}' os.rename(metr_file_path, metr_file_path.replace(".csv", f"_{bckp_date}.csv")) - return \ No newline at end of file + return