Skip to content

Commit

Permalink
Fix for Issue #77
Browse files Browse the repository at this point in the history
ExampleDataset can now be iterated over and has tests.
  • Loading branch information
dkoes committed Dec 6, 2021
1 parent 2773558 commit 5b618a9
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
3 changes: 3 additions & 0 deletions python/bindings.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,9 @@ MAKE_ALL_GRIDS()
.def("get_type_names", &ExampleDataset::get_type_names)
.def("__getitem__", +[](const ExampleDataset& D, int i) {
if(i < 0) i = D.size()+i; //index from back
if(i < 0 || (size_t)i >= D.size()) {
throw std::out_of_range("Index "+itoa(i)+" invalid for ExampleDataset with size "+itoa(D.size()));
}
return D[i];
});

Expand Down
29 changes: 29 additions & 0 deletions test/test_example_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
import molgrid
import numpy as np
import os
import torch

from pytest import approx
from numpy import around

datadir = os.path.dirname(__file__) + '/data'

#make sure we can map and iterate
def test_example_dataset():
fname = datadir + "/small.types"
e = molgrid.ExampleDataset(data_root=datadir + "/structs")
e.populate(fname)

assert len(e) == 1000
assert e[-1].labels[1] == approx(-10.3)
assert e[3].labels[1] == approx(-6.05)

for ex in e:
pass
assert ex.labels[1] == approx(-10.3)

for ex in e:
break

assert ex.labels[1] == approx(6.05)

0 comments on commit 5b618a9

Please sign in to comment.