Skip to content

Commit

Permalink
Merge branch 'main' into fix/docker_image_building_issue_610
Browse files Browse the repository at this point in the history
  • Loading branch information
mikemhenry authored Nov 20, 2023
2 parents ed2a35b + 1e6a365 commit b226323
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 2 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies:
- openff-units==0.2.0
- pint<0.22
- openff-models>=0.0.5
- openfe-analysis>=0.1.2
- click
- typing-extensions
- lomap2>=2.3.0
Expand Down
83 changes: 83 additions & 0 deletions openfe/analysis/plotting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
from itertools import chain
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
import numpy as np
import numpy.typing as npt
from openff.units import unit
from typing import Optional, Union
Expand Down Expand Up @@ -202,3 +204,84 @@ def plot_replica_timeseries(

ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
return ax


def plot_2D_rmsd(data: list[list[float]],
vmax=5.0) -> plt.Figure:
"""Plots 2D RMSD for many states
Parameters
----------
data : list[list[float]]
for each state, the 2D RMSD
vmax : float, optional
the value to consider "high" in the colourmap to flag bad values,
defaults to 5.0 (A)
Returns
-------
matplotlib Figure
"""
twod_rmsd_arrs = []
for state in data:
# unpack 2D RMSD data
# we store N(N-1)//2 values, so find N then make symmetric array
N = int((1 + np.sqrt(8 * len(state) + 1)) / 2)
arr = np.zeros((N, N))
arr[np.triu_indices_from(arr, k=1)] = state
arr += arr.T

twod_rmsd_arrs.append(arr)

nplots = len(data) + 1 # + colorbar

# plot on 4 x n grid
nrows = nplots // 4 + (1 if nplots % 4 else 0)

fig, axes = plt.subplots(nrows, 4)

for i, (arr, ax) in enumerate(
zip(twod_rmsd_arrs, chain.from_iterable(axes))):
ax.imshow(arr,
vmin=0, vmax=vmax,
cmap=plt.get_cmap('cividis'))
ax.axis('off') # turn off ticks/labels
ax.set_title(f'State {i}')

plt.colorbar(axes[0][0].images[0],
cax=axes[-1][-1],
label="RMSD scale (A)",
orientation="horizontal")

fig.suptitle('Protein 2D RMSD')
fig.tight_layout()

return fig


def plot_ligand_COM_drift(time: list[float], data: list[list[float]]):
fig, ax = plt.subplots()

for i, s in enumerate(data):
ax.plot(time, s, label=f'State {i}')

ax.legend(loc='upper left')
ax.set_xlabel('Time (ps)')
ax.set_ylabel('Distance (A)')
ax.set_title('Ligand COM drift')

return fig


def plot_ligand_RMSD(time: list[float], data: list[list[float]]):
fig, ax = plt.subplots()

for i, s in enumerate(data):
ax.plot(time, s, label=f'State {i}')

ax.legend(loc='upper left')
ax.set_xlabel('Time (ps)')
ax.set_ylabel('RMSD (A)')
ax.set_title('Ligand RMSD')

return fig
35 changes: 34 additions & 1 deletion openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from collections import defaultdict
import uuid
import warnings
import json
from itertools import chain
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from openff.units import unit
Expand All @@ -40,6 +42,7 @@
from typing import Any, Iterable, Union
import openmmtools
import mdtraj
import subprocess
from rdkit import Chem

import gufe
Expand All @@ -60,6 +63,7 @@
)
from . import _rfe_utils
from ...utils import without_oechem_backend, log_system_probe
from ...analysis import plotting
from openfe.due import due, Doi


Expand Down Expand Up @@ -980,6 +984,33 @@ def run(self, *, dry=False, verbose=True,
else:
return {'debug': {'sampler': sampler}}

@staticmethod
def analyse(where) -> dict:
# don't put energy analysis in here, it uses the open file reporter
# whereas structural stuff requires that the file handle is closed
ret = subprocess.run(['openfe_analysis', str(where)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
if ret.returncode:
return {'structural_analysis_error': ret.stderr}

data = json.loads(ret.stdout)

savedir = pathlib.Path(where)
if d := data['protein_2D_RMSD']:
fig = plotting.plot_2D_rmsd(d)
fig.savefig(savedir / "protein_2D_RMSD.png")
plt.close(fig)
f2 = plotting.plot_ligand_COM_drift(data['time(ps)'], data['ligand_wander'])
f2.savefig(savedir / "ligand_COM_drift.png")
plt.close(f2)

f3 = plotting.plot_ligand_RMSD(data['time(ps)'], data['ligand_RMSD'])
f3.savefig(savedir / "ligand_RMSD.png")
plt.close(f3)

return {'structural_analysis': data}

def _execute(
self, ctx: gufe.Context, **kwargs,
) -> dict[str, Any]:
Expand All @@ -988,9 +1019,11 @@ def _execute(
outputs = self.run(scratch_basepath=ctx.scratch,
shared_basepath=ctx.shared)

analysis_outputs = self.analyse(ctx.shared)

return {
'repeat_id': self._inputs['repeat_id'],
'generation': self._inputs['generation'],
**outputs
**outputs,
**analysis_outputs,
}
6 changes: 5 additions & 1 deletion openfe/protocols/openmm_utils/multistate_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from typing import Union
import warnings

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from openmmtools import multistate
Expand Down Expand Up @@ -77,6 +77,7 @@ def plot(self, filepath: Path, filename_prefix: str):
ax.figure.savefig( # type: ignore
filepath / (filename_prefix + 'mbar_overlap_matrix.png')
)
plt.close(ax.figure) # type: ignore

# Reverse and forward analysis
ax = plotting.plot_convergence(
Expand All @@ -86,6 +87,7 @@ def plot(self, filepath: Path, filename_prefix: str):
ax.figure.savefig( # type: ignore
filepath / (filename_prefix + 'forward_reverse_convergence.png')
)
plt.close(ax.figure) # type: ignore

# Replica state timeseries plot
ax = plotting.plot_replica_timeseries(
Expand All @@ -95,6 +97,7 @@ def plot(self, filepath: Path, filename_prefix: str):
ax.figure.savefig( # type: ignore
filepath / (filename_prefix + 'replica_state_timeseries.png')
)
plt.close(ax.figure) # type: ignore

# Replica exchange transition matrix
if self.sampling_method == 'repex':
Expand All @@ -105,6 +108,7 @@ def plot(self, filepath: Path, filename_prefix: str):
ax.figure.savefig( # type: ignore
filepath / (filename_prefix + 'replica_exchange_matrix.png')
)
plt.close(ax.figure) # type: ignore

def _analyze(self, forward_reverse_samples: int):
"""
Expand Down

0 comments on commit b226323

Please sign in to comment.