Skip to content

Commit

Permalink
Collate function for MolDataset
Browse files Browse the repository at this point in the history
The static method `MolDataset.collateMolDataset` can be used as the
collate_fn within `torch.utils.data.DataLoader`. This will return: lengths,
centers, coordinates, types, radii and labels of all of the examples in
a batch. Length corresponds to the number of atoms in example as each
example will be padded to the length of the example with the most atoms
in the batch. This allows the use of `gridmaker.forward` on the gpu and
using molgrid with DataLoader and `torch.utils.data.distributed.DistributedSampler`.
  • Loading branch information
drewnutt committed Jun 28, 2021
1 parent 24e1d2a commit 6d3737e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
29 changes: 29 additions & 0 deletions python/torch_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,32 @@ def __setstate__(self,state):
self.examples.populate(self.types_files)

self.num_labels = self.examples.num_labels()

@staticmethod
def collateMolDataset(batch):
'''collate_fn for use in torch.utils.data.Dataloader when using the MolDataset.
Returns lengths, centers, coords, types, radii, labels all padded to fit maximum size of batch'''
lens = []
centers = []
lcoords = []
ltypes = []
lradii = []
labels = []
for center,coords,types,radii,label in batch:
lens.append(coords.shape[0])
centers.append(center)
lcoords.append(coords)
ltypes.append(types)
lradii.append(radii.unsqueeze(1))
labels.append(torch.tensor(label))


lengths = torch.tensor(lens)
lcoords = torch.nn.utils.rnn.pad_sequence(lcoords, batch_first=True)
ltypes = torch.nn.utils.rnn.pad_sequence(ltypes, batch_first=True)
lradii = torch.nn.utils.rnn.pad_sequence(lradii, batch_first=True)

centers = torch.stack(centers,dim=0)
labels = torch.stack(labels,dim=0)

return lengths, centers, lcoords, ltypes, lradii, labels
23 changes: 23 additions & 0 deletions test/test_example_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,29 @@ def test_pytorch_dataset():
center, coords, types, radii, labels = m[-1]
assert labels[0] == 0
np.testing.assert_allclose(labels[1], -10.3)

'''Testing out the collate_fn when used with torch.utils.data.DataLoader'''
torch_loader = torch.utils.data.DataLoader(
m, batch_size=8,collate_fn=molgrid.MolDataset.collateMolDataset)
iterator = iter(torch_loader)
next(iterator)
lengths, center, coords, types, radii, labels = next(iterator)
assert len(lengths) == 8
assert center.shape[0] == 8
assert coords.shape[0] == 8
assert types.shape[0] == 8
assert radii.shape[0] == 8
assert radii.shape[0] == 8
assert labels.shape[0] == 8

mcenter, mcoords, mtypes, mradii, mlabels = m[10]
np.testing.assert_allclose(center[2],mcenter)
np.testing.assert_allclose(coords[2][:lengths[2]],mcoords)
np.testing.assert_allclose(types[2][:lengths[2]],mtypes)
np.testing.assert_allclose(radii[2][:lengths[2]],mradii.unsqueeze(1))
assert len(labels[2]) == len(mlabels)
assert labels[2][0] == mlabels[0]
assert labels[2][1] == mlabels[1]

def test_duplicated_examples():
'''This is for files with multiple ligands'''
Expand Down

0 comments on commit 6d3737e

Please sign in to comment.