Skip to content

Commit

Permalink
ensure that SortishSampler always puts the largest sample first
Browse files Browse the repository at this point in the history
  • Loading branch information
mcskinner committed Apr 29, 2018
1 parent f29e102 commit 9f21874
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 2 additions & 0 deletions fastai/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def __iter__(self):
sort_idx = sum([sorted(s, key=self.key, reverse=True) for s in ck_idx], [])
sz = self.bs
ck_idx = [sort_idx[i:i+sz] for i in range(0, len(sort_idx), sz)]
max_ck = np.argmax([ck[0] for ck in ck_idx])
ck_idx[0],ck_idx[max_ck] = ck_idx[max_ck],ck_idx[0]
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:]))
sort_idx = np.concatenate((ck_idx[0], sort_idx))
return iter(sort_idx)
Expand Down
3 changes: 1 addition & 2 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,4 @@ def test_sortish_sampler_sorts_each_batch_descending():
s1 > s2 or (i+1) % bs == 0 # don't check batch boundaries
for i, (s1, s2) in enumerate(zip(samp, samp[1:]))
)
# Not always true, though the class comment implies it should be.
# assert samp[0] == max(samp)
assert samp[0] == max(samp)

0 comments on commit 9f21874

Please sign in to comment.