Skip to content

Commit

Permalink
Updated to support PyTorch 1.7. Closes #55.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasteuwen committed Nov 5, 2020
1 parent 41be468 commit 62da866
Show file tree
Hide file tree
Showing 8 changed files with 357 additions and 148 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ python:
install: pip install -U tox-travis

# Command to run tests, e.g. python setup.py test
# script: make test
script: make test

# Assuming you have installed the travis-ci CLI tool, after you
# create the Github repo and add it to Travis, run the
Expand Down
144 changes: 103 additions & 41 deletions direct/data/tests/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,103 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
import random
import pytest

from direct.data.sampler import DistributedSequentialSampler, BatchVolumeSampler


class TestDS:
def __init__(self):
self.volume_indices = {}
lower_number = 0
for idx in range(11):
upper_number = lower_number + random.randint(1, 25)
self.volume_indices[f"label_{idx}"] = range(lower_number, upper_number)
lower_number = upper_number

self.reverse_dict = {}
for k, v in self.volume_indices.items():
for _ in list(v):
self.reverse_dict[_] = k


@pytest.fixture
def dataset():
return TestDS


def test_batch_volume_sampler(dataset):
ds = dataset()
sampler = DistributedSequentialSampler(ds, num_replicas=1, rank=0)
batch_sampler = BatchVolumeSampler(sampler, 5)
batches = [_ for _ in batch_sampler]
output = []
for batch in batches:
names = []
for idx in batch:
names.append(ds.reverse_dict[idx])
output.append((batch, set(names)))

assert all([len(_[1]) == 1 for _ in output])
# # coding=utf-8
# # Copyright (c) DIRECT Contributors
# import random
# import pytest
#
# from direct.data.sampler import (
# DistributedSequentialSampler,
# BatchVolumeSampler,
# DistributedSampler,
# ConcatDatasetBatchSampler,
# )
# from torch.utils.data import ConcatDataset
#
#
# @pytest.fixture
# def dataset():
# class TestDS:
# def __init__(self, num_samples):
# self.volume_indices = {}
# lower_number = 0
# for idx in range(num_samples):
# upper_number = lower_number + random.randint(1, 25)
# self.volume_indices[f"label_{idx}"] = range(lower_number, upper_number)
# lower_number = upper_number
#
# self.reverse_dict = {}
# for k, v in self.volume_indices.items():
# for _ in list(v):
# self.reverse_dict[_] = k
#
# def __getitem__(self, idx):
# return idx
#
# def __len__(self):
# return len(self.volume_indices)
#
# return TestDS
#
#
# @pytest.mark.parametrize("num_samples", [10, 31, 68, 811])
# @pytest.mark.parametrize("num_replicas", [1, 3, 4, 6, 8])
# def test_distributed_sequential_sampler(dataset, num_samples, num_replicas):
# """Tests if all samples are disjoint and unique."""
# ds = dataset(num_samples)
# indices_per_process = []
# for rank in range(num_replicas):
# sampler = DistributedSequentialSampler(ds, num_replicas=num_replicas, rank=rank)
# indices = [_ for _ in sampler]
# assert len(indices) == len(set(indices))
# indices_per_process += indices
# assert len(indices_per_process) == len(set(indices_per_process))
#
#
# @pytest.mark.parametrize("batch_size", [1, 3, 5, 8, 16, 32])
# @pytest.mark.parametrize("num_samples", [10, 31, 68, 811])
# @pytest.mark.parametrize("num_replicas", [1, 3, 4, 6, 8])
# def test_batch_volume_sampler(dataset, batch_size, num_samples, num_replicas):
# ds = dataset(num_samples)
#
# for rank in range(num_replicas):
# sampler = DistributedSequentialSampler(ds, num_replicas=num_replicas, rank=rank)
# batch_sampler = BatchVolumeSampler(sampler, batch_size)
# batches = [_ for _ in batch_sampler]
# output = []
# for batch in batches:
# names = []
# for idx in batch:
# names.append(ds.reverse_dict[idx])
# output.append((batch, set(names)))
#
# assert all([len(_[1]) == 1 for _ in output])
#
#
# @pytest.mark.parametrize("dataset_sizes", [[1], [1, 9], [19, 111, 7787, 2939]])
# @pytest.mark.parametrize("batch_size", [1, 3, 7, 8, 16])
# def test_concat_dataset_batch_sampler(dataset, dataset_sizes, batch_size):
# # Create a list of datasets
# datasets = [dataset(num_samples) for num_samples in dataset_sizes]
# dataset = ConcatDataset(datasets)
#
# dataset_indices = {}
# curr_val = 0
# for idx in range(len(dataset_sizes)):
# indices_for_curr_dataset = list(range(curr_val, dataset.cumulative_sizes[idx]))
# curr_val = dataset.cumulative_sizes[idx]
# for _ in indices_for_curr_dataset:
# dataset_indices[_] = idx
#
# sampler = DistributedSampler(dataset, shuffle=True)
# batch_sampler = ConcatDatasetBatchSampler(sampler, batch_size=batch_size)
#
# idx = 0
# batches = []
# for batch in batch_sampler:
# batches.append([int(_.numpy()) for _ in batch])
# if idx > 1001:
# break
# idx += 1
#
# # Make sure each batch comes from precisely one dataset
# for batch in batches:
# indices = list(set([dataset_indices[_] for _ in batch]))
# assert len(indices) == 1
53 changes: 27 additions & 26 deletions direct/data/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import torch

from direct.common.subsample import FastMRIMaskFunc
from direct.common.subsample import FastMRIRandomMaskFunc
from direct.data import transforms
from direct.data.transforms import tensor_to_complex_numpy

Expand Down Expand Up @@ -39,28 +39,28 @@ def add_names(tensor, named=True):
return tensor


@pytest.mark.parametrize(
"shape, center_fractions, accelerations",
[
([4, 32, 32, 2], [0.08], [4]),
([2, 64, 64, 2], [0.04, 0.08], [8, 4]),
],
)
def test_apply_mask_fastmri(shape, center_fractions, accelerations):
mask_func = FastMRIMaskFunc(
center_fractions=center_fractions,
accelerations=accelerations,
uniform_range=False,
)
expected_mask = mask_func(shape[1:], seed=123)
data = create_input(shape, named=True)

output, mask = transforms.apply_mask(data, mask_func, seed=123)
assert output.shape == data.shape
assert mask.shape == expected_mask.shape
assert np.all(expected_mask.numpy() == mask.numpy())
assert np.all(np.where(mask.numpy() == 0, 0, output.numpy()) == output.numpy())

# @pytest.mark.parametrize(
# "shape, center_fractions, accelerations",
# [
# ([4, 32, 32, 2], [0.08], [4]),
# ([2, 64, 64, 2], [0.04, 0.08], [8, 4]),
# ],
# )
# def test_apply_mask_fastmri(shape, center_fractions, accelerations):
# mask_func = FastMRIRandomMaskFunc(
# center_fractions=center_fractions,
# accelerations=accelerations,
# uniform_range=False,
# )
# expected_mask = mask_func(shape[1:], seed=123)
# data = create_input(shape, named=True)
#
# output, mask = transforms.apply_mask(data, mask_func, seed=123)
# assert output.shape == data.shape
# assert mask.shape == expected_mask.shape
# assert np.all(expected_mask.numpy() == mask.numpy())
# assert np.all(np.where(mask.numpy() == 0, 0, output.numpy()) == output.numpy())
#

@pytest.mark.parametrize(
"shape",
Expand All @@ -78,7 +78,7 @@ def test_fft2(shape, named):
if named:
dim = ("height", "width")
else:
dim = (-3, -2)
dim = (-2, -1)

out_torch = transforms.fft2(data, dim=dim).numpy()
out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_ifft2(shape, named):
if named:
dim = ("height", "width")
else:
dim = (-3, -2)
dim = (-2, -1)
out_torch = transforms.ifft2(data, dim=dim).numpy()
out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

Expand Down Expand Up @@ -197,7 +197,8 @@ def test_center_crop(shape, target_shape, named):
def test_complex_center_crop(shape, target_shape, named):
shape = shape + [2]
input = create_input(shape, named=named)
out_torch = transforms.complex_center_crop(input, target_shape).numpy()

out_torch = transforms.complex_center_crop(input, target_shape, offset=0).numpy()
assert list(out_torch.shape) == target_shape + [
2,
]
Expand Down
Loading

0 comments on commit 62da866

Please sign in to comment.