Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PBC support #201

Merged
merged 47 commits into from
Jan 17, 2024
Merged

Add PBC support #201

merged 47 commits into from
Jan 17, 2024

Conversation

RaulPPelaez
Copy link
Collaborator

@RaulPPelaez RaulPPelaez commented Jul 10, 2023

I introduced a new parameter, box_vecs that is interpreted as a 3x3 tensor defining the periodic box.

The parameter is a string in the format: [[a,b,c],[d,e,f],[g,h,i]]
The string is then transformed into a tensor and passed along to the models.
I also added box as an argument to the forward function of TorchMD_Net.
This means that there are two ways to specify the box:

  • By setting it in the yaml file
  • By passing it to forward
    Forward one takes precedence.

The box can be set globally (all samples in a batch have the same box) by passing a 3x3 tensor to the model (or using the yaml file) or per-sample by passing a tensor with shape (max(batch)+1, 3,3). Note that the box in the yaml can only be used gobally.

Additionally, the LNNP module can now take a per-sample parameter, "box", from the Dataloader and send it through to the model. Box exists in the same conceptual level as "q" or "s" are now.

I added a new dataset, WaterBox, as an example of a dataset that requires a per-sample box.

@RaulPPelaez RaulPPelaez requested a review from raimis August 8, 2023 09:15
@RaulPPelaez RaulPPelaez changed the title Add PBC support to TensorNet Add PBC support Oct 3, 2023
@FranklinHu1
Copy link

@RaulPPelaez I'm very sorry but I made a mistake in my previous testing of the variable PBCs and was only using the CPU platform, not the CUDA platform (as I incorrectly stated in issue #221 ). I went ahead and updated my box_input branch version of the code and tried to run using the CUDA platform this time on an A100 GPU, and get a different error message that points to the same problem:

########## READING CLI ARGUMENTS ##########
Setting up save directory for dynamics outputs
########## SETTING UP SIMULATION SYSTEM ##########
Initializing system based on atoms listed in pdb file
Checking system has no forces or constraints
0
Model outputs forces: True
Torchscript module copied
Model uses passed box vectors: True
Traceback (most recent call last):
  File "/pscratch/sd/f/frankhu/openmm_INXS/run_INXS_dynamics.py", line 117, in <module>
    simulation.context.setVelocitiesToTemperature(300 * kelvin)
  File "/global/homes/f/frankhu/.conda/envs/torchmd-net/lib/python3.10/site-packages/openmm/openmm.py", line 3512, in setVelocitiesToTemperature
    return _openmm.Context_setVelocitiesToTemperature(self, *args)
openmm.OpenMMException: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__.py", line 16, in forward
    model = self.model
    z = self.z
    _1 = (model).forward(z, positions1, None, torch.mul(box_vectors, 10), None, None, None, )
          ~~~~~~~~~~~~~~ <--- HERE
    energy, force, = _1
    if torch.__isnot__(force, None):
  File "code/__torch__/torchmdnet/models/model.py", line 40, in forward
      pass
    representation_model = self.representation_model
    _4 = (representation_model).forward(z, pos, batch0, box, q, s, )
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    x, v, z0, pos0, batch1, = _4
    output_model = self.output_model
  File "code/__torch__/torchmdnet/models/tensornet.py", line 34, in forward
    _3 = uninitialized(Tensor)
    distance = self.distance
    _4 = (distance).forward(pos, batch, box, )
          ~~~~~~~~~~~~~~~~~ <--- HERE
    edge_index, edge_weight, edge_vec, = _4
    if torch.__isnot__(edge_vec, None):
  File "code/__torch__/torchmdnet/models/utils.py", line 113, in forward
    include_transpose = self.include_transpose
    use_periodic = self.use_periodic
    _20 = _16(strategy, pos, batch0, box1, use_periodic, cutoff_lower, cutoff_upper, max_pairs, loop, include_transpose, )
          ~~~ <--- HERE
    edge_index, edge_vec, edge_weight, num_pairs, = _20
    check_errors = self.check_errors
  File "code/__torch__/torchmdnet/extensions.py", line 11, in get_neighbor_pairs_kernel
    loop: bool,
    include_transpose: bool) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
  _0, _1, _2, _3 = ops.torchmdnet_extensions.get_neighbor_pairs(strategy, positions, batch, box_vectors, use_periodic, cutoff_lower, cutoff_upper, max_num_pairs, loop, include_transpose)
                   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
  return (_0, _1, _2, _3)
