Skip to content

Commit

Permalink
blacken code
Browse files Browse the repository at this point in the history
  • Loading branch information
stefdoerr committed Feb 12, 2025
1 parent 2be8620 commit e24eaa7
Show file tree
Hide file tree
Showing 24 changed files with 369 additions and 203 deletions.
1 change: 1 addition & 0 deletions benchmarks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def benchmark_pdb(pdb_file, **kwargs):
"2L emb 64": {"num_layers": 2, "embedding_dimension": 64},
}


def benchmark_all():
timings = {}
for pdb_file in os.listdir("systems"):
Expand Down
1 change: 1 addition & 0 deletions benchmarks/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Optional
from torch_cluster import radius_graph


class Distance(nn.Module):
def __init__(
self,
Expand Down
16 changes: 10 additions & 6 deletions examples/openmm-integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import openmm
import openmmtorch
except ImportError:
raise ImportError("Please install OpenMM and OpenMM-Torch (you can use conda install -c conda-forge openmm openmm-torch)")
raise ImportError(
"Please install OpenMM and OpenMM-Torch (you can use conda install -c conda-forge openmm openmm-torch)"
)

import sys
import torch
Expand Down Expand Up @@ -34,9 +36,9 @@ def __init__(self, embeddings, model):
def forward(self, positions):
# OpenMM works with nanometer positions and kilojoule per mole energies
# Depending on the model, you might need to convert the units
positions = positions.to(torch.float32) * 10.0 # nm -> A
positions = positions.to(torch.float32) * 10.0 # nm -> A
energy = self.model(z=self.embeddings, pos=positions)[0]
return energy * 96.4916 # eV -> kJ/mol
return energy * 96.4916 # eV -> kJ/mol


pdb = PDBFile("../benchmarks/systems/chignolin.pdb")
Expand All @@ -54,9 +56,11 @@ def forward(self, positions):
for atom in pdb.topology.atoms():
system.addParticle(atom.element.mass)
system.addForce(torch_force)
integrator = LangevinMiddleIntegrator(298.15*kelvin, 1/picosecond, 2*femtosecond)
platform = Platform.getPlatformByName('CPU')
integrator = LangevinMiddleIntegrator(298.15 * kelvin, 1 / picosecond, 2 * femtosecond)
platform = Platform.getPlatformByName("CPU")
simulation = Simulation(pdb.topology, system, integrator, platform)
simulation.context.setPositions(pdb.positions)
simulation.reporters.append(StateDataReporter(sys.stdout, 1, step=True, potentialEnergy=True, temperature=True))
simulation.reporters.append(
StateDataReporter(sys.stdout, 1, step=True, potentialEnergy=True, temperature=True)
)
simulation.step(10)
8 changes: 6 additions & 2 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,10 @@ def test_hdf5_with_and_without_caching(num_files, tile_embed, batch_size, tmpdir

for sample_cached, sample in zip(dl_cached, dl):
assert np.allclose(sample_cached.pos, sample.pos), "Sample has incorrect coords"
assert np.allclose(sample_cached.z, sample.z), "Sample has incorrect atom numbers"
assert np.allclose(
sample_cached.z, sample.z
), "Sample has incorrect atom numbers"
assert np.allclose(sample_cached.y, sample.y), "Sample has incorrect energy"
assert np.allclose(sample_cached.neg_dy, sample.neg_dy), "Sample has incorrect forces"
assert np.allclose(
sample_cached.neg_dy, sample.neg_dy
), "Sample has incorrect forces"
4 changes: 1 addition & 3 deletions tests/test_mdcath.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,7 @@ def test_mdcath_args(tmpdir, skipframes, batch_size, pdb_list):
data.flush()
data.close()

dataset = MDCATH(
root=tmpdir, skip_frames=skipframes, pdb_list=pdb_list
)
dataset = MDCATH(root=tmpdir, skip_frames=skipframes, pdb_list=pdb_list)
dl = DataLoader(
dataset,
batch_size=batch_size,
Expand Down
87 changes: 65 additions & 22 deletions tests/test_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,74 +38,115 @@ def test_atomref(model_name, enable_atomref):

# check if the output of both models differs by the expected atomref contribution
if enable_atomref:
expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(1)
expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(
1
)
else:
expected_offset = 0
torch.testing.assert_close(x_atomref, x_no_atomref + expected_offset)


@mark.parametrize("trainable", [True, False])
def test_atomref_trainable(trainable):
dataset = DummyDataset(has_atomref=True)
atomref = Atomref(max_z=100, dataset=dataset, trainable=trainable)
assert atomref.atomref.weight.requires_grad == trainable


def test_learnableatomref():
atomref = LearnableAtomref(max_z=100)
assert atomref.atomref.weight.requires_grad == True


def test_zbl():
pos = torch.tensor([[1.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, -1.0]], dtype=torch.float32) # Atom positions in Bohr
pos = torch.tensor(
[[1.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
dtype=torch.float32,
) # Atom positions in Bohr
types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types
atomic_number = torch.tensor([1, 6, 8], dtype=torch.int8) # Mapping of atom types to atomic numbers
atomic_number = torch.tensor(
[1, 6, 8], dtype=torch.int8
) # Mapping of atom types to atomic numbers
distance_scale = 5.29177210903e-11 # Convert Bohr to meters
energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules
energy_scale = 1000.0 / 6.02214076e23 # Convert kJ/mol to Joules

# Use the ZBL class to compute the energy.

zbl = ZBL(10.0, 5, atomic_number, distance_scale=distance_scale, energy_scale=energy_scale)
energy = zbl.post_reduce(torch.zeros((1,)), types, pos, torch.zeros_like(types), None, {})[0]
zbl = ZBL(
10.0, 5, atomic_number, distance_scale=distance_scale, energy_scale=energy_scale
)
energy = zbl.post_reduce(
torch.zeros((1,)), types, pos, torch.zeros_like(types), None, {}
)[0]

# Compare to the expected value.

def compute_interaction(pos1, pos2, z1, z2):
delta = pos1-pos2
delta = pos1 - pos2
r = torch.sqrt(torch.dot(delta, delta))
x = r / (0.8854/(z1**0.23 + z2**0.23))
phi = 0.1818*torch.exp(-3.2*x) + 0.5099*torch.exp(-0.9423*x) + 0.2802*torch.exp(-0.4029*x) + 0.02817*torch.exp(-0.2016*x)
cutoff = 0.5*(torch.cos(r*torch.pi/10.0) + 1.0)
return cutoff*phi*(138.935/5.29177210903e-2)*z1*z2/r
x = r / (0.8854 / (z1**0.23 + z2**0.23))
phi = (
0.1818 * torch.exp(-3.2 * x)
+ 0.5099 * torch.exp(-0.9423 * x)
+ 0.2802 * torch.exp(-0.4029 * x)
+ 0.02817 * torch.exp(-0.2016 * x)
)
cutoff = 0.5 * (torch.cos(r * torch.pi / 10.0) + 1.0)
return cutoff * phi * (138.935 / 5.29177210903e-2) * z1 * z2 / r

expected = 0
for i in range(len(pos)):
for j in range(i):
expected += compute_interaction(pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]])
expected += compute_interaction(
pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]]
)
torch.testing.assert_close(expected, energy, rtol=1e-4, atol=1e-4)


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_coulomb(dtype):
pos = torch.tensor([[0.5, 0.0, 0.0], [1.5, 0.0, 0.0], [0.8, 0.8, 0.0], [0.0, 0.0, -0.4]], dtype=dtype) # Atom positions in nm
pos = torch.tensor(
[[0.5, 0.0, 0.0], [1.5, 0.0, 0.0], [0.8, 0.8, 0.0], [0.0, 0.0, -0.4]],
dtype=dtype,
) # Atom positions in nm
charge = torch.tensor([0.2, -0.1, 0.8, -0.9], dtype=dtype) # Partial charges
types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types
distance_scale = 1e-9 # Convert nm to meters
energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules
energy_scale = 1000.0 / 6.02214076e23 # Convert kJ/mol to Joules
lower_switch_distance = 0.9
upper_switch_distance = 1.3

# Use the Coulomb class to compute the energy.

coulomb = Coulomb(lower_switch_distance, upper_switch_distance, 5, distance_scale=distance_scale, energy_scale=energy_scale)
energy = coulomb.post_reduce(torch.zeros((1,)), types, pos, torch.zeros_like(types), extra_args={'partial_charges':charge})[0]
coulomb = Coulomb(
lower_switch_distance,
upper_switch_distance,
5,
distance_scale=distance_scale,
energy_scale=energy_scale,
)
energy = coulomb.post_reduce(
torch.zeros((1,)),
types,
pos,
torch.zeros_like(types),
extra_args={"partial_charges": charge},
)[0]

# Compare to the expected value.

def compute_interaction(pos1, pos2, z1, z2):
delta = pos1-pos2
delta = pos1 - pos2
r = torch.sqrt(torch.dot(delta, delta))
if r < lower_switch_distance:
return 0
energy = 138.935*z1*z2/r
energy = 138.935 * z1 * z2 / r
if r < upper_switch_distance:
energy *= 0.5-0.5*torch.cos(torch.pi*(r-lower_switch_distance)/(upper_switch_distance-lower_switch_distance))
energy *= 0.5 - 0.5 * torch.cos(
torch.pi
* (r - lower_switch_distance)
/ (upper_switch_distance - lower_switch_distance)
)
return energy

expected = 0
Expand All @@ -120,10 +161,12 @@ def test_multiple_priors(dtype):
# Create a model from a config file.

dataset = DummyDataset(has_atomref=True)
config_file = join(dirname(__file__), 'priors.yaml')
args = load_example_args('equivariant-transformer', config_file=config_file, dtype=dtype)
config_file = join(dirname(__file__), "priors.yaml")
args = load_example_args(
"equivariant-transformer", config_file=config_file, dtype=dtype
)
prior_models = create_prior_models(args, dataset)
args['prior_args'] = [p.get_init_args() for p in prior_models]
args["prior_args"] = [p.get_init_args() for p in prior_models]
model = LNNP(args, prior_model=prior_models)
priors = model.model.prior_model

Expand Down
6 changes: 4 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs):
if config_file is None:
if model_name == "tensornet":
config_file = join(dirname(dirname(__file__)), "examples", "TensorNet-QM9.yaml")
config_file = join(
dirname(dirname(__file__)), "examples", "TensorNet-QM9.yaml"
)
else:
config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml")
with open(config_file, "r") as f:
Expand Down Expand Up @@ -84,7 +86,7 @@ def _get_atomref(self):
return self.atomref

