Skip to content

Commit

Permalink
#61 Removed need for file saving in ANIPotentialImpl (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
JMorado authored Mar 1, 2024
1 parent 305597f commit 90984d5
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions openmmml/models/anipotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from openmmml.mlpotential import MLPotential, MLPotentialImpl, MLPotentialImplFactory
import openmm
from typing import Iterable, Optional, Union
from typing import Iterable, Optional

class ANIPotentialImplFactory(MLPotentialImplFactory):
"""This is the factory that creates ANIPotentialImpl objects."""
Expand All @@ -52,12 +52,6 @@ class ANIPotentialImpl(MLPotentialImpl):
can optionally use only a single model by specifying the modelIndex argument to
select which one to use. This leads to a large improvement in speed, at the
cost of a small decrease in accuracy.
TorchForce requires the model to be saved to disk in a separate file. By default
it writes a file called 'animodel.pt' in the current working directory. You can
use the filename argument to specify a different name. For example,
>>> system = potential.createSystem(topology, filename='mymodel.pt')
"""

def __init__(self, name):
Expand All @@ -69,7 +63,6 @@ def addForces(self,
system: openmm.System,
atoms: Optional[Iterable[int]],
forceGroup: int,
filename: str = 'animodel.pt',
implementation: str = 'nnpops',
modelIndex: Optional[int] = None,
**args):
Expand Down Expand Up @@ -139,14 +132,13 @@ def forward(self, positions, boxvectors: Optional[torch.Tensor] = None):
is_periodic = (topology.getPeriodicBoxVectors() is not None) or system.usesPeriodicBoundaryConditions()
aniForce = ANIForce(model, species, atoms, is_periodic)

# Convert it to TorchScript and save it.
# Convert it to TorchScript.

module = torch.jit.script(aniForce)
module.save(filename)

# Create the TorchForce and add it to the System.

force = openmmtorch.TorchForce(filename)
force = openmmtorch.TorchForce(module)
force.setForceGroup(forceGroup)
force.setUsesPeriodicBoundaryConditions(is_periodic)
system.addForce(force)
Expand Down

0 comments on commit 90984d5

Please sign in to comment.