-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUNet_Train-v2.py
211 lines (176 loc) · 8.71 KB
/
UNet_Train-v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# this script uses a custom written unet script called custom_unet.py
import os
import glob
from tqdm import tqdm
import pickle
import random
import numpy as np
import pandas as pd
import nibabel as nib
import scipy.ndimage as sci
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
scalar = MinMaxScaler()
from skimage.transform import resize
import tensorflow as tf
from tensorflow import keras
devices = tf.config.list_physical_devices('GPU')
print(devices)
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical
# from keras_preprocessing.image import ImageDataGenerator
def crop_3D(img, new_size):
img_shape = img.shape
x_mid = int(img_shape[0]/2)
y_mid = int(img_shape[1]/2)
z_mid = int(img_shape[2]/2)
x_diff = int(abs(new_size[0]-x_mid))
y_diff = int(abs(new_size[1]-y_mid))
z_diff = int(abs(new_size[2]-z_mid))
x_start = x_mid-x_diff
y_start = y_mid-y_diff
z_start = z_mid-z_diff
tmp_img = img[x_start:x_start+new_size[0],y_start:y_start+new_size[1],z_start:z_start+new_size[2]]
return tmp_img
def generate_brats_batch(prefix,
contrasts,
batch_size=32,
tumour='*',
patient_ids='*',
resample_size=(None, None, None),
augment_size=None,
infinite=True):
"""
Generate arrays for each batch, for x (data) and y (labels), where the contrast is treated like a colour channel.
Example:
x_batch shape: (32, 240, 240, 155, 4)
y_batch shape: (32, 240, 240, 155)
augment_size must be less than or equal to the batch_size, if None will not augment.
"""
file_pattern = '{prefix}/MICCAI_BraTS_2018_Data_Training/{tumour}/{patient_id}/{patient_id}_{contrast}.nii.gz'
while True:
n_classes = 4
# get list of filenames for every contrast available
keys = dict(prefix=prefix, tumour=tumour)
filenames_by_contrast = {}
for contrast in contrasts:
filenames_by_contrast[contrast] = glob.glob(file_pattern.format(contrast=contrast, patient_id=patient_ids, **keys)) if patient_ids == '*' else []
if patient_ids != '*':
contrast_files = []
for patient_id in patient_ids:
contrast_files.extend(glob.glob(file_pattern.format(contrast=contrast, patient_id=patient_id, **keys)))
filenames_by_contrast[contrast] = contrast_files
# get the shape of one 3D volume and initialize the batch lists
arbitrary_contrast = contrasts[0]
shape = nib.load(filenames_by_contrast[arbitrary_contrast][0]).get_fdata().shape if resample_size == (None, None, None) else resample_size
# initialize empty array of batches
x_batch = np.empty((batch_size, ) + shape + (len(contrasts), )) #, dtype=np.int32)
y_batch = np.empty((batch_size, ) + shape + (n_classes,)) #, dtype=np.int32)
num_images = len(filenames_by_contrast[arbitrary_contrast])
np.random.shuffle(filenames_by_contrast[arbitrary_contrast])
for bindex in range(0, num_images, batch_size):
filenames = filenames_by_contrast[arbitrary_contrast][bindex:bindex + batch_size]
for findex, filename in enumerate(filenames):
for cindex, contrast in enumerate(contrasts):
# load raw image batches and normalize the pixels
tmp_img = nib.load(filename.replace(arbitrary_contrast, contrast)).get_fdata()
if resample_size != (None, None, None):
tmp_img = resize(tmp_img, resample_size, mode='edge')
tmp_img = scalar.fit_transform(tmp_img.reshape(-1, tmp_img.shape[-1])).reshape(tmp_img.shape)
x_batch[findex, ..., cindex] = tmp_img
# load mask batches and change to categorical
tmp_mask = nib.load(filename.replace(arbitrary_contrast, 'seg')).get_fdata()
tmp_mask[tmp_mask==4] = 3
tmp_mask = to_categorical(tmp_mask, num_classes=4)
if resample_size != (None, None, None):
tmp_mask = resize(tmp_mask, resample_size, mode='edge')
y_batch[findex] = tmp_mask
if bindex + batch_size > num_images:
x_batch, y_batch = x_batch[:num_images - bindex], y_batch[:num_images - bindex]
if augment_size is not None:
# x_aug, y_aug = augment(x_batch, y_batch, augment_size)
x_aug = None
y_aug = None
yield np.append(x_batch, x_aug), np.append(y_batch, y_aug)
else:
yield x_batch, y_batch
if not infinite:
break
if __name__ == '__main__':
tumours = ['LGG','HGG']
# prefix = '/Users/jasonfung/Documents/EECE571' # Jason's Macbook
# prefix = 'C:/Users/Fungj/Documents/EECE_571F' # Jason's Desktop
brats_dir = '/MICCAI_BraTS_2018_Data_Training/'
prefix = '/home/atom/Documents/datasets/brats' # Adam's Station
file_pattern = '{prefix}/MICCAI_BraTS_2018_Data_Training/{tumour}/{patient_id}/{patient_id}_{contrast}.nii.gz'
# patient_id = 'Brats18_TCIA09_620_1'
contrasts = ['t1','t1ce', 'flair', 't2']
tumours = ['LGG', 'HGG']
data_list_LGG = os.listdir(os.path.join(prefix+brats_dir,tumours[0]))
data_list_HGG = os.listdir(os.path.join(prefix+brats_dir,tumours[1]))
dataset_file_list = data_list_HGG + data_list_LGG
# shuffle and split the dataset file list
import random
random.seed(42)
file_list_shuffled = dataset_file_list.copy()
random.shuffle(file_list_shuffled)
test_ratio = 0.2
train_file, test_file = file_list_shuffled[0:int(len(file_list_shuffled)*(1-test_ratio))], file_list_shuffled[int(len(file_list_shuffled)*(1-test_ratio)):]
while '.DS_Store' in train_file:
train_file.remove('.DS_Store')
while '.DS_Store' in test_file:
test_file.remove('.DS_Store')
# data parameters
x_size = 80
y_size = 80
z_size = 80
contrast_channels = 4
input_shape = (x_size, y_size, z_size, contrast_channels)
n_classes = 4
batch_size = 2
train_datagen = generate_brats_batch(file_pattern,
contrasts,
batch_size = batch_size,
patient_ids = train_file,
resample_size = (x_size, y_size, z_size)) # first iteration
test_datagen = generate_brats_batch(file_pattern,
contrasts,
batch_size = batch_size,
patient_ids = test_file,
resample_size = (x_size, y_size, z_size)) # first iteration
from custom_unet import *
import segmentation_models_3D as sm
sm.set_framework('tf.keras')
# define Hyper Parameters
LR = 0.0005
activation = 'softmax'
optim = tf.keras.optimizers.Adam(LR)
class_weights = [0.25, 0.25, 0.25, 0.25]
# limit memory growth
# gpus = tf.config.experimental.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(gpus[0], True)
# Define Loss Functions
# dice_loss = sm.losses.DiceLoss(class_weights=class_weights)
# dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = 1*focal_loss
metrics = [sm.metrics.IOUScore(threshold = 0.5)]
# Define the model being used. In this case, UNet
model = unet_model((x_size,y_size,z_size,contrast_channels),
n_classes,
max_pooling = True)
model.compile(optimizer = optim, loss = total_loss, metrics = metrics)
steps_per_epoch = len(train_file)//batch_size
val_steps_per_epoch = len(test_file)//batch_size
my_callbacks = [tf.keras.callbacks.EarlyStopping(patience = 5),
tf.keras.callbacks.TensorBoard(log_dir = prefix + '/models/unet/' + './logs'),
tf.keras.callbacks.ModelCheckpoint(filepath = prefix + '/unet_model_20220419.h5', monitor = 'val_loss', save_best_only = True)
]
with tf.device('/device:GPU:0'):
history = model.fit(train_datagen,
steps_per_epoch = steps_per_epoch,
epochs = 30,
verbose = 1,
validation_data = test_datagen,
validation_steps = val_steps_per_epoch,
callbacks = my_callbacks)