Skip to content

Commit

Permalink
Merge branch 'master' into add_pose
Browse files Browse the repository at this point in the history
  • Loading branch information
ChengLai authored Feb 7, 2021
2 parents 022e8c2 + 39c3159 commit 495b98b
Show file tree
Hide file tree
Showing 9 changed files with 781 additions and 15 deletions.
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Welcome to TensorLayer

**Documentation Version:** |release|

**Jun 2020** `Deep Reinforcement Learning Book Is Coming <http://deepreinforcementlearningbook.org>`__.
**Jun 2020** `Deep Reinforcement Learning Book Is Released <http://deepreinforcementlearningbook.org>`__.

**Good News:** We won the **Best Open Source Software Award** `@ACM Multimedia (MM) 2017 <http://www.acmmm.org/2017/mm-2017-awardees/>`_.

Expand Down
4 changes: 4 additions & 0 deletions docs/modules/visualize.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ to visualize the model, activations etc. Here we provide more functions for data
frame
images2d
tsne_embedding
draw_boxes_and_labels_to_image_with_json


Save and read images
Expand All @@ -44,6 +45,9 @@ Save image for object detection
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: draw_boxes_and_labels_to_image

Save image for object detection with json
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: draw_boxes_and_labels_to_image_with_json

Save image for pose estimation (MPII)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
80 changes: 80 additions & 0 deletions examples/app_tutorials/model/coco.names
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
person
bicycle
car
motorbike
aeroplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
sofa
potted plant
bed
dining table
toilet
tvmonitor
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush
541 changes: 541 additions & 0 deletions examples/app_tutorials/model/yolov4_config.txt

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions examples/app_tutorials/tutorial_object_detection_yolov4_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#! /usr/bin/python
# -*- coding: utf-8 -*-

from tensorlayer.app import computer_vision
from tensorlayer import visualize
from tensorlayer.app.computer_vision_object_detection.common import read_class_names
import numpy as np
import cv2
from PIL import Image
INPUT_SIZE = 416
image_path = './data/kite.jpg'

class_names = read_class_names('./model/coco.names')
original_image = cv2.imread(image_path)
image = cv2.cvtColor(np.array(original_image), cv2.COLOR_BGR2RGB)
net = computer_vision.object_detection('yolo4-mscoco')
json_result = net(original_image)
print(type(json_result))
image = visualize.draw_boxes_and_labels_to_image_with_json(image, json_result, class_names)
image = Image.fromarray(image.astype(np.uint8))
image.show()
38 changes: 38 additions & 0 deletions examples/app_tutorials/tutorial_object_detection_yolov4_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#! /usr/bin/python
# -*- coding: utf-8 -*-

from tensorlayer.app import computer_vision
from tensorlayer import visualize
from tensorlayer.app.computer_vision_object_detection.common import read_class_names
import cv2
INPUT_SIZE = 416
video_path = './data/road.mp4'

class_names = read_class_names('./model/coco.names')
vid = cv2.VideoCapture(video_path)
'''
vid = cv2.VideoCapture(0) # the serial number of camera on you device
'''

if not vid.isOpened():
raise ValueError("Read Video Failed!")
net = computer_vision.object_detection('yolo4-mscoco')
frame_id = 0
while True:
return_value, frame = vid.read()
if return_value:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
else:
if frame_id == vid.get(cv2.CAP_PROP_FRAME_COUNT):
print("Video processing complete")
break
raise ValueError("No image! Try with another video format")

json_result = net(frame)
image = visualize.draw_boxes_and_labels_to_image_with_json(frame, json_result, class_names)
result = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

cv2.namedWindow("result", cv2.WINDOW_AUTOSIZE)
cv2.imshow("result", result)
if cv2.waitKey(1) & 0xFF == ord('q'): break
frame_id += 1
28 changes: 28 additions & 0 deletions tensorlayer/app/computer_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __call__(self, input_data):
output = yolo4_output_processing(feature_maps)
elif self.model_name == 'lcn':
output = self.model(input_data)
pred_bbox = yolo4_output_processing(feature_maps)
output = result_to_json(input_data, pred_bbox)
else:
raise NotImplementedError

Expand Down Expand Up @@ -118,3 +120,29 @@ def yolo4_output_processing(feature_maps):
)
output = [boxes.numpy(), scores.numpy(), classes.numpy(), valid_detections.numpy()]
return output


def result_to_json(image, pred_bbox):
image_h, image_w, _ = image.shape
out_boxes, out_scores, out_classes, num_boxes = pred_bbox
class_names = {}
json_result = []
with open('model/coco.names', 'r') as data:
for ID, name in enumerate(data):
class_names[ID] = name.strip('\n')
nums_class = len(class_names)

for i in range(num_boxes[0]):
if int(out_classes[0][i]) < 0 or int(out_classes[0][i]) > nums_class: continue
coor = out_boxes[0][i]
coor[0] = int(coor[0] * image_h)
coor[2] = int(coor[2] * image_h)
coor[1] = int(coor[1] * image_w)
coor[3] = int(coor[3] * image_w)

score = float(out_scores[0][i])
class_ind = int(out_classes[0][i])
bbox = np.array([coor[1], coor[0], coor[3], coor[2]]).tolist() # [x1,y1,x2,y2]
json_result.append({'image': None, 'category_id': class_ind, 'bbox': bbox, 'score': score})

return json_result
2 changes: 1 addition & 1 deletion tensorlayer/prepro.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,7 +1527,7 @@ def zoom_multi(x, zoom_range=(0.9, 1.1), flags=None, border_mode='constant'):
h, w = x.shape[0], x.shape[1]
transform_matrix = transform_matrix_offset_center(zoom_matrix, h, w)
results.append(affine_transform_cv2(x, transform_matrix, flags=flags, border_mode=border_mode))
return results
return np.asarray(results)


# image = tf.image.random_brightness(image, max_delta=32. / 255.)
Expand Down
80 changes: 67 additions & 13 deletions tensorlayer/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import imageio
import numpy as np

import tensorlayer as tl
from tensorlayer.lazy_imports import LazyImport
import colorsys, random

cv2 = LazyImport("cv2")

Expand All @@ -16,18 +16,9 @@
# matplotlib.use('Agg')

__all__ = [
'read_image',
'read_images',
'save_image',
'save_images',
'draw_boxes_and_labels_to_image',
'draw_mpii_people_to_image',
'frame',
'CNN2d',
'images2d',
'tsne_embedding',
'draw_weights',
'W',
'read_image', 'read_images', 'save_image', 'save_images', 'draw_boxes_and_labels_to_image',
'draw_mpii_people_to_image', 'frame', 'CNN2d', 'images2d', 'tsne_embedding', 'draw_weights', 'W',
'draw_boxes_and_labels_to_image_with_json'
]


Expand Down Expand Up @@ -662,3 +653,66 @@ def draw_weights(W=None, second=10, saveable=True, shape=None, name='mnist', fig


W = draw_weights


def draw_boxes_and_labels_to_image_with_json(image, json_result, class_list, save_name=None):
"""Draw bboxes and class labels on image. Return the image with bboxes.
Parameters
-----------
image : numpy.array
The RGB image [height, width, channel].
json_result : list of dict
The object detection result with json format.
classes_list : list of str
For converting ID to string on image.
save_name : None or str
The name of image file (i.e. image.png), if None, not to save image.
Returns
-------
numpy.array
The saved image.
References
-----------
- OpenCV rectangle and putText.
- `scikit-image <http://scikit-image.org/docs/dev/api/skimage.draw.html#skimage.draw.rectangle>`__.
"""
image_h, image_w, _ = image.shape
num_classes = len(class_list)
hsv_tuples = [(1.0 * x / num_classes, 1., 1.) for x in range(num_classes)]
colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
random.seed(0)
random.shuffle(colors)
random.seed(None)
bbox_thick = int(0.6 * (image_h + image_w) / 600)
fontScale = 0.5

for bbox_info in json_result:
image_name = bbox_info['image']
category_id = bbox_info['category_id']
if category_id < 0 or category_id > num_classes: continue
bbox = bbox_info['bbox'] # the order of coordinates is [x1, y2, x2, y2]
score = bbox_info['score']

bbox_color = colors[category_id]
c1, c2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
cv2.rectangle(image, c1, c2, bbox_color, bbox_thick)

bbox_mess = '%s: %.2f' % (class_list[category_id], score)
t_size = cv2.getTextSize(bbox_mess, 0, fontScale, thickness=bbox_thick // 2)[0]
c3 = (c1[0] + t_size[0], c1[1] - t_size[1] - 3)
cv2.rectangle(image, c1, (np.float32(c3[0]), np.float32(c3[1])), bbox_color, -1)

cv2.putText(
image, bbox_mess, (c1[0], np.float32(c1[1] - 2)), cv2.FONT_HERSHEY_SIMPLEX, fontScale, (0, 0, 0),
bbox_thick // 2, lineType=cv2.LINE_AA
)

if save_name is not None:
save_image(image, save_name)

return image

0 comments on commit 495b98b

Please sign in to comment.