-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathluna_p_local.py
67 lines (56 loc) · 3.33 KB
/
luna_p_local.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
import data_transforms
import data_iterators
import pathfinder
from functools import partial
import utils
restart_from_save = None
rng = np.random.RandomState(42)
restart_from_save = None
rng = np.random.RandomState(42)
# transformations
p_transform = {'patch_size': (64, 64, 64),
'mm_patch_size': (64, 64, 64),
'pixel_spacing': (1., 1., 1.)
}
p_transform_augment = {
'translation_range_z': [-16, 16],
'translation_range_y': [-16, 16],
'translation_range_x': [-16, 16],
'rotation_range_z': [-180, 180],
'rotation_range_y': [-180, 180],
'rotation_range_x': [-180, 180]
}
# data preparation function
def data_prep_function(data, patch_center, luna_annotations, pixel_spacing, luna_origin, p_transform,
p_transform_augment, **kwargs):
x, patch_annotation_tf, annotations_tf = data_transforms.transform_patch3d(data=data,
luna_annotations=luna_annotations,
patch_center=patch_center,
p_transform=p_transform,
p_transform_augment=p_transform_augment,
pixel_spacing=pixel_spacing,
luna_origin=luna_origin)
x = data_transforms.pixelnormHU(x)
y = data_transforms.make_3d_mask_from_annotations(img_shape=x.shape, annotations=annotations_tf, shape='sphere')
return x, y
data_prep_function_train = partial(data_prep_function, p_transform_augment=p_transform_augment, p_transform=p_transform)
data_prep_function_valid = partial(data_prep_function, p_transform_augment=None, p_transform=p_transform)
# data iterators
batch_size = 4
nbatches_chunk = 8
chunk_size = batch_size * nbatches_chunk
train_valid_ids = utils.load_pkl(pathfinder.LUNA_VALIDATION_SPLIT_PATH)
train_pids, valid_pids = train_valid_ids['train'], train_valid_ids['valid']
train_data_iterator = data_iterators.PatchPositiveLunaDataGenerator(data_path=pathfinder.LUNA_DATA_PATH,
batch_size=chunk_size,
transform_params=p_transform,
data_prep_fun=data_prep_function_train,
rng=rng,
patient_ids=train_pids,
full_batch=True, random=True, infinite=True)
valid_data_iterator = data_iterators.ValidPatchPositiveLunaDataGenerator(data_path=pathfinder.LUNA_DATA_PATH,
transform_params=p_transform,
data_prep_fun=data_prep_function_valid,
patient_ids=valid_pids)
nchunks_per_epoch = train_data_iterator.nsamples / chunk_size