Skip to content

Commit

Permalink
add tutorial on 3D dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
JunMa11 authored May 1, 2023
1 parent 21d0cab commit b5cd094
Showing 1 changed file with 363 additions and 0 deletions.
363 changes: 363 additions & 0 deletions finetune_and_inference_tutorial_3D_dataset.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,363 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fine-tune SAM on customized datasets (3D example)\n",
"1. Prepare original 3D images `data/FLARE22Train/` (Download link:https://zenodo.org/record/7860267) \n",
"2. Run `pre_CT.py` for pre-processing. Expected output: `./data/Npz_files/CT_Abd-Gallbladder_`\n",
"3. Start this fine-tuning tutorial"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# %% set up environment\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import os\n",
"join = os.path.join\n",
"from tqdm import tqdm\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import monai\n",
"from segment_anything import SamPredictor, sam_model_registry\n",
"from segment_anything.utils.transforms import ResizeLongestSide\n",
"from utils.SurfaceDice import compute_dice_coefficient\n",
"# set seeds\n",
"torch.manual_seed(2023)\n",
"np.random.seed(2023)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#%% create a dataset class to load npz data and return back image embeddings and ground truth\n",
"class NpzDataset(Dataset): \n",
" def __init__(self, data_root):\n",
" self.data_root = data_root\n",
" self.npz_files = sorted(os.listdir(self.data_root)) \n",
" self.npz_data = [np.load(join(data_root, f)) for f in self.npz_files]\n",
" # this implementation is ugly but it works (and is also fast for feeding data to GPU) if your server has enough RAM\n",
" # as an alternative, you can also use a list of npy files and load them one by one\n",
" self.ori_gts = np.vstack([d['gts'] for d in self.npz_data])\n",
" self.img_embeddings = np.vstack([d['img_embeddings'] for d in self.npz_data])\n",
" print(f\"{self.img_embeddings.shape=}, {self.ori_gts.shape=}\")\n",
" \n",
" def __len__(self):\n",
" return self.ori_gts.shape[0]\n",
"\n",
" def __getitem__(self, index):\n",
" img_embed = self.img_embeddings[index]\n",
" gt2D = self.ori_gts[index]\n",
" y_indices, x_indices = np.where(gt2D > 0)\n",
" x_min, x_max = np.min(x_indices), np.max(x_indices)\n",
" y_min, y_max = np.min(y_indices), np.max(y_indices)\n",
" # add perturbation to bounding box coordinates\n",
" H, W = gt2D.shape\n",
" x_min = max(0, x_min - np.random.randint(0, 20))\n",
" x_max = min(W, x_max + np.random.randint(0, 20))\n",
" y_min = max(0, y_min - np.random.randint(0, 20))\n",
" y_max = min(H, y_max + np.random.randint(0, 20))\n",
" bboxes = np.array([x_min, y_min, x_max, y_max])\n",
" # convert img embedding, mask, bounding box to torch tensor\n",
" return torch.tensor(img_embed).float(), torch.tensor(gt2D[None, :,:]).long(), torch.tensor(bboxes).float()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# %% test dataset class and dataloader\n",
"npz_tr_path = 'data/Npz_files/CT_Abd-Gallbladder/train'\n",
"demo_dataset = NpzDataset(npz_tr_path)\n",
"demo_dataloader = DataLoader(demo_dataset, batch_size=8, shuffle=True)\n",
"for img_embed, gt2D, bboxes in demo_dataloader:\n",
" # img_embed: (B, 256, 64, 64), gt2D: (B, 1, 256, 256), bboxes: (B, 4)\n",
" print(f\"{img_embed.shape=}, {gt2D.shape=}, {bboxes.shape=}\")\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# %% set up model for fine-tuning \n",
"# train data path\n",
"npz_tr_path = 'data/Npz_files/CT_Abd-Gallbladder/train'\n",
"work_dir = './work_dir'\n",
"task_name = 'CT_Abd-Gallbladder'\n",
"# prepare SAM model\n",
"model_type = 'vit_b'\n",
"checkpoint = 'work_dir/SAM/sam_vit_b_01ec64.pth'\n",
"device = 'cuda:0'\n",
"model_save_path = join(work_dir, task_name)\n",
"os.makedirs(model_save_path, exist_ok=True)\n",
"sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device)\n",
"sam_model.train()\n",
"\n",
"# Set up the optimizer, hyperparameter tuning will improve performance here\n",
"optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)\n",
"seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#%% train\n",
"num_epochs = 100\n",
"losses = []\n",
"best_loss = 1e10\n",
"train_dataset = NpzDataset(npz_tr_path)\n",
"train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
"for epoch in range(num_epochs):\n",
" epoch_loss = 0\n",
" # train\n",
" for step, (image_embedding, gt2D, boxes) in enumerate(tqdm(train_dataloader)):\n",
" # do not compute gradients for image encoder and prompt encoder\n",
" with torch.no_grad():\n",
" # convert box to 1024x1024 grid\n",
" box_np = boxes.numpy()\n",
" sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size)\n",
" box = sam_trans.apply_boxes(box_np, (gt2D.shape[-2], gt2D.shape[-1]))\n",
" box_torch = torch.as_tensor(box, dtype=torch.float, device=device)\n",
" if len(box_torch.shape) == 2:\n",
" box_torch = box_torch[:, None, :] # (B, 1, 4)\n",
" # get prompt embeddings \n",
" sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(\n",
" points=None,\n",
" boxes=box_torch,\n",
" masks=None,\n",
" )\n",
" # predicted masks\n",
" mask_predictions, _ = sam_model.mask_decoder(\n",
" image_embeddings=image_embedding.to(device), # (B, 256, 64, 64)\n",
" image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)\n",
" sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)\n",
" dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)\n",
" multimask_output=False,\n",
" )\n",
"\n",
" loss = seg_loss(mask_predictions, gt2D.to(device))\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" epoch_loss += loss.item()\n",
" \n",
" epoch_loss /= step\n",
" losses.append(epoch_loss)\n",
" print(f'EPOCH: {epoch}, Loss: {epoch_loss}')\n",
" # save the latest model checkpoint\n",
" torch.save(sam_model.state_dict(), join(model_save_path, 'sam_model_latest.pth'))\n",
" # save the best model\n",
" if epoch_loss < best_loss:\n",
" best_loss = epoch_loss\n",
" torch.save(sam_model.state_dict(), join(model_save_path, 'sam_model_best.pth'))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot loss\n",
"plt.plot(losses)\n",
"plt.title('Dice + Cross Entropy Loss')\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('Loss')\n",
"plt.show() # comment this line if you are running on a server\n",
"plt.savefig(join(model_save_path, 'train_loss.png'))\n",
"plt.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#%% compare the segmentation results between the original SAM model and the fine-tuned model\n",
"# load the original SAM model\n",
"ori_sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device)\n",
"ori_sam_predictor = SamPredictor(ori_sam_model)\n",
"npz_ts_path = 'data/Npz_files/CT_Abd-Gallbladder/test'\n",
"test_npzs = sorted(os.listdir(npz_ts_path))\n",
"# random select a test case\n",
"npz_idx = np.random.randint(0, len(test_npzs))\n",
"npz = np.load(join(npz_ts_path, test_npzs[npz_idx]))\n",
"imgs = npz['imgs']\n",
"gts = npz['gts']\n",
"\n",
"def get_bbox_from_mask(mask):\n",
" '''Returns a bounding box from a mask'''\n",
" y_indices, x_indices = np.where(mask > 0)\n",
" x_min, x_max = np.min(x_indices), np.max(x_indices)\n",
" y_min, y_max = np.min(y_indices), np.max(y_indices)\n",
" # add perturbation to bounding box coordinates\n",
" H, W = mask.shape\n",
" x_min = max(0, x_min - np.random.randint(0, 20))\n",
" x_max = min(W, x_max + np.random.randint(0, 20))\n",
" y_min = max(0, y_min - np.random.randint(0, 20))\n",
" y_max = min(H, y_max + np.random.randint(0, 20))\n",
"\n",
" return np.array([x_min, y_min, x_max, y_max])\n",
"\n",
"ori_sam_segs = []\n",
"medsam_segs = []\n",
"bboxes = []\n",
"for img, gt in zip(imgs, gts):\n",
" bbox = get_bbox_from_mask(gt)\n",
" bboxes.append(bbox)\n",
" # predict the segmentation mask using the original SAM model\n",
" ori_sam_predictor.set_image(img)\n",
" ori_sam_seg, _, _ = ori_sam_predictor.predict(point_coords=None, box=bbox, multimask_output=False)\n",
" ori_sam_segs.append(ori_sam_seg[0])\n",
" \n",
" # predict the segmentation mask using the fine-tuned model\n",
" H, W = img.shape[:2]\n",
" resize_img = sam_trans.apply_image(img)\n",
" resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to(device)\n",
" input_image = sam_model.preprocess(resize_img_tensor[None,:,:,:]) # (1, 3, 1024, 1024)\n",
" with torch.no_grad():\n",
" image_embedding = sam_model.image_encoder(input_image.to(device)) # (1, 256, 64, 64)\n",
" # convert box to 1024x1024 grid\n",
" bbox = sam_trans.apply_boxes(bbox, (H, W))\n",
" box_torch = torch.as_tensor(bbox, dtype=torch.float, device=device)\n",
" if len(box_torch.shape) == 2:\n",
" box_torch = box_torch[:, None, :] # (B, 1, 4)\n",
" \n",
" sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(\n",
" points=None,\n",
" boxes=box_torch,\n",
" masks=None,\n",
" )\n",
" medsam_seg_prob, _ = sam_model.mask_decoder(\n",
" image_embeddings=image_embedding.to(device), # (B, 256, 64, 64)\n",
" image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)\n",
" sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)\n",
" dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)\n",
" multimask_output=False,\n",
" )\n",
" medsam_seg_prob = torch.sigmoid(medsam_seg_prob)\n",
" # convert soft mask to hard mask\n",
" medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()\n",
" medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)\n",
" medsam_segs.append(medsam_seg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#%% compute the DSC score\n",
"ori_sam_segs = np.stack(ori_sam_segs, axis=0)\n",
"medsam_segs = np.stack(medsam_segs, axis=0)\n",
"ori_sam_dsc = compute_dice_coefficient(gts>0, ori_sam_segs>0)\n",
"medsam_dsc = compute_dice_coefficient(gts>0, medsam_segs>0)\n",
"print('Original SAM DSC: {:.4f}'.format(ori_sam_dsc), 'MedSAM DSC: {:.4f}'.format(medsam_dsc))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#%% visualize the segmentation results of the middle slice\n",
"# visualization functions\n",
"# source: https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb\n",
"# change color to avoid red and green\n",
"def show_mask(mask, ax, random_color=False):\n",
" if random_color:\n",
" color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
" else:\n",
" color = np.array([251/255, 252/255, 30/255, 0.6])\n",
" h, w = mask.shape[-2:]\n",
" mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
" ax.imshow(mask_image)\n",
" \n",
"def show_box(box, ax):\n",
" x0, y0 = box[0], box[1]\n",
" w, h = box[2] - box[0], box[3] - box[1]\n",
" ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2)) \n",
"\n",
"\n",
"img_id = int(imgs.shape[0]/2) # np.random.randint(imgs.shape[0])\n",
"_, axs = plt.subplots(1, 3, figsize=(25, 25))\n",
"axs[0].imshow(imgs[img_id])\n",
"show_mask(gts[img_id], axs[0])\n",
"# show_box(box_np[img_id], axs[0])\n",
"# axs[0].set_title('Mask with Tuned Model', fontsize=20)\n",
"axs[0].axis('off')\n",
"\n",
"axs[1].imshow(imgs[img_id])\n",
"show_mask(ori_sam_segs[img_id], axs[1])\n",
"show_box(bboxes[img_id], axs[1])\n",
"# add text to image to show dice score\n",
"axs[1].text(0.5, 0.5, 'SAM DSC: {:.4f}'.format(ori_sam_dsc), fontsize=30, horizontalalignment='left', verticalalignment='top', color='yellow')\n",
"# axs[1].set_title('Mask with Untuned Model', fontsize=20)\n",
"axs[1].axis('off')\n",
"\n",
"axs[2].imshow(imgs[img_id])\n",
"show_mask(medsam_segs[img_id], axs[2])\n",
"show_box(bboxes[img_id], axs[2])\n",
"# add text to image to show dice score\n",
"axs[2].text(0.5, 0.5, 'MedSAM DSC: {:.4f}'.format(medsam_dsc), fontsize=30, horizontalalignment='left', verticalalignment='top', color='yellow')\n",
"# axs[2].set_title('Ground Truth', fontsize=20)\n",
"axs[2].axis('off')\n",
"plt.show() \n",
"plt.subplots_adjust(wspace=0.01, hspace=0)\n",
"# save plot\n",
"# plt.savefig(join(model_save_path, test_npzs[npz_idx].split('.npz')[0] + str(img_id).zfill(3) + '.png'), bbox_inches='tight', dpi=300)\n",
"plt.close()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "medsam-demo",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit b5cd094

Please sign in to comment.