def is_current_stream_capturing() -> bool:

Traceback of TorchScript, original code (most recent call last):
  File "/pscratch/sd/f/frankhu/openmm_INXS/torch_force_generator_multi.py", line 81, in forward
        positions = positions * 10 #nm -> A
        #Multiply box vectors by 10 to ensure unit consistency
        energy, force = self.model.forward(self.z,
                        ~~~~~~~~~~~~~~~~~~ <--- HERE
                                           positions,
                                           box = box_vectors * 10)
  File "/global/u2/f/frankhu/torchmd-net/torchmdnet/models/model.py", line 329, in forward
            pos.requires_grad_(True)
        # run the potentially wrapped representation model
        x, v, z, pos, batch = self.representation_model(z, pos, batch, box=box, q=q, s=s)
                              ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        # apply the output network
        x = self.output_model.pre_reduce(x, v, z, pos, batch)
  File "/global/u2/f/frankhu/torchmd-net/torchmdnet/models/tensornet.py", line 225, in forward
    ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
        # Obtain graph, with distances and relative position vectors
        edge_index, edge_weight, edge_vec = self.distance(pos, batch, box)
                                            ~~~~~~~~~~~~~ <--- HERE
        # This assert convinces TorchScript that edge_vec is a Tensor and not an Optional[Tensor]
        assert (
  File "/global/u2/f/frankhu/torchmd-net/torchmdnet/models/utils.py", line 255, in forward
        if batch is None:
            batch = torch.zeros(pos.shape[0], dtype=torch.long, device=pos.device)
        edge_index, edge_vec, edge_weight, num_pairs = get_neighbor_pairs_kernel(
                                                       ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            strategy=self.strategy,
            positions=pos,
  File "/global/u2/f/frankhu/torchmd-net/torchmdnet/extensions/__init__.py", line 97, in get_neighbor_pairs_kernel
    This function is a torch extension loaded from `torch.ops.torchmdnet_extensions.get_neighbor_pairs`.
    """
    return torch.ops.torchmdnet_extensions.get_neighbor_pairs(
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        strategy,
        positions,
RuntimeError: Expected "box_vectors" to be on CPU

My file for running dynamics is as follows:

#import openmmtools
import torch
#Specific imports for ET
from openmmtorch import TorchForce
import torchmdnet.neighbors
import torchmdnet.extensions
import torch_cluster, torch_geometric
#Simulation-specific imports
import sys
from openmm import LangevinIntegrator, AndersenThermostat, VerletIntegrator, LangevinMiddleIntegrator, System, Vec3
from openmm.app import Simulation, PDBReporter, StateDataReporter, CutoffPeriodic, CheckpointReporter, DCDReporter 
from openmm.unit import kelvin, picosecond, femtosecond, nanometers
import openmm
import mdtraj as md
from reporters import VelocityReporter, ForceReporter
#Miscellaneous imports
import argparse, shutil, os

#Note: Because the entire system is being modeled using ML, we can forego the use of any xml files embedded directly into 
#the subdirectories of openMM. Instead, initialize the system from the pdb files by setting atomic positions.

parser = argparse.ArgumentParser(description = "method for modifying file names at the top level")
parser.add_argument('current_step', type = int, help = 'index for the current stage of dynamics, zero-indexed')
parser.add_argument('model_path_dir', type = str, help = 'directory where the force model is stored')
parser.add_argument('save_path', type = str, help = 'directory name to save dynamics results to')
parser.add_argument('system_topology', type = str, help = 'pdb file that contains the topology for the system of interest')
parser.add_argument('topology_interpretation', type = str, help = "whether to use 'openmm' or 'mdtraj' for interpreting the pdb topology file in cases where the pdb is not formatted correctly for openmm")
parser.add_argument('pass_PBCs', type = str, help = 'determines if vectors specifying the periodic boundary conditions are passed to the torchforce forward pass')
#Example on how to restart found at https://github.com/openmm/openmm/issues/3837
if __name__ == "__main__":
    print("########## READING CLI ARGUMENTS ##########")
    args = parser.parse_args()
    curr_step = args.current_step
    mod_path_dir = args.model_path_dir
    pdb_path = args.system_topology
    save_path = f'dynamics_results/{args.save_path}'
    if not os.path.isdir(save_path):
        print("Setting up save directory for dynamics outputs")
        os.mkdir(save_path)
    
    print("########## SETTING UP SIMULATION SYSTEM ##########")
    #If curr_step == 0, it's the first trajectory. If it's greater than 
    #   0, then it searches for a checkpoint file to load in. Also affects
    #   some initialization steps.
    print("Initializing system based on atoms listed in pdb file")
    system = System()
    
    if args.topology_interpretation == 'openmm':
        pdb = openmm.app.PDBFile(pdb_path)
        for atom in pdb.topology.atoms():
            system.addParticle(atom.element.mass)
        system.setDefaultPeriodicBoxVectors(*pdb.topology.getPeriodicBoxVectors())
    elif args.topology_interpretation == 'mdtraj':
        #The use of mdtraj is necessary to circumvent some of the issues 
        #    with pdb files not being properly formatted for openmm (e.g., when run through i-pi)
        pdb = md.load(pdb_path)
        openmm_topo = pdb.topology.to_openmm()
        for atom in openmm_topo.atoms():
            system.addParticle(atom.element.mass)
        #Convert the periodic box vectors correctly
        unitcell_vectors = pdb.unitcell_vectors[0]
        v1 = Vec3(*unitcell_vectors[0])
        v2 = Vec3(*unitcell_vectors[1])
        v3 = Vec3(*unitcell_vectors[2])
        q_tot = openmm.unit.quantity.Quantity((v1, v2, v3), nanometers)
        system.setDefaultPeriodicBoxVectors(*q_tot)
        openmm_topo.setPeriodicBoxVectors(q_tot)
    
    print("Checking system has no forces or constraints")
    assert(system.getNumForces() == 0)
    print(system.getNumConstraints())
    assert(system.getNumConstraints() == 0)
    
    force = TorchForce(f'models/{mod_path_dir}/generated_mod.pt')
    force.setOutputsForces(True) #Return force predictions directly
    if args.pass_PBCs.lower() == 'true':
        #Toggle if box vectors are passed
        force.setUsesPeriodicBoundaryConditions(True)
    print("Model outputs forces:", force.getOutputsForces())
    system.addForce(force)
    assert(system.getNumForces() == 1)

    #Copy the model torchscript module as a sanity check
    shutil.copy(f'models/{mod_path_dir}/generated_mod.pt', 
                f'{save_path}/generated_mod.pt')
    print("Torchscript module copied")
    
    #Copy this script as well to keep track of other simulation parameters
    shutil.copy('run_INXS_dynamics.py', f'{save_path}/run_INXS_dynamics.py')
 
    temperature = 300 * kelvin
    frictionCoeff = 1 / picosecond
    timeStep = 0.5 * femtosecond
    integrator = LangevinMiddleIntegrator(temperature, frictionCoeff, timeStep)
    platform = openmm.Platform.getPlatformByName("CUDA")
    
    if args.topology_interpretation == 'openmm':
        simulation = Simulation(pdb.topology, system, integrator, platform)
    elif args.topology_interpretation == 'mdtraj':
        simulation = Simulation(openmm_topo, system, integrator, platform)
        
    print("Model uses passed box vectors:", simulation.system.usesPeriodicBoundaryConditions())
    
    if curr_step == 0:
        
        if args.topology_interpretation == 'openmm':
            simulation.context.setPositions(pdb.positions)
        elif args.topology_interpretation == 'mdtraj':
            #Convert the positions correctly from the mdtraj pdb file
            all_pos = pdb.xyz[0]
            all_vecs = [Vec3(*elem) for elem in all_pos]
            pos_quant = openmm.unit.quantity.Quantity(all_vecs, nanometers)
            simulation.context.setPositions(pos_quant)
        
        # print("Minimizing...")
        # simulation.minimizeEnergy()
        simulation.context.setVelocitiesToTemperature(300 * kelvin)
        #print("Equilibrating...")
        #Equilibrate system for 3000 steps
        #simulation.step(3000)
    elif curr_step > 0:
        #Load in the previous step checkpoint
        print(f"Loading in state_{curr_step - 1}.chk")
        simulation.loadCheckpoint(f'{save_path}/state_{curr_step - 1}.chk')
     
    #DCD file + reporters for appending
    simulation.reporters.append(DCDReporter(f'{save_path}/trajectory.dcd', 4, append = True if curr_step > 0 else False))
    simulation.reporters.append(StateDataReporter(f'{save_path}/properties_{curr_step}.out', 4, step=True,
        time=True, potentialEnergy=True, temperature=True, volume=True, density=True,
        progress=True, remainingTime=True, speed=True, totalSteps=2_000_000,
        separator='\t'))
    simulation.reporters.append(CheckpointReporter(f'{save_path}/state_{curr_step}.chk', 4))
    #Add a reporter so we can get velocity information out
    simulation.reporters.append(VelocityReporter(f'{save_path}/velocities_{curr_step}.xyz', 4))
    simulation.reporters.append(ForceReporter(f'{save_path}/forces_{curr_step}.xyz', 4))
    
    print("########## RUNNING PRODUCTION ##########")
    simulation.step(2_000_000)
    print("Done!")

Sorry again about not catching this earlier. The good news is that the testing I previously did shows that dynamics run correctly, but right now only on the CPU platform.

@RaulPPelaez
Copy link
Collaborator Author

RaulPPelaez commented Dec 18, 2023

I fixed that in this commit 06b7b72 . I think you just need to reinstall with the latest version of this PR.
Sometimes setup.py might incorrectly decide not to recompile the CUDA parts, so remove the build directory in the root of the repo before pip installing again.

@RaulPPelaez
Copy link
Collaborator Author

I managed to allow a per-sample box. This enables training on datasets that provide the box per sample. I added the water box dataset suggested by @sef43 as an example.
@sef43 could you review?

@guillemsimeon
Copy link
Collaborator

Wow cool! Thanks! I would love to test TensorNet on this water box dataset.

@RaulPPelaez
Copy link
Collaborator Author

Its just 1.5K conformations AFAIK, so I would not expect any miracles from it hehe.

@guillemsimeon
Copy link
Collaborator

on MD17 we train on 950 frames, since it is just one system. But we never trained using PBC and I am excited about it! It opens a bunch of possibilities

@sef43
Copy link
Collaborator

sef43 commented Dec 18, 2023

i don’t see the water example code in this pr?

If i remember correctly that dataset is sufficient to train a water model that has stable NVT dynamics with similar MLPs, correct NPT dynamics is another story though…

@RaulPPelaez
Copy link
Collaborator Author

Oops! sorry, git push failed and I did not noticed. Commits are there now

@raimis raimis removed their request for review January 16, 2024 14:38
@sef43
Copy link
Collaborator

sef43 commented Jan 16, 2024

I can not get this to work properly. When I load/create a model and give it a box size it does not seem to change the output of the forward pass. Here is an example test script and the output. I would expect the output to be different for these box sizes. Am I calling the models in the wrong way?

script:

import pytest
from pytest import mark
import pickle
from os.path import exists, dirname, join
import torch
import lightning as pl
from torchmdnet import models
from torchmdnet.models.model import create_model
from torchmdnet.models import output_modules
from torchmdnet.models.utils import dtype_mapping
from utils import load_example_args, create_example_batch





def test_forward_box(model_name, use_batch=False, explicit_q_s=False, precision=32):
    z, pos, batch = create_example_batch()
    pos = pos.to(dtype=dtype_mapping[precision])
    model = create_model(load_example_args(model_name, prior_model=None, precision=precision))
    batch = batch if use_batch else None
    
    boxes = [torch.eye(3)*L for L in [0.0001,0.1,1,10,1000]]
    ys=[]
    #print(boxes)
    for box in boxes:
        if explicit_q_s:
            y,_ = model(z, pos, batch=batch, box=box, q=None, s=None)
        else:
            y,_ = model(z, pos, batch=batch, box=box)

        print("y = ", y.item())
    



for model_name in models.__all_models__:
    test_forward_box(model_name)

output:

(torchmd-net) steve@metro06:~/torchmd-net/tests$ python mytest.py 
[W LinearAlgebra.cpp:2928] Warning: at::frobenius_norm is deprecated and it is just left for JIT compatibility. It will be removed in a future PyTorch release. Please use `linalg.vector_norm(A, 2., dim, keepdim)` instead (function operator())
y =  3.3409483432769775
y =  3.3409483432769775
y =  3.3409483432769775
y =  3.3409483432769775
y =  3.3409483432769775
/home/steve/miniconda3/envs/torchmd-net/lib/python3.11/site-packages/torchmdnet/models/model.py:65: DeprecationWarning: TorchMD_T is deprecated and will be removed in a future version.
  representation_model = TorchMD_T(
y =  -5.708318710327148
y =  -5.708318710327148
y =  -5.708318710327148
y =  -5.708318710327148
y =  -5.708318710327148
y =  0.4629291296005249
y =  0.4629291296005249
y =  0.4629291296005249
y =  0.4629291296005249
y =  0.4629291296005249
y =  0.1535973697900772
y =  0.1535973697900772
y =  0.1535973697900772
y =  0.1535973697900772
y =  0.1535973697900772


@RaulPPelaez
Copy link
Collaborator Author

Thanks a bunch Steve!
Your test should have failed because you are not allowed to have a box with size <2*cutoff.
This made me uncover the bug, the neighbor list was ignoring the box when the module was constructed without one but then a box was being passed to forward.

import pytest
import torch
from torchmdnet import models
from torchmdnet.models.model import create_model
from torchmdnet.models.utils import dtype_mapping
from utils import load_example_args, create_example_batch

@pytest.mark.parametrize("model_name", models.__all_models__)
def test_forward_box(model_name, use_batch=False, precision=32):
    z, pos, batch = create_example_batch(n_atoms=100)
    pos = pos.to(dtype=dtype_mapping[precision])*100
    args = load_example_args(model_name, prior_model=None, precision=precision)
    model = create_model(args)
    batch = batch if use_batch else None

    boxes = [torch.eye(3)*L*args["cutoff_upper"] for L in [2, 4,8,16]]
    for box in boxes:
        y,_ = model(z, pos, batch=batch, box=box)
        print(f"Box: {box[0][0].item()}\tEnergy: {y.item():.7f}")

Prints:

$ pytest -v -s test_box.py 
============================================================================ test session starts ============================================================================
platform linux -- Python 3.11.6, pytest-7.4.3, pluggy-1.3.0 -- /home/raul/miniforge3/envs/torchmdnet/bin/python3.11
cachedir: .pytest_cache
rootdir: /home/raul/work/bcn/torchmd-net
plugins: typeguard-2.13.3, anyio-3.7.1, cov-4.1.0
collected 4 items                                                                                                                                                           

test_box.py::test_forward_box[graph-network] [W LinearAlgebra.cpp:2785] Warning: at::frobenius_norm is deprecated and it is just left for JIT compatibility. It will be removed in a future PyTorch release. Please use `linalg.vector_norm(A, 2., dim, keepdim)` instead (function operator())
Box: 10.0	Energy: 77.3235855
Box: 20.0	Energy: 59.7313576
Box: 40.0	Energy: 58.5310135
Box: 80.0	Energy: 58.3661118
PASSED
test_box.py::test_forward_box[transformer] Box: 10.0	Energy: 28.0527191
Box: 20.0	Energy: 27.3246365
Box: 40.0	Energy: 26.7726440
Box: 80.0	Energy: 26.4065723
PASSED
test_box.py::test_forward_box[equivariant-transformer] Box: 10.0	Energy: -2.5009341
Box: 20.0	Energy: -5.6181102
Box: 40.0	Energy: -6.9233050
Box: 80.0	Energy: -7.0186620
PASSED
test_box.py::test_forward_box[tensornet] Box: 10.0	Energy: 2.7247550
Box: 20.0	Energy: 0.3784023
Box: 40.0	Energy: 0.0480490
Box: 80.0	Energy: 0.1820417
PASSED

@sef43
Copy link
Collaborator

sef43 commented Jan 17, 2024

Great! ready to merge from my point of view

@RaulPPelaez
Copy link
Collaborator Author

Please approve the review.

@sef43
Copy link
Collaborator

sef43 commented Jan 17, 2024

I dont think I have the correct permission to properly approve

Copy link
Collaborator

@guillemsimeon guillemsimeon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fantastique

@guillemsimeon guillemsimeon merged commit ad1adfc into torchmd:main Jan 17, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants