forked from bowang-lab/MedSAM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
363 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,363 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Fine-tune SAM on customized datasets (3D example)\n", | ||
"1. Prepare original 3D images `data/FLARE22Train/` (Download link:https://zenodo.org/record/7860267) \n", | ||
"2. Run `pre_CT.py` for pre-processing. Expected output: `./data/Npz_files/CT_Abd-Gallbladder_`\n", | ||
"3. Start this fine-tuning tutorial" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# %% set up environment\n", | ||
"import numpy as np\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import os\n", | ||
"join = os.path.join\n", | ||
"from tqdm import tqdm\n", | ||
"import torch\n", | ||
"from torch.utils.data import Dataset, DataLoader\n", | ||
"import monai\n", | ||
"from segment_anything import SamPredictor, sam_model_registry\n", | ||
"from segment_anything.utils.transforms import ResizeLongestSide\n", | ||
"from utils.SurfaceDice import compute_dice_coefficient\n", | ||
"# set seeds\n", | ||
"torch.manual_seed(2023)\n", | ||
"np.random.seed(2023)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#%% create a dataset class to load npz data and return back image embeddings and ground truth\n", | ||
"class NpzDataset(Dataset): \n", | ||
" def __init__(self, data_root):\n", | ||
" self.data_root = data_root\n", | ||
" self.npz_files = sorted(os.listdir(self.data_root)) \n", | ||
" self.npz_data = [np.load(join(data_root, f)) for f in self.npz_files]\n", | ||
" # this implementation is ugly but it works (and is also fast for feeding data to GPU) if your server has enough RAM\n", | ||
" # as an alternative, you can also use a list of npy files and load them one by one\n", | ||
" self.ori_gts = np.vstack([d['gts'] for d in self.npz_data])\n", | ||
" self.img_embeddings = np.vstack([d['img_embeddings'] for d in self.npz_data])\n", | ||
" print(f\"{self.img_embeddings.shape=}, {self.ori_gts.shape=}\")\n", | ||
" \n", | ||
" def __len__(self):\n", | ||
" return self.ori_gts.shape[0]\n", | ||
"\n", | ||
" def __getitem__(self, index):\n", | ||
" img_embed = self.img_embeddings[index]\n", | ||
" gt2D = self.ori_gts[index]\n", | ||
" y_indices, x_indices = np.where(gt2D > 0)\n", | ||
" x_min, x_max = np.min(x_indices), np.max(x_indices)\n", | ||
" y_min, y_max = np.min(y_indices), np.max(y_indices)\n", | ||
" # add perturbation to bounding box coordinates\n", | ||
" H, W = gt2D.shape\n", | ||
" x_min = max(0, x_min - np.random.randint(0, 20))\n", | ||
" x_max = min(W, x_max + np.random.randint(0, 20))\n", | ||
" y_min = max(0, y_min - np.random.randint(0, 20))\n", | ||
" y_max = min(H, y_max + np.random.randint(0, 20))\n", | ||
" bboxes = np.array([x_min, y_min, x_max, y_max])\n", | ||
" # convert img embedding, mask, bounding box to torch tensor\n", | ||
" return torch.tensor(img_embed).float(), torch.tensor(gt2D[None, :,:]).long(), torch.tensor(bboxes).float()\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# %% test dataset class and dataloader\n", | ||
"npz_tr_path = 'data/Npz_files/CT_Abd-Gallbladder/train'\n", | ||
"demo_dataset = NpzDataset(npz_tr_path)\n", | ||
"demo_dataloader = DataLoader(demo_dataset, batch_size=8, shuffle=True)\n", | ||
"for img_embed, gt2D, bboxes in demo_dataloader:\n", | ||
" # img_embed: (B, 256, 64, 64), gt2D: (B, 1, 256, 256), bboxes: (B, 4)\n", | ||
" print(f\"{img_embed.shape=}, {gt2D.shape=}, {bboxes.shape=}\")\n", | ||
" break" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# %% set up model for fine-tuning \n", | ||
"# train data path\n", | ||
"npz_tr_path = 'data/Npz_files/CT_Abd-Gallbladder/train'\n", | ||
"work_dir = './work_dir'\n", | ||
"task_name = 'CT_Abd-Gallbladder'\n", | ||
"# prepare SAM model\n", | ||
"model_type = 'vit_b'\n", | ||
"checkpoint = 'work_dir/SAM/sam_vit_b_01ec64.pth'\n", | ||
"device = 'cuda:0'\n", | ||
"model_save_path = join(work_dir, task_name)\n", | ||
"os.makedirs(model_save_path, exist_ok=True)\n", | ||
"sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device)\n", | ||
"sam_model.train()\n", | ||
"\n", | ||
"# Set up the optimizer, hyperparameter tuning will improve performance here\n", | ||
"optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)\n", | ||
"seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#%% train\n", | ||
"num_epochs = 100\n", | ||
"losses = []\n", | ||
"best_loss = 1e10\n", | ||
"train_dataset = NpzDataset(npz_tr_path)\n", | ||
"train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n", | ||
"for epoch in range(num_epochs):\n", | ||
" epoch_loss = 0\n", | ||
" # train\n", | ||
" for step, (image_embedding, gt2D, boxes) in enumerate(tqdm(train_dataloader)):\n", | ||
" # do not compute gradients for image encoder and prompt encoder\n", | ||
" with torch.no_grad():\n", | ||
" # convert box to 1024x1024 grid\n", | ||
" box_np = boxes.numpy()\n", | ||
" sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size)\n", | ||
" box = sam_trans.apply_boxes(box_np, (gt2D.shape[-2], gt2D.shape[-1]))\n", | ||
" box_torch = torch.as_tensor(box, dtype=torch.float, device=device)\n", | ||
" if len(box_torch.shape) == 2:\n", | ||
" box_torch = box_torch[:, None, :] # (B, 1, 4)\n", | ||
" # get prompt embeddings \n", | ||
" sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(\n", | ||
" points=None,\n", | ||
" boxes=box_torch,\n", | ||
" masks=None,\n", | ||
" )\n", | ||
" # predicted masks\n", | ||
" mask_predictions, _ = sam_model.mask_decoder(\n", | ||
" image_embeddings=image_embedding.to(device), # (B, 256, 64, 64)\n", | ||
" image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)\n", | ||
" sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)\n", | ||
" dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)\n", | ||
" multimask_output=False,\n", | ||
" )\n", | ||
"\n", | ||
" loss = seg_loss(mask_predictions, gt2D.to(device))\n", | ||
" optimizer.zero_grad()\n", | ||
" loss.backward()\n", | ||
" optimizer.step()\n", | ||
" epoch_loss += loss.item()\n", | ||
" \n", | ||
" epoch_loss /= step\n", | ||
" losses.append(epoch_loss)\n", | ||
" print(f'EPOCH: {epoch}, Loss: {epoch_loss}')\n", | ||
" # save the latest model checkpoint\n", | ||
" torch.save(sam_model.state_dict(), join(model_save_path, 'sam_model_latest.pth'))\n", | ||
" # save the best model\n", | ||
" if epoch_loss < best_loss:\n", | ||
" best_loss = epoch_loss\n", | ||
" torch.save(sam_model.state_dict(), join(model_save_path, 'sam_model_best.pth'))\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# plot loss\n", | ||
"plt.plot(losses)\n", | ||
"plt.title('Dice + Cross Entropy Loss')\n", | ||
"plt.xlabel('Epoch')\n", | ||
"plt.ylabel('Loss')\n", | ||
"plt.show() # comment this line if you are running on a server\n", | ||
"plt.savefig(join(model_save_path, 'train_loss.png'))\n", | ||
"plt.close()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#%% compare the segmentation results between the original SAM model and the fine-tuned model\n", | ||
"# load the original SAM model\n", | ||
"ori_sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device)\n", | ||
"ori_sam_predictor = SamPredictor(ori_sam_model)\n", | ||
"npz_ts_path = 'data/Npz_files/CT_Abd-Gallbladder/test'\n", | ||
"test_npzs = sorted(os.listdir(npz_ts_path))\n", | ||
"# random select a test case\n", | ||
"npz_idx = np.random.randint(0, len(test_npzs))\n", | ||
"npz = np.load(join(npz_ts_path, test_npzs[npz_idx]))\n", | ||
"imgs = npz['imgs']\n", | ||
"gts = npz['gts']\n", | ||
"\n", | ||
"def get_bbox_from_mask(mask):\n", | ||
" '''Returns a bounding box from a mask'''\n", | ||
" y_indices, x_indices = np.where(mask > 0)\n", | ||
" x_min, x_max = np.min(x_indices), np.max(x_indices)\n", | ||
" y_min, y_max = np.min(y_indices), np.max(y_indices)\n", | ||
" # add perturbation to bounding box coordinates\n", | ||
" H, W = mask.shape\n", | ||
" x_min = max(0, x_min - np.random.randint(0, 20))\n", | ||
" x_max = min(W, x_max + np.random.randint(0, 20))\n", | ||
" y_min = max(0, y_min - np.random.randint(0, 20))\n", | ||
" y_max = min(H, y_max + np.random.randint(0, 20))\n", | ||
"\n", | ||
" return np.array([x_min, y_min, x_max, y_max])\n", | ||
"\n", | ||
"ori_sam_segs = []\n", | ||
"medsam_segs = []\n", | ||
"bboxes = []\n", | ||
"for img, gt in zip(imgs, gts):\n", | ||
" bbox = get_bbox_from_mask(gt)\n", | ||
" bboxes.append(bbox)\n", | ||
" # predict the segmentation mask using the original SAM model\n", | ||
" ori_sam_predictor.set_image(img)\n", | ||
" ori_sam_seg, _, _ = ori_sam_predictor.predict(point_coords=None, box=bbox, multimask_output=False)\n", | ||
" ori_sam_segs.append(ori_sam_seg[0])\n", | ||
" \n", | ||
" # predict the segmentation mask using the fine-tuned model\n", | ||
" H, W = img.shape[:2]\n", | ||
" resize_img = sam_trans.apply_image(img)\n", | ||
" resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to(device)\n", | ||
" input_image = sam_model.preprocess(resize_img_tensor[None,:,:,:]) # (1, 3, 1024, 1024)\n", | ||
" with torch.no_grad():\n", | ||
" image_embedding = sam_model.image_encoder(input_image.to(device)) # (1, 256, 64, 64)\n", | ||
" # convert box to 1024x1024 grid\n", | ||
" bbox = sam_trans.apply_boxes(bbox, (H, W))\n", | ||
" box_torch = torch.as_tensor(bbox, dtype=torch.float, device=device)\n", | ||
" if len(box_torch.shape) == 2:\n", | ||
" box_torch = box_torch[:, None, :] # (B, 1, 4)\n", | ||
" \n", | ||
" sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(\n", | ||
" points=None,\n", | ||
" boxes=box_torch,\n", | ||
" masks=None,\n", | ||
" )\n", | ||
" medsam_seg_prob, _ = sam_model.mask_decoder(\n", | ||
" image_embeddings=image_embedding.to(device), # (B, 256, 64, 64)\n", | ||
" image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)\n", | ||
" sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)\n", | ||
" dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)\n", | ||
" multimask_output=False,\n", | ||
" )\n", | ||
" medsam_seg_prob = torch.sigmoid(medsam_seg_prob)\n", | ||
" # convert soft mask to hard mask\n", | ||
" medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()\n", | ||
" medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)\n", | ||
" medsam_segs.append(medsam_seg)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#%% compute the DSC score\n", | ||
"ori_sam_segs = np.stack(ori_sam_segs, axis=0)\n", | ||
"medsam_segs = np.stack(medsam_segs, axis=0)\n", | ||
"ori_sam_dsc = compute_dice_coefficient(gts>0, ori_sam_segs>0)\n", | ||
"medsam_dsc = compute_dice_coefficient(gts>0, medsam_segs>0)\n", | ||
"print('Original SAM DSC: {:.4f}'.format(ori_sam_dsc), 'MedSAM DSC: {:.4f}'.format(medsam_dsc))\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#%% visualize the segmentation results of the middle slice\n", | ||
"# visualization functions\n", | ||
"# source: https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb\n", | ||
"# change color to avoid red and green\n", | ||
"def show_mask(mask, ax, random_color=False):\n", | ||
" if random_color:\n", | ||
" color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n", | ||
" else:\n", | ||
" color = np.array([251/255, 252/255, 30/255, 0.6])\n", | ||
" h, w = mask.shape[-2:]\n", | ||
" mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n", | ||
" ax.imshow(mask_image)\n", | ||
" \n", | ||
"def show_box(box, ax):\n", | ||
" x0, y0 = box[0], box[1]\n", | ||
" w, h = box[2] - box[0], box[3] - box[1]\n", | ||
" ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2)) \n", | ||
"\n", | ||
"\n", | ||
"img_id = int(imgs.shape[0]/2) # np.random.randint(imgs.shape[0])\n", | ||
"_, axs = plt.subplots(1, 3, figsize=(25, 25))\n", | ||
"axs[0].imshow(imgs[img_id])\n", | ||
"show_mask(gts[img_id], axs[0])\n", | ||
"# show_box(box_np[img_id], axs[0])\n", | ||
"# axs[0].set_title('Mask with Tuned Model', fontsize=20)\n", | ||
"axs[0].axis('off')\n", | ||
"\n", | ||
"axs[1].imshow(imgs[img_id])\n", | ||
"show_mask(ori_sam_segs[img_id], axs[1])\n", | ||
"show_box(bboxes[img_id], axs[1])\n", | ||
"# add text to image to show dice score\n", | ||
"axs[1].text(0.5, 0.5, 'SAM DSC: {:.4f}'.format(ori_sam_dsc), fontsize=30, horizontalalignment='left', verticalalignment='top', color='yellow')\n", | ||
"# axs[1].set_title('Mask with Untuned Model', fontsize=20)\n", | ||
"axs[1].axis('off')\n", | ||
"\n", | ||
"axs[2].imshow(imgs[img_id])\n", | ||
"show_mask(medsam_segs[img_id], axs[2])\n", | ||
"show_box(bboxes[img_id], axs[2])\n", | ||
"# add text to image to show dice score\n", | ||
"axs[2].text(0.5, 0.5, 'MedSAM DSC: {:.4f}'.format(medsam_dsc), fontsize=30, horizontalalignment='left', verticalalignment='top', color='yellow')\n", | ||
"# axs[2].set_title('Ground Truth', fontsize=20)\n", | ||
"axs[2].axis('off')\n", | ||
"plt.show() \n", | ||
"plt.subplots_adjust(wspace=0.01, hspace=0)\n", | ||
"# save plot\n", | ||
"# plt.savefig(join(model_save_path, test_npzs[npz_idx].split('.npz')[0] + str(img_id).zfill(3) + '.png'), bbox_inches='tight', dpi=300)\n", | ||
"plt.close()\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "medsam-demo", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.11" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |