forked from bowang-lab/MedSAM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_multi_gpus.py
executable file
·505 lines (459 loc) · 17.8 KB
/
train_multi_gpus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
# -*- coding: utf-8 -*-
"""
train the image encoder and mask decoder
freeze prompt image encoder
"""
# %% setup environment
import numpy as np
import matplotlib.pyplot as plt
import os
join = os.path.join
from tqdm import tqdm
from skimage import transform
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp
import monai
from segment_anything import sam_model_registry
import torch.nn.functional as F
import argparse
import random
from datetime import datetime
import shutil
import glob
# set seeds
torch.manual_seed(2023)
torch.cuda.empty_cache()
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(
plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2)
)
class NpyDataset(Dataset):
def __init__(self, data_root, bbox_shift=20):
self.data_root = data_root
self.gt_path = join(data_root, "gts")
self.img_path = join(data_root, "imgs")
self.gt_path_files = sorted(
glob.glob(join(self.gt_path, "**/*.npy"), recursive=True)
)
self.gt_path_files = [
file
for file in self.gt_path_files
if os.path.isfile(join(self.img_path, os.path.basename(file)))
]
self.bbox_shift = bbox_shift
print(f"number of images: {len(self.gt_path_files)}")
def __len__(self):
return len(self.gt_path_files)
def __getitem__(self, index):
# load npy image (1024, 1024, 3), [0,1]
img_name = os.path.basename(self.gt_path_files[index])
img_1024 = np.load(
join(self.img_path, img_name), "r", allow_pickle=True
) # (1024, 1024, 3)
# convert the shape to (3, H, W)
img_1024 = np.transpose(img_1024, (2, 0, 1))
assert (
np.max(img_1024) <= 1.0 and np.min(img_1024) >= 0.0
), "image should be normalized to [0, 1]"
gt = np.load(
self.gt_path_files[index], "r", allow_pickle=True
) # multiple labels [0, 1,4,5...], (256,256)
assert img_name == os.path.basename(self.gt_path_files[index]), (
"img gt name error" + self.gt_path_files[index] + self.npy_files[index]
)
label_ids = np.unique(gt)[1:]
gt2D = np.uint8(
gt == random.choice(label_ids.tolist())
) # only one label, (256, 256)
assert np.max(gt2D) == 1 and np.min(gt2D) == 0.0, "ground truth should be 0, 1"
y_indices, x_indices = np.where(gt2D > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
# add perturbation to bounding box coordinates
H, W = gt2D.shape
x_min = max(0, x_min - random.randint(0, self.bbox_shift))
x_max = min(W, x_max + random.randint(0, self.bbox_shift))
y_min = max(0, y_min - random.randint(0, self.bbox_shift))
y_max = min(H, y_max + random.randint(0, self.bbox_shift))
bboxes = np.array([x_min, y_min, x_max, y_max])
return (
torch.tensor(img_1024).float(),
torch.tensor(gt2D[None, :, :]).long(),
torch.tensor(bboxes).float(),
img_name,
)
# %% sanity test of dataset class
tr_dataset = NpyDataset("data/npy/CT_Abd")
tr_dataloader = DataLoader(tr_dataset, batch_size=8, shuffle=True)
for step, (image, gt, bboxes, names_temp) in enumerate(tr_dataloader):
print(image.shape, gt.shape, bboxes.shape)
# show the example
_, axs = plt.subplots(1, 2, figsize=(25, 25))
idx = random.randint(0, 7)
axs[0].imshow(image[idx].cpu().permute(1, 2, 0).numpy())
show_mask(gt[idx].cpu().numpy(), axs[0])
show_box(bboxes[idx].numpy(), axs[0])
axs[0].axis("off")
# set title
axs[0].set_title(names_temp[idx])
idx = random.randint(0, 7)
axs[1].imshow(image[idx].cpu().permute(1, 2, 0).numpy())
show_mask(gt[idx].cpu().numpy(), axs[1])
show_box(bboxes[idx].numpy(), axs[1])
axs[1].axis("off")
# set title
axs[1].set_title(names_temp[idx])
# plt.show()
plt.subplots_adjust(wspace=0.01, hspace=0)
plt.savefig("./data_sanitycheck.png", bbox_inches="tight", dpi=300)
plt.close()
break
# %% set up parser
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--tr_npy_path",
type=str,
default="data/npy/CT_Abd",
help="path to training npy files; two subfolders: gts and imgs",
)
parser.add_argument("-task_name", type=str, default="MedSAM-ViT-B")
parser.add_argument("-model_type", type=str, default="vit_b")
parser.add_argument(
"-checkpoint", type=str, default="work_dir/SAM/sam_vit_b_01ec64.pth"
)
# parser.add_argument('-device', type=str, default='cuda:0')
parser.add_argument(
"--load_pretrain", type=bool, default=True, help="use wandb to monitor training"
)
parser.add_argument("-pretrain_model_path", type=str, default="")
parser.add_argument("-work_dir", type=str, default="./work_dir")
# train
parser.add_argument("-num_epochs", type=int, default=1000)
parser.add_argument("-batch_size", type=int, default=8)
parser.add_argument("-num_workers", type=int, default=8)
# Optimizer parameters
parser.add_argument(
"-weight_decay", type=float, default=0.01, help="weight decay (default: 0.01)"
)
parser.add_argument(
"-lr", type=float, default=0.0001, metavar="LR", help="learning rate (absolute lr)"
)
parser.add_argument(
"-use_wandb", type=bool, default=False, help="use wandb to monitor training"
)
parser.add_argument("-use_amp", action="store_true", default=False, help="use amp")
## Distributed training args
parser.add_argument("--world_size", type=int, help="world size")
parser.add_argument("--node_rank", type=int, default=0, help="Node rank")
parser.add_argument(
"--bucket_cap_mb",
type=int,
default=25,
help="The amount of memory in Mb that DDP will accumulate before firing off gradient communication for the bucket (need to tune)",
)
parser.add_argument(
"--grad_acc_steps",
type=int,
default=1,
help="Gradient accumulation steps before syncing gradients for backprop",
)
parser.add_argument(
"--resume", type=str, default="", help="Resuming training from checkpoint"
)
parser.add_argument("--init_method", type=str, default="env://")
args = parser.parse_args()
if args.use_wandb:
import wandb
wandb.login()
wandb.init(
project=args.task_name,
config={
"lr": args.lr,
"batch_size": args.batch_size,
"data_path": args.tr_npy_path,
"model_type": args.model_type,
},
)
# %% set up model for fine-tuning
# device = args.device
run_id = datetime.now().strftime("%Y%m%d-%H%M")
model_save_path = join(args.work_dir, args.task_name + "-" + run_id)
# %% set up model
class MedSAM(nn.Module):
def __init__(
self,
image_encoder,
mask_decoder,
prompt_encoder,
):
super().__init__()
self.image_encoder = image_encoder
self.mask_decoder = mask_decoder
self.prompt_encoder = prompt_encoder
# freeze prompt encoder
for param in self.prompt_encoder.parameters():
param.requires_grad = False
def forward(self, image, box):
image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
# do not compute gradients for prompt encoder
with torch.no_grad():
box_torch = torch.as_tensor(box, dtype=torch.float32, device=image.device)
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :] # (B, 1, 4)
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
low_res_masks, _ = self.mask_decoder(
image_embeddings=image_embedding, # (B, 256, 64, 64)
image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=False,
)
ori_res_masks = F.interpolate(
low_res_masks,
size=(image.shape[2], image.shape[3]),
mode="bilinear",
align_corners=False,
)
return ori_res_masks
def main():
ngpus_per_node = torch.cuda.device_count()
print("Spwaning processces")
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
def main_worker(gpu, ngpus_per_node, args):
node_rank = int(args.node_rank)
rank = node_rank * ngpus_per_node + gpu
world_size = args.world_size
print(f"[Rank {rank}]: Use GPU: {gpu} for training")
is_main_host = rank == 0
if is_main_host:
os.makedirs(model_save_path, exist_ok=True)
shutil.copyfile(
__file__, join(model_save_path, run_id + "_" + os.path.basename(__file__))
)
torch.cuda.set_device(gpu)
# device = torch.device("cuda:{}".format(gpu))
torch.distributed.init_process_group(
backend="nccl", init_method=args.init_method, rank=rank, world_size=world_size
)
sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
medsam_model = MedSAM(
image_encoder=sam_model.image_encoder,
mask_decoder=sam_model.mask_decoder,
prompt_encoder=sam_model.prompt_encoder
).cuda()
cuda_mem_info = torch.cuda.mem_get_info(gpu)
free_cuda_mem, total_cuda_mem = cuda_mem_info[0] / (1024**3), cuda_mem_info[1] / (
1024**3
)
print(
f"[RANK {rank}: GPU {gpu}] Total CUDA memory before DDP initialised: {total_cuda_mem} Gb"
)
print(
f"[RANK {rank}: GPU {gpu}] Free CUDA memory before DDP initialised: {free_cuda_mem} Gb"
)
if rank % ngpus_per_node == 0:
print("Before DDP initialization:")
os.system("nvidia-smi")
medsam_model = nn.parallel.DistributedDataParallel(
medsam_model,
device_ids=[gpu],
output_device=gpu,
gradient_as_bucket_view=True,
find_unused_parameters=True,
bucket_cap_mb=args.bucket_cap_mb, ## Too large -> comminitation overlap, too small -> unable to overlap with computation
)
cuda_mem_info = torch.cuda.mem_get_info(gpu)
free_cuda_mem, total_cuda_mem = cuda_mem_info[0] / (1024**3), cuda_mem_info[1] / (
1024**3
)
print(
f"[RANK {rank}: GPU {gpu}] Total CUDA memory after DDP initialised: {total_cuda_mem} Gb"
)
print(
f"[RANK {rank}: GPU {gpu}] Free CUDA memory after DDP initialised: {free_cuda_mem} Gb"
)
if rank % ngpus_per_node == 0:
print("After DDP initialization:")
os.system("nvidia-smi")
medsam_model.train()
print(
"Number of total parameters: ",
sum(p.numel() for p in medsam_model.parameters()),
) # 93735472
print(
"Number of trainable parameters: ",
sum(p.numel() for p in medsam_model.parameters() if p.requires_grad),
) # 93729252
## Setting up optimiser and loss func
# only optimize the parameters of image encodder, mask decoder, do not update prompt encoder
# img_mask_encdec_params = list(medsam_model.image_encoder.parameters()) + list(medsam_model.mask_decoder.parameters())
img_mask_encdec_params = list(
medsam_model.module.image_encoder.parameters()
) + list(medsam_model.module.mask_decoder.parameters())
optimizer = torch.optim.AdamW(
img_mask_encdec_params, lr=args.lr, weight_decay=args.weight_decay
)
print(
"Number of image encoder and mask decoder parameters: ",
sum(p.numel() for p in img_mask_encdec_params if p.requires_grad),
) # 93729252
seg_loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True, reduction="mean")
ce_loss = nn.BCEWithLogitsLoss(reduction="mean")
# %% train
num_epochs = args.num_epochs
iter_num = 0
losses = []
best_loss = 1e10
train_dataset = NpyDataset(args.tr_npy_path)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
## Distributed sampler has done the shuffling for you,
## So no need to shuffle in dataloader
print("Number of training samples: ", len(train_dataset))
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=(train_sampler is None),
num_workers=args.num_workers,
pin_memory=True,
sampler=train_sampler,
)
start_epoch = 0
if args.resume is not None:
if os.path.isfile(args.resume):
print(rank, "=> loading checkpoint '{}'".format(args.resume))
## Map model to be loaded to specified single GPU
loc = "cuda:{}".format(gpu)
checkpoint = torch.load(args.resume, map_location=loc)
start_epoch = checkpoint["epoch"] + 1
medsam_model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
print(
rank,
"=> loaded checkpoint '{}' (epoch {})".format(
args.resume, checkpoint["epoch"]
),
)
torch.distributed.barrier()
if args.use_amp:
scaler = torch.cuda.amp.GradScaler()
print(f"[RANK {rank}: GPU {gpu}] Using AMP for training")
for epoch in range(start_epoch, num_epochs):
epoch_loss = 0
train_dataloader.sampler.set_epoch(epoch)
for step, (image, gt2D, boxes, _) in enumerate(
tqdm(train_dataloader, desc=f"[RANK {rank}: GPU {gpu}]")
):
optimizer.zero_grad()
boxes_np = boxes.detach().cpu().numpy()
# image, gt2D = image.to(device), gt2D.to(device)
image, gt2D = image.cuda(), gt2D.cuda()
if args.use_amp:
## AMP
with torch.autocast(device_type="cuda", dtype=torch.float16):
medsam_pred = medsam_model(image, boxes_np)
loss = seg_loss(medsam_pred, gt2D) + ce_loss(
medsam_pred, gt2D.float()
)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
else:
medsam_pred = medsam_model(image, boxes_np)
loss = seg_loss(medsam_pred, gt2D) + ce_loss(
medsam_pred, gt2D.float()
)
# Gradient accumulation
if args.grad_acc_steps > 1:
loss = (
loss / args.grad_acc_steps
) # normalize the loss because it is accumulated
if (step + 1) % args.grad_acc_steps == 0:
## Perform gradient sync
loss.backward()
optimizer.step()
optimizer.zero_grad()
else:
## Accumulate gradient on current node without backproping
with medsam_model.no_sync():
loss.backward() ## calculate the gradient only
else:
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step > 10 and step % 100 == 0:
if is_main_host:
checkpoint = {
"model": medsam_model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
}
torch.save(
checkpoint,
join(model_save_path, "medsam_model_latest_step.pth"),
)
epoch_loss += loss.item()
iter_num += 1
# if rank % ngpus_per_node == 0:
# print('\n')
# os.system('nvidia-smi')
# print('\n')
# Check CUDA memory usage
cuda_mem_info = torch.cuda.mem_get_info(gpu)
free_cuda_mem, total_cuda_mem = cuda_mem_info[0] / (1024**3), cuda_mem_info[
1
] / (1024**3)
print("\n")
print(f"[RANK {rank}: GPU {gpu}] Total CUDA memory: {total_cuda_mem} Gb")
print(f"[RANK {rank}: GPU {gpu}] Free CUDA memory: {free_cuda_mem} Gb")
print(
f"[RANK {rank}: GPU {gpu}] Used CUDA memory: {total_cuda_mem - free_cuda_mem} Gb"
)
print("\n")
epoch_loss /= step
losses.append(epoch_loss)
if args.use_wandb:
wandb.log({"epoch_loss": epoch_loss})
print(
f'Time: {datetime.now().strftime("%Y%m%d-%H%M")}, Epoch: {epoch}, Loss: {epoch_loss}'
)
# save the model checkpoint
if is_main_host:
checkpoint = {
"model": medsam_model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
}
torch.save(checkpoint, join(model_save_path, "medsam_model_latest.pth"))
## save the best model
if epoch_loss < best_loss:
best_loss = epoch_loss
torch.save(checkpoint, join(model_save_path, "medsam_model_best.pth"))
torch.distributed.barrier()
# %% plot loss
plt.plot(losses)
plt.title("Dice + Cross Entropy Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
# plt.show() # comment this line if you are running on a server
plt.savefig(join(model_save_path, args.task_name + "train_loss.png"))
plt.close()
if __name__ == "__main__":
main()