Skip to content

Commit

Permalink
generating masks tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
reddiedev committed May 23, 2023
1 parent 8c25cd2 commit 915039a
Showing 1 changed file with 232 additions and 20 deletions.
252 changes: 232 additions & 20 deletions scripts/obj-detection.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -51,9 +51,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[setup]: determining CUDA support...\n",
"PyTorch version: 2.0.1+cu118\n",
"Torchvision version: 0.15.2+cu118\n",
"CUDA is available: True\n"
]
}
],
"source": [
"\n",
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"max_split_size_mb:512\"\n",
Expand All @@ -65,7 +76,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -97,9 +108,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"You have selected MANUAL_INPUT, please enter the image URLs in the following block\n"
]
}
],
"source": [
"def get_valid_input():\n",
" while True:\n",
Expand All @@ -118,20 +137,39 @@
" return number\n",
" print(\"Invalid input. Please enter an integer from 1 to {}.\".format(n))\n",
"\n",
"input_type = get_valid_input()\n",
"# input_type = get_valid_input()\n",
"input_type = 1\n",
"\n",
"if input_type == 1:\n",
" input_mode = \"manual\"\n",
" print(\"You have selected MANUAL_INPUT, please enter the image URLs in the following block\")\n",
"elif input_type==2:\n",
" input_mode = \"random\"\n",
" input_image_count = get_valid_number(10)\n",
" print(f\"You have selected RANDOM_INPUT of {input_image_count} images from the CoCo 2017 dataset\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0]: loading coco annotations and captions...\n",
"loading annotations into memory...\n",
"Done (t=0.64s)\n",
"creating index...\n",
"index created!\n",
"loading annotations into memory...\n",
"Done (t=0.05s)\n",
"creating index...\n",
"index created!\n"
]
}
],
"source": [
"\n",
"print(\"[0]: loading coco annotations and captions...\")\n",
Expand Down Expand Up @@ -159,13 +197,13 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"input_image_IDs = []\n",
"input_image_catIDs = []\n",
"input_image_URLs = []\n",
"input_images_links = []\n",
"input_image_areas = []\n",
"input_image_labels = []\n",
"\n",
Expand All @@ -186,24 +224,44 @@
" ground_truth_labels = list(map(lambda x: x['caption'], anns))\n",
" input_image_IDs.append(imgID)\n",
" input_image_catIDs.append(catID)\n",
" input_image_URLs.append(imgURL)\n",
" input_images_links.append(imgURL)\n",
" input_image_areas.append(imgArea)\n",
" input_image_labels.append(ground_truth_labels)\n",
"\n",
"random_images = get_random_coco_image(input_image_count) \n",
"if input_type == 2:\n",
" random_images = get_random_coco_image(input_image_count) \n",
"\n",
"## MODIFY FOR MANUAL INPUT\n",
"if input_type == 1:\n",
" input_images_links = []"
" input_images_links = ['../images/dog_car.jpg']"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"\n",
"coco_results = []\n",
"label_results = []\n",
"top_one_scores = []\n",
"top_five_scores = []"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1]: loading sam model\n"
]
}
],
"source": [
"print(\"[1]: loading sam model\")\n",
"sam_checkpoint = os.path.join(\"../checkpoints\", \"sam_vit_h_4b8939.pth\")\n",
"model_type = \"vit_h\"\n",
Expand All @@ -225,11 +283,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2]: creating open clip model...\n",
"[2]: loading coco categories as labels...\n"
]
}
],
"source": [
"\n",
"print(\"[2]: creating open clip model...\")\n",
"modelType = 'ViT-B-32-quickgelu'\n",
"modelDataset = \"laion400m_e31\"\n",
Expand All @@ -242,7 +308,153 @@
"\n",
"print(\"[2]: loading coco categories as labels...\")\n",
"text = tokenizer(coco_labels_words_values)\n",
"text = text.to(device)\n"
"text = text.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def show_anns(anns):\n",
" if len(anns) == 0:\n",
" return\n",
" sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)\n",
" ax = plt.gca()\n",
" ax.set_autoscale_on(False)\n",
" for ann in sorted_anns:\n",
" m = ann['segmentation']\n",
" img = np.ones((m.shape[0], m.shape[1], 3))\n",
" color_mask = np.random.random((1, 3)).tolist()[0]\n",
" for i in range(3):\n",
" img[:, :, i] = color_mask[i]\n",
" ax.imshow(np.dstack((img, m*0.35)))\n",
"\n",
"def generate_masks(image,image_index,raw_masks, area):\n",
" if len(raw_masks) == 0:\n",
" return\n",
" length = len(raw_masks)\n",
" sorted_anns = sorted(raw_masks, key=(lambda x: x['area']), reverse=True)\n",
" counter = 1\n",
" filtered_masks = []\n",
" for i in range(length):\n",
" mask = sorted_anns[i]\n",
" if (mask['area'] < area):\n",
" continue\n",
" x, y, w, h = mask['bbox']\n",
" x, y, w, h = int(x), int(y), int(w), int(h)\n",
" im = image[y:y+h, x:x+w]\n",
" plt.figure(figsize=(20, 20))\n",
" plt.imshow(im)\n",
" plt.axis('off')\n",
" plt.savefig(f\"../output/{image_index}/mask-{counter}.jpg\",\n",
" bbox_inches='tight', pad_inches=0)\n",
" plt.close()\n",
" counter += 1\n",
" filtered_masks.append(mask)\n",
" return filtered_masks\n",
"\n",
"\n",
"def generate_labels(anns, imgID, catID):\n",
" if len(anns) == 0:\n",
" return\n",
" length = len(anns)\n",
" values = []\n",
" for i in range(length):\n",
" mask = anns[i]\n",
" im = Image.open(f\"../output/{imgID}/mask-{i+1}.jpg\").convert(\"RGB\")\n",
" img = preprocess(im).unsqueeze(0)\n",
" img = img.to(device)\n",
"\n",
" with torch.no_grad(), torch.cuda.amp.autocast():\n",
" image_features = model.encode_image(img)\n",
" text_features = model.encode_text(text)\n",
" image_features /= image_features.norm(dim=-1, keepdim=True)\n",
" text_features /= text_features.norm(dim=-1, keepdim=True)\n",
" text_probs = (100.0 * image_features @\n",
" text_features.T).softmax(dim=-1)\n",
"\n",
" text_prob = np.max(text_probs.cpu().numpy())\n",
" index = np.argmax(text_probs.cpu().numpy())\n",
" label = coco_labels_words_values[index]\n",
"\n",
" print(f\"[{i+1}/{length}]:\", label, f\"({text_prob*100:.2f}%)\",)\n",
" values.append(\n",
" {\"label\": label, \"area\": mask[\"area\"], \"prob\": text_prob})\n",
" result = {'image_id': imgID, 'category_id': catID,\n",
" \"bbox\": mask['bbox'], \"score\": mask['predicted_iou']}\n",
" coco_results.append(result)\n",
"\n",
" # generate top 5 labels according to label_accuracy and mask_area\n",
" sorted_values = sorted(values, key=lambda x: x['prob'])\n",
" payload = sorted_values[:5]\n",
" labels = list(map(lambda d: d['label'], payload))\n",
" return labels"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for image_index in range(len(input_images_links)):\n",
" image_link = input_images_links[image_index]\n",
" # create output folder for this image\n",
" \n",
" image_folder_path = f\"../output/{image_index}\"\n",
" if not os.path.exists(image_folder_path):\n",
" os.mkdir(image_folder_path)\n",
" \n",
" image = io.imread(image_link)\n",
" height, width = image.shape[:2]\n",
" imageArea = height * width\n",
" maskArea = 0.04 * imageArea\n",
" \n",
" plt.figure(figsize=(10, 10))\n",
" plt.imshow(image)\n",
" plt.axis('off')\n",
" plt.savefig(f\"../output/{image_index}/source.jpg\",\n",
" bbox_inches='tight', pad_inches=0)\n",
" plt.show()\n",
" plt.close()\n",
" \n",
" raw_masks = mask_generator.generate(image)\n",
" print(f\"({image_index}): generated {len(raw_masks)} masks...\")\n",
" filtered_masks = generate_masks(image, image_index, raw_masks, maskArea)\n",
" print(f\"({image_index}): masks filtered down to {len(filtered_masks)} masks...\")\n",
"\n",
" plt.figure(figsize=(10, 10))\n",
" plt.imshow(image)\n",
" show_anns(filtered_masks)\n",
" plt.axis('off')\n",
" plt.savefig(f\"../output/{image_index}/generated-masks.jpg\",\n",
" bbox_inches='tight', pad_inches=0)\n",
" plt.show()\n",
" plt.close()\n",
" \n",
" source_image = io.imread(f\"../output/{image_index}/source.jpg\")\n",
" fig, axs = plt.subplots(1, 2)\n",
" axs[0].imshow(source_image, cmap='gray')\n",
" axs[0].set_title('Source Image')\n",
" masked_image = io.imread(f\"../output/{image_index}/generated-masks.jpg\")\n",
" axs[1].imshow(masked_image, cmap='gray')\n",
" axs[1].set_title('Image with Masks')\n",
" plt.subplots_adjust(wspace=0.4)\n",
" plt.show()\n",
"\n",
" \n",
" # for each mask image, annotate using open-clip\n",
" print(f\"({image_index}): generating labels...\")\n",
" generated_labels = generate_labels(filtered_masks, imgID, catID)\n",
" print(\"GENERATED LABELS\")\n",
" pprint(generated_labels)\n",
" print(\"GROUND TRUTH LABELS\")\n",
" pprint(ground_truth_labels)\n",
" ground_truth_string = \" \".join(ground_truth_labels)\n",
" \n",
" "
]
}
],
Expand Down

0 comments on commit 915039a

Please sign in to comment.