Skip to content

Commit

Permalink
Merge pull request fastai#409 from mcskinner/sampler-test
Browse files Browse the repository at this point in the history
add test for `text.py` sampler implementations
  • Loading branch information
jph00 authored Apr 29, 2018
2 parents 2eb29a3 + 9f21874 commit 4b46cf9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
2 changes: 2 additions & 0 deletions fastai/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def __iter__(self):
sort_idx = sum([sorted(s, key=self.key, reverse=True) for s in ck_idx], [])
sz = self.bs
ck_idx = [sort_idx[i:i+sz] for i in range(0, len(sort_idx), sz)]
max_ck = np.argmax([ck[0] for ck in ck_idx])
ck_idx[0],ck_idx[max_ck] = ck_idx[max_ck],ck_idx[0]
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:]))
sort_idx = np.concatenate((ck_idx[0], sort_idx))
return iter(sort_idx)
Expand Down
31 changes: 31 additions & 0 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np

from fastai.text import SortSampler, SortishSampler


def test_sort_sampler_sorts_all_descending():
bs = 4
n = bs*100
data = 2 * np.arange(n)
samp = list(SortSampler(data, lambda i: data[i]))

# The sample is a permutation of the indices.
assert sorted(samp) == list(range(n))
# And that "permutation" is for descending data order.
assert all(s1 > s2 for s1, s2 in zip(samp, samp[1:]))


def test_sortish_sampler_sorts_each_batch_descending():
bs = 4
n = bs*100
data = 2 * np.arange(n)
samp = list(SortishSampler(data, lambda i: data[i], bs))

# The sample is a permutation of the indices.
assert sorted(samp) == list(range(n))
# And that permutation is kind of reverse sorted.
assert all(
s1 > s2 or (i+1) % bs == 0 # don't check batch boundaries
for i, (s1, s2) in enumerate(zip(samp, samp[1:]))
)
assert samp[0] == max(samp)

0 comments on commit 4b46cf9

Please sign in to comment.