Skip to content

Commit

Permalink
Remove replica states from results dictionary (#603)
Browse files Browse the repository at this point in the history
* [WIP] Remove replica states from results dictionary

* Make it so that you can access replica states from file

* fix missing import

* fix key & remove tests

* fix FEAnalysis test

* add file not found errors

* plural on replica

* fix error messages being raised

* Fix slow test + add "reading the replica states" test to slow test

* Add longer test for HFEs

* Add new json files
  • Loading branch information
IAlibay authored Nov 3, 2023
1 parent e60580c commit 68a08d2
Show file tree
Hide file tree
Showing 14 changed files with 190 additions and 35 deletions.
2 changes: 1 addition & 1 deletion openfe/protocols/openmm_afe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ def run(self, dry=False, verbose=True,

if not dry:
nc = self.shared_basepath / settings['simulation_settings'].output_filename
chk = self.shared_basepath / settings['simulation_settings'].checkpoint_storage
chk = settings['simulation_settings'].checkpoint_storage
return {
'nc': nc,
'last_checkpoint': chk,
Expand Down
41 changes: 35 additions & 6 deletions openfe/protocols/openmm_afe/equil_solvation_afe_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
"""
from __future__ import annotations

import pathlib
import logging

from collections import defaultdict
import gufe
from gufe.components import Component
Expand All @@ -38,6 +38,7 @@
import numpy as np
import numpy.typing as npt
from openff.units import unit
from openmmtools import multistate
from typing import Dict, Optional, Union
from typing import Any, Iterable

Expand Down Expand Up @@ -281,13 +282,41 @@ def get_replica_states(self) -> dict[str, list[npt.NDArray]]:
the thermodynamic cycle, with lists of replica states
timeseries for each repeat of that simulation type.
"""
replica_states: dict[str, list[npt.NDArray]] = {}
replica_states: dict[str, list[npt.NDArray]] = {
'solvent': [], 'vacuum': []
}

def is_file(filename: str):
p = pathlib.Path(filename)

if not p.exists():
errmsg = f"File could not be found {p}"
raise ValueError(errmsg)

return p

def get_replica_state(nc, chk):
nc = is_file(nc)
dir_path = nc.parents[0]
chk = is_file(dir_path / chk).name

reporter = multistate.MultiStateReporter(
storage=nc, checkpoint_storage=chk, open_mode='r'
)

retval = np.asarray(reporter.read_replica_thermodynamic_states())
reporter.close()

return retval

for key in ['solvent', 'vacuum']:
replica_states[key] = [
pus[0].outputs['replica_states']
for pus in self.data[key].values()
]
for pus in self.data[key].values():
states = get_replica_state(
pus[0].outputs['nc'],
pus[0].outputs['last_checkpoint'],
)
replica_states[key].append(states)

return replica_states

def equilibration_iterations(self) -> dict[str, list[float]]:
Expand Down
22 changes: 20 additions & 2 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,26 @@ def get_replica_states(self) -> list[npt.NDArray]:
replica_states : List[npt.NDArray]
List of replica states for each repeat
"""
replica_states = [pus[0].outputs['replica_states']
for pus in self.data.values()]
def is_file(filename: str):
p = pathlib.Path(filename)
if not p.exists():
errmsg = f"File could not be found {p}"
raise ValueError(errmsg)
return p

replica_states = []

for pus in self.data.values():
nc = is_file(pus[0].outputs['nc'])
dir_path = nc.parents[0]
chk = is_file(dir_path / pus[0].outputs['last_checkpoint']).name
reporter = multistate.MultiStateReporter(
storage=nc, checkpoint_storage=chk, open_mode='r'
)
replica_states.append(
np.asarray(reporter.read_replica_thermodynamic_states())
)
reporter.close()

return replica_states

Expand Down
2 changes: 1 addition & 1 deletion openfe/protocols/openmm_utils/multistate_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def unit_results_dict(self):
'forward_and_reverse_energies': self.forward_and_reverse_free_energies,
'production_iterations': self.production_iterations,
'equilibration_iterations': self.equilibration_iterations,
'replica_states': self.replica_states}
}

if hasattr(self, '_exchange_matrix'):
results_dict['replica_exchange_statistics'] = self.replica_exchange_statistics
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion openfe/tests/protocols/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def rfe_transformation_json() -> str:
"""string of a RFE result of quickrun"""
d = resources.files('openfe.tests.data.openmm_rfe')

with gzip.open((d / 'Transformation-e1702a3efc0fa735d5c14fc7572b5278_results.json.gz').as_posix(), 'r') as f: # type: ignore
with gzip.open((d / 'RFE-ProtocolUnitResult-0f3457edf947483aa03d0f4fe88bf566.json.gz').as_posix(), 'r') as f: # type: ignore
return f.read().decode() # type: ignore


Expand Down
106 changes: 106 additions & 0 deletions openfe/tests/protocols/test_openmm_afe_slow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe

from gufe.protocols import execute_DAG
import pytest
from openff.units import unit
from openmm import Platform
import os
import pathlib

import openfe
from openfe.protocols import openmm_afe


@pytest.fixture
def available_platforms() -> set[str]:
return {Platform.getPlatform(i).getName() for i in range(Platform.getNumPlatforms())}


@pytest.fixture
def set_openmm_threads_1():
# for vacuum sims, we want to limit threads to one
# this fixture sets OPENMM_CPU_THREADS='1' for a single test, then reverts to previously held value
previous: str | None = os.environ.get('OPENMM_CPU_THREADS')

try:
os.environ['OPENMM_CPU_THREADS'] = '1'
yield
finally:
if previous is None:
del os.environ['OPENMM_CPU_THREADS']
else:
os.environ['OPENMM_CPU_THREADS'] = previous


@pytest.mark.integration # takes too long to be a slow test ~ 4 mins locally
@pytest.mark.flaky(reruns=3) # pytest-rerunfailures; we can get bad minimisation
@pytest.mark.parametrize('platform', ['CPU', 'CUDA'])
def test_openmm_run_engine(platform,
available_platforms,
benzene_modifications,
set_openmm_threads_1, tmpdir):
if platform not in available_platforms:
pytest.skip(f"OpenMM Platform: {platform} not available")

# Run a really short calculation to check everything is going well
s = openmm_afe.AbsoluteSolvationProtocol.default_settings()
s.alchemsampler_settings.n_repeats = 1
s.solvent_simulation_settings.output_indices = "resname UNK"
s.vacuum_simulation_settings.equilibration_length = 0.1 * unit.picosecond
s.vacuum_simulation_settings.production_length = 0.1 * unit.picosecond
s.solvent_simulation_settings.equilibration_length = 0.1 * unit.picosecond
s.solvent_simulation_settings.production_length = 0.1 * unit.picosecond
s.vacuum_engine_settings.compute_platform = platform
s.solvent_engine_settings.compute_platform = platform
s.integrator_settings.n_steps = 5 * unit.timestep
s.vacuum_simulation_settings.checkpoint_interval = 5 * unit.timestep
s.solvent_simulation_settings.checkpoint_interval = 5 * unit.timestep
s.alchemsampler_settings.n_replicas = 14
s.alchemical_settings.lambda_elec_windows = 5
s.alchemical_settings.lambda_vdw_windows = 9

protocol = openmm_afe.AbsoluteSolvationProtocol(
settings=s,
)

stateA = openfe.ChemicalSystem({
'benzene': benzene_modifications['benzene'],
'solvent': openfe.SolventComponent()
})

stateB = openfe.ChemicalSystem({
'solvent': openfe.SolventComponent(),
})

# Create DAG from protocol, get the vacuum and solvent units
# and eventually dry run the first solvent unit
dag = protocol.create(
stateA=stateA,
stateB=stateB,
mapping=None,
)


cwd = pathlib.Path(str(tmpdir))
r = execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd,
keep_shared=True)

assert r.ok()
for pur in r.protocol_unit_results:
unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0"
assert unit_shared.exists()
assert pathlib.Path(unit_shared).is_dir()
checkpoint = pur.outputs['last_checkpoint']
assert checkpoint == f"{pur.outputs['simtype']}_checkpoint.nc"
assert (unit_shared / checkpoint).exists()
nc = pur.outputs['nc']
assert nc == unit_shared / f"{pur.outputs['simtype']}.nc"
assert nc.exists()

# Test results methods that need files present
results = protocol.gather([r])
states = results.get_replica_states()
assert len(states.items()) == 2
assert len(states['solvent']) == 1
assert states['solvent'][0].shape[1] == 14
20 changes: 9 additions & 11 deletions openfe/tests/protocols/test_openmm_afe_solvation_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,15 +600,15 @@ def test_get_estimate(self, protocolresult):
est = protocolresult.get_estimate()

assert est
assert est.m == pytest.approx(-3.00208997)
assert est.m == pytest.approx(-2.977553138764437)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)

def test_get_uncertainty(self, protocolresult):
est = protocolresult.get_uncertainty()

assert est
assert est.m == pytest.approx(0.1577349)
assert est.m == pytest.approx(0.19617297299036018)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)

Expand Down Expand Up @@ -664,15 +664,6 @@ def test_get_replica_transition_statistics(self, key, protocolresult):
assert rpx1['eigenvalues'].shape == (15,)
assert rpx1['matrix'].shape == (15, 15)

@pytest.mark.parametrize('key', ['solvent', 'vacuum'])
def test_get_replica_states(self, key, protocolresult):
rep = protocolresult.get_replica_states()

assert isinstance(rep, dict)
assert isinstance(rep[key], list)
assert len(rep[key]) == 3
assert rep[key][0].shape == (251, 15)

@pytest.mark.parametrize('key', ['solvent', 'vacuum'])
def test_equilibration_iterations(self, key, protocolresult):
eq = protocolresult.equilibration_iterations()
Expand All @@ -690,3 +681,10 @@ def test_production_iterations(self, key, protocolresult):
assert isinstance(prod[key], list)
assert len(prod[key]) == 3
assert all(isinstance(v, float) for v in prod[key])

def test_filenotfound_replica_states(self, protocolresult):
errmsg = "File could not be found"

with pytest.raises(ValueError, match=errmsg):
protocolresult.get_replica_states()

16 changes: 7 additions & 9 deletions openfe/tests/protocols/test_openmm_equil_rfe_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,15 +1289,15 @@ def test_get_estimate(self, protocolresult):
est = protocolresult.get_estimate()

assert est
assert est.m == pytest.approx(-15.768768285032115)
assert est.m == pytest.approx(3.5531577581450953)
assert isinstance(est, unit.Quantity)
assert est.is_compatible_with(unit.kilojoule_per_mole)

def test_get_uncertainty(self, protocolresult):
est = protocolresult.get_uncertainty()

assert est
assert est.m == pytest.approx(0.03662634237353985)
assert est.m == pytest.approx(0.03431704941311493)
assert isinstance(est, unit.Quantity)
assert est.is_compatible_with(unit.kilojoule_per_mole)

Expand Down Expand Up @@ -1347,13 +1347,6 @@ def test_get_replica_transition_statistics(self, protocolresult):
assert rpx1['eigenvalues'].shape == (11,)
assert rpx1['matrix'].shape == (11, 11)

def test_get_replica_states(self, protocolresult):
rep = protocolresult.get_replica_states()

assert isinstance(rep, list)
assert len(rep) == 3
assert rep[0].shape == (6, 11)

def test_equilibration_iterations(self, protocolresult):
eq = protocolresult.equilibration_iterations()

Expand All @@ -1368,6 +1361,11 @@ def test_production_iterations(self, protocolresult):
assert len(prod) == 3
assert all(isinstance(v, float) for v in prod)

def test_filenotfound_replica_states(self, protocolresult):
errmsg = "File could not be found"

with pytest.raises(ValueError, match=errmsg):
protocolresult.get_replica_states()

@pytest.mark.parametrize('mapping_name,result', [
["benzene_to_toluene_mapping", 0],
Expand Down
10 changes: 8 additions & 2 deletions openfe/tests/protocols/test_openmm_rfe_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,18 @@ def test_openmm_run_engine(benzene_vacuum_system, platform,
assert unit_shared.exists()
assert pathlib.Path(unit_shared).is_dir()
checkpoint = pur.outputs['last_checkpoint']
assert checkpoint == unit_shared / "checkpoint.nc"
assert checkpoint.exists()
assert checkpoint == "checkpoint.nc"
assert (unit_shared / checkpoint).exists()
nc = pur.outputs['nc']
assert nc == unit_shared / "simulation.nc"
assert nc.exists()

# Test results methods that need files present
results = p.gather([r])
states = results.get_replica_states()
assert len(states) == 1
assert states[0].shape[1] == 11


@pytest.mark.integration # takes ~7 minutes to run
@pytest.mark.flaky(reruns=3)
Expand Down
2 changes: 1 addition & 1 deletion openfe/tests/protocols/test_openmmutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def analyzer(self, reporter):

def test_free_energies(self, analyzer):
ret_dict = analyzer.unit_results_dict
assert len(ret_dict.items()) == 8
assert len(ret_dict.items()) == 7
assert pytest.approx(ret_dict['unit_estimate'].m) == -47.9606
assert pytest.approx(ret_dict['unit_estimate_error'].m) == 0.02396789
# forward and reverse (since we do this ourselves)
Expand Down
2 changes: 1 addition & 1 deletion openfe/tests/protocols/test_solvation_afe_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_key_stable(self):

class TestAbsoluteSolvationProtocolResult(GufeTokenizableTestsMixin):
cls = openmm_afe.AbsoluteSolvationProtocolResult
key = "AbsoluteSolvationProtocolResult-8caab27e7ad1bd544a787ac639f5f447"
key = "AbsoluteSolvationProtocolResult-e7d74b8ccc009d071b8c6eb0420da4bf"
repr = f"<{key}>"

@pytest.fixture()
Expand Down

0 comments on commit 68a08d2

Please sign in to comment.