From 6d3737e87bdfee05c5c185d77a8a9eb97b889a95 Mon Sep 17 00:00:00 2001 From: Drew McNutt Date: Mon, 28 Jun 2021 15:41:15 -0400 Subject: [PATCH] Collate function for MolDataset The static method `MolDataset.collateMolDataset` can be used as the collate_fn within `torch.utils.data.DataLoader`. This will return: lengths, centers, coordinates, types, radii and labels of all of the examples in a batch. Length corresponds to the number of atoms in example as each example will be padded to the length of the example with the most atoms in the batch. This allows the use of `gridmaker.forward` on the gpu and using molgrid with DataLoader and `torch.utils.data.distributed.DistributedSampler`. --- python/torch_bindings.py | 29 +++++++++++++++++++++++++++++ test/test_example_provider.py | 23 +++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/python/torch_bindings.py b/python/torch_bindings.py index 8f1373d..d6be6d7 100644 --- a/python/torch_bindings.py +++ b/python/torch_bindings.py @@ -220,3 +220,32 @@ def __setstate__(self,state): self.examples.populate(self.types_files) self.num_labels = self.examples.num_labels() + + @staticmethod + def collateMolDataset(batch): + '''collate_fn for use in torch.utils.data.Dataloader when using the MolDataset. + Returns lengths, centers, coords, types, radii, labels all padded to fit maximum size of batch''' + lens = [] + centers = [] + lcoords = [] + ltypes = [] + lradii = [] + labels = [] + for center,coords,types,radii,label in batch: + lens.append(coords.shape[0]) + centers.append(center) + lcoords.append(coords) + ltypes.append(types) + lradii.append(radii.unsqueeze(1)) + labels.append(torch.tensor(label)) + + + lengths = torch.tensor(lens) + lcoords = torch.nn.utils.rnn.pad_sequence(lcoords, batch_first=True) + ltypes = torch.nn.utils.rnn.pad_sequence(ltypes, batch_first=True) + lradii = torch.nn.utils.rnn.pad_sequence(lradii, batch_first=True) + + centers = torch.stack(centers,dim=0) + labels = torch.stack(labels,dim=0) + + return lengths, centers, lcoords, ltypes, lradii, labels diff --git a/test/test_example_provider.py b/test/test_example_provider.py index 57beb44..98b5be0 100644 --- a/test/test_example_provider.py +++ b/test/test_example_provider.py @@ -347,6 +347,29 @@ def test_pytorch_dataset(): center, coords, types, radii, labels = m[-1] assert labels[0] == 0 np.testing.assert_allclose(labels[1], -10.3) + + '''Testing out the collate_fn when used with torch.utils.data.DataLoader''' + torch_loader = torch.utils.data.DataLoader( + m, batch_size=8,collate_fn=molgrid.MolDataset.collateMolDataset) + iterator = iter(torch_loader) + next(iterator) + lengths, center, coords, types, radii, labels = next(iterator) + assert len(lengths) == 8 + assert center.shape[0] == 8 + assert coords.shape[0] == 8 + assert types.shape[0] == 8 + assert radii.shape[0] == 8 + assert radii.shape[0] == 8 + assert labels.shape[0] == 8 + + mcenter, mcoords, mtypes, mradii, mlabels = m[10] + np.testing.assert_allclose(center[2],mcenter) + np.testing.assert_allclose(coords[2][:lengths[2]],mcoords) + np.testing.assert_allclose(types[2][:lengths[2]],mtypes) + np.testing.assert_allclose(radii[2][:lengths[2]],mradii.unsqueeze(1)) + assert len(labels[2]) == len(mlabels) + assert labels[2][0] == mlabels[0] + assert labels[2][1] == mlabels[1] def test_duplicated_examples(): '''This is for files with multiple ligands'''