Skip to content

Commit

Permalink
update API
Browse files Browse the repository at this point in the history
  • Loading branch information
rentainhe committed Sep 29, 2021
1 parent dc9d582 commit 3e6c0ec
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 18 deletions.
2 changes: 1 addition & 1 deletion debug.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from visualize import run_grid_attention_example
run_grid_attention_example()
run_grid_attention_example(version=2, quality=200)
24 changes: 18 additions & 6 deletions visualize/example/grid_attention_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,23 @@
import numpy as np


def run_grid_attention_example():
img_path = 'test_data/test_image.jpg'
random_attention = np.random.randn(14, 14)
save_path = 'test/'
visulize_grid_attention_v2(img_path=img_path, save_path=save_path, attention_mask=random_attention, save_image=True,
save_original_image=True)
def run_grid_attention_example(img_path="test_data/test_image.jpg", save_path="test/", attention_mask=None, version=2, quality=100):
if not attention_mask:
attention_mask = np.random.randn(14, 14)
assert version in [1, 2], "We only support two version of attention visualization example"
if version == 1:
visulize_grid_attention(img_path=img_path,
save_path=save_path,
attention_mask=attention_mask,
save_image=True,
save_original_image=True,
quality=quality)
elif version == 2:
visulize_grid_attention_v2(img_path=img_path,
save_path=save_path,
attention_mask=attention_mask,
save_image=True,
save_original_image=True,
quality=quality)


14 changes: 9 additions & 5 deletions visualize/visualize_attention_map/visualize_attention_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@
import os


def visulize_grid_attention(img_path, save_path, attention_mask, save_image=True, save_original_image=True):
def visulize_grid_attention(img_path, save_path, attention_mask, ratio=0.5, save_image=True, save_original_image=True, quality=100):
"""
img_path: where to load the image
save_path: where to save the image
attention_mask: the 2-D attention mask on your image, e.g: np.array (h, w) or (w, h)
ratio: scaling factor to scale the output h and w
quality: save image quality
"""
print("load image from: " + img_path)
img = Image.open(img_path)
img_h, img_w = img.size[0], img.size[1]

# set the background
plt.subplots(nrows=1, ncols=1, figsize=(0.02 * img_h, 0.02 * img_w))

# scale the image
img_h, img_w = int(img.size[0] * ratio), int(img.size[1] * ratio)
img = img.resize((img_h, img_w))
plt.imshow(img, alpha=1)
plt.axis('off')

Expand All @@ -43,11 +47,11 @@ def visulize_grid_attention(img_path, save_path, attention_mask, save_image=True
plt.axis('off')
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.savefig(img_with_attention_save_path, dpi=100)
plt.savefig(img_with_attention_save_path, dpi=quality)

# save original image
if save_original_image:
print("save original image at the same time")
img_name = img_path.split('/')[-1].split('.')[0] + "_original.jpg"
original_image_save_path = os.path.join(save_path, img_name)
img.save(original_image_save_path, quality=100)
img.save(original_image_save_path, quality=quality)
14 changes: 8 additions & 6 deletions visualize/visualize_attention_map/visualize_attention_map_V2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import os


def visulize_grid_attention_v2(img_path, save_path, attention_mask, ratio=0.5, cmap="jet", save_image=False,
save_original_image=False):
def visulize_grid_attention_v2(img_path, save_path, attention_mask, ratio=1, cmap="jet", save_image=False,
save_original_image=False, quality=200):
"""
img_path: image file path to load
save_path: image file path to save
attention_mask: 2-D attention map with np.array type, e.g, (h, w) or (w, h)
attention_mask: 2-D attention map with np.array type, e.g, (h, w) or (w, h)
ratio: scaling factor to scale the output h and w
cmap: attention style, default: "jet"
cmap: attention style, default: "jet"
quality: saved image quality
"""
print("load image from: ", img_path)
img = Image.open(img_path, mode='r')
Expand All @@ -29,6 +30,7 @@ def visulize_grid_attention_v2(img_path, save_path, attention_mask, ratio=0.5, c
plt.imshow(img, alpha=1)
plt.axis('off')

# normalize the attention map
mask = cv2.resize(attention_mask, (img_h, img_w))
normed_mask = mask / mask.max()
normed_mask = (normed_mask * 255).astype('uint8')
Expand All @@ -46,7 +48,7 @@ def visulize_grid_attention_v2(img_path, save_path, attention_mask, ratio=0.5, c
plt.axis('off')
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.savefig(img_with_attention_save_path, dpi=100)
plt.savefig(img_with_attention_save_path, dpi=quality)

if save_original_image:
# build save path
Expand All @@ -57,4 +59,4 @@ def visulize_grid_attention_v2(img_path, save_path, attention_mask, ratio=0.5, c
print("save original image at the same time")
img_name = img_path.split('/')[-1].split('.')[0] + "_original.jpg"
original_image_save_path = os.path.join(save_path, img_name)
img.save(original_image_save_path, quality=100)
img.save(original_image_save_path, quality=quality)

0 comments on commit 3e6c0ec

Please sign in to comment.