DummyDataset.get_atomref = _get_atomref
self.atomic_number = torch.arange(max(atom_types)+1)
self.atomic_number = torch.arange(max(atom_types) + 1)
self.distance_scale = 1.0
self.energy_scale = 1.0

Expand Down
1 change: 1 addition & 0 deletions torchmdnet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchmdnet.models.utils import scatter
import warnings


class DataModule(LightningDataModule):
"""A LightningDataModule for loading datasets from the torchmdnet.datasets module.
Expand Down
10 changes: 5 additions & 5 deletions torchmdnet/datasets/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ def raw_file_names(self):
def get_atomref(self, max_z=100):
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.
Args:
max_z (int): Maximum atomic number
Args:
max_z (int): Maximum atomic number
Returns:
torch.Tensor: Atomic energy reference values for each element in the dataset.
"""
Returns:
torch.Tensor: Atomic energy reference values for each element in the dataset.
"""
refs = pt.zeros(max_z)
for key, val in self._ELEMENT_ENERGIES.items():
refs[key] = val * self.HARTREE_TO_EV
Expand Down
23 changes: 12 additions & 11 deletions torchmdnet/datasets/comp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def raw_url(self):
def get_atomref(self, max_z=100):
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.
Args:
max_z (int): Maximum atomic number
Args:
max_z (int): Maximum atomic number
Returns:
torch.Tensor: Atomic energy reference values for each element in the dataset.
"""
Returns:
torch.Tensor: Atomic energy reference values for each element in the dataset.
"""
refs = pt.zeros(max_z)
for key, val in self._ELEMENT_ENERGIES.items():
refs[key] = val * self.HARTREE_TO_EV
Expand Down Expand Up @@ -142,6 +142,7 @@ def raw_url_name(self):
def raw_file_names(self):
return ["ani_md_bench.h5"]


class DrugBank(COMP6Base):
"""
DrugBank Benchmark. This benchmark is developed through a subsampling of the
Expand Down Expand Up @@ -247,7 +248,7 @@ def __init__(

self.subsets = [
DS(root, transform, pre_transform, pre_filter)
for DS in (ANIMD, DrugBank, GDB07to09, GDB10to13, Tripeptides, S66X8)
for DS in (ANIMD, DrugBank, GDB07to09, GDB10to13, Tripeptides, S66X8)
]

self.num_samples = sum(len(subset) for subset in self.subsets)
Expand Down Expand Up @@ -347,12 +348,12 @@ def sample_iter(self, mol_ids=False):
def get_atomref(self, max_z=100):
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.
Args:
max_z (int): Maximum atomic number
Args:
max_z (int): Maximum atomic number
Returns:
torch.Tensor: Atomic energy reference values for each element in the dataset.
"""
Returns:
torch.Tensor: Atomic energy reference values for each element in the dataset.
"""
refs = pt.zeros(max_z)
for key, val in self._ELEMENT_ENERGIES.items():
refs[key] = val * self.HARTREE_TO_EV
Expand Down
Loading

0 comments on commit e24eaa7

Please sign in to comment.