forked from WZMIAOMIAO/deep-learning-for-image-processing
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
wz
authored and
wz
committed
Jun 29, 2020
1 parent
c667ba6
commit c1a8602
Showing
9 changed files
with
269 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,45 @@ | ||
# 代码完善中,敬请期待... | ||
# SSD: Single Shot MultiBox Detector | ||
|
||
## 环境配置: | ||
* Python 3.6或者3.7 | ||
* Pytorch 1.5(注意:是1.5) | ||
* pycocotools(Linux: pip install pycocotools; | ||
Windows:pip install pycocotools-windows(不需要额外安装vs)) | ||
* Ubuntu或Centos(不建议Windows) | ||
* 最好使用GPU训练 | ||
|
||
## 文件结构: | ||
``` | ||
├── src: 实现SSD模型的相关模块 | ||
│ ├── resnet50_backbone.py 使用resnet50网络作为SSD的backbone | ||
│ ├── ssd_model.py SSD网络结构文件 | ||
│ └── utils.py 训练过程中使用到的一些功能实现 | ||
├── train_utils: 训练验证相关模块(包括cocotools) | ||
├── my_dataset.py: 自定义dataset用于读取VOC数据集 | ||
├── train_ssd300.py: 以resnet50做为backbone的SSD网络进行训练 | ||
├── train_multi_GPU.py: 针对使用多GPU的用户使用 | ||
├── predict_test.py: 简易的预测脚本,使用训练好的权重进行预测测试 | ||
├── pascal_voc_classes.json: pascal_voc标签文件 | ||
├── plot_curve.py: 用于绘制训练过程的损失以及验证集的mAP | ||
``` | ||
|
||
## 预训练权重下载地址(下载后放入src文件夹中): | ||
* ResNet50+SSD: https://ngc.nvidia.com/catalog/models | ||
`搜索ssd -> 找到SSD for PyTorch(FP32) -> download FP32 -> 解压文件` | ||
|
||
## 数据集,本例程使用的是PASCAL VOC2012数据集(下载后放入项目当前文件夹中) | ||
* Pascal VOC2012 train/val数据集下载地址:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar | ||
* Pascal VOC2007 test数据集请参考:http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar | ||
* 如果不了解数据集或者想使用自己的数据集进行训练,请参考我的bilibili:https://b23.tv/F1kSCK | ||
|
||
## 训练方法 | ||
* 确保提前准备好数据集 | ||
* 确保提前下载好对应预训练模型权重 | ||
* 单GPU训练或CPU,直接使用train_ssd300.py训练脚本 | ||
* 若要使用多GPU训练,使用 "python -m torch.distributed.launch --nproc_per_node=8 --use_env train_multi_GPU.py" 指令,nproc_per_node参数为使用GPU数量 | ||
|
||
## 如果对SSD算法原理不是很理解可参考我的bilibili | ||
* https://b23.tv/GJnkOD | ||
|
||
## 进一步了解该项目,以及对SSD算法代码的分析可参考我的bilibili | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import collections | ||
import PIL.ImageDraw as ImageDraw | ||
import PIL.ImageFont as ImageFont | ||
import numpy as np | ||
|
||
STANDARD_COLORS = [ | ||
'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque', | ||
'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite', | ||
'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan', | ||
'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange', | ||
'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet', | ||
'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite', | ||
'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod', | ||
'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki', | ||
'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue', | ||
'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey', | ||
'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue', | ||
'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime', | ||
'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid', | ||
'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen', | ||
'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin', | ||
'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed', | ||
'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed', | ||
'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple', | ||
'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown', | ||
'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue', | ||
'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow', | ||
'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White', | ||
'WhiteSmoke', 'Yellow', 'YellowGreen' | ||
] | ||
|
||
|
||
def filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map): | ||
for i in range(boxes.shape[0]): | ||
if scores[i] > thresh: | ||
box = tuple(boxes[i].tolist()) # numpy -> list -> tuple | ||
if classes[i] in category_index.keys(): | ||
class_name = category_index[classes[i]] | ||
else: | ||
class_name = 'N/A' | ||
display_str = str(class_name) | ||
display_str = '{}: {}%'.format(display_str, int(100 * scores[i])) | ||
box_to_display_str_map[box].append(display_str) | ||
box_to_color_map[box] = STANDARD_COLORS[ | ||
classes[i] % len(STANDARD_COLORS)] | ||
else: | ||
break # 网络输出概率已经排序过,当遇到一个不满足后面的肯定不满足 | ||
|
||
|
||
def draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color): | ||
try: | ||
font = ImageFont.truetype('arial.ttf', 24) | ||
except IOError: | ||
font = ImageFont.load_default() | ||
|
||
# If the total height of the display strings added to the top of the bounding | ||
# box exceeds the top of the image, stack the strings below the bounding box | ||
# instead of above. | ||
display_str_heights = [font.getsize(ds)[1] for ds in box_to_display_str_map[box]] | ||
# Each display_str has a top and bottom margin of 0.05x. | ||
total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights) | ||
|
||
if top > total_display_str_height: | ||
text_bottom = top | ||
else: | ||
text_bottom = bottom + total_display_str_height | ||
# Reverse list and print from bottom to top. | ||
for display_str in box_to_display_str_map[box][::-1]: | ||
text_width, text_height = font.getsize(display_str) | ||
margin = np.ceil(0.05 * text_height) | ||
draw.rectangle([(left, text_bottom - text_height - 2 * margin), | ||
(left + text_width, text_bottom)], fill=color) | ||
draw.text((left + margin, text_bottom - text_height - margin), | ||
display_str, | ||
fill='black', | ||
font=font) | ||
text_bottom -= text_height - 2 * margin | ||
|
||
|
||
def draw_box(image, boxes, classes, scores, category_index, thresh=0.5, line_thickness=8): | ||
box_to_display_str_map = collections.defaultdict(list) | ||
box_to_color_map = collections.defaultdict(str) | ||
|
||
filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map) | ||
|
||
# Draw all boxes onto image. | ||
draw = ImageDraw.Draw(image) | ||
im_width, im_height = image.size | ||
for box, color in box_to_color_map.items(): | ||
xmin, ymin, xmax, ymax = box | ||
(left, right, top, bottom) = (xmin * 1, xmax * 1, | ||
ymin * 1, ymax * 1) | ||
draw.line([(left, top), (left, bottom), (right, bottom), | ||
(right, top), (left, top)], width=line_thickness, fill=color) | ||
draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import torch | ||
from draw_box_utils import draw_box | ||
from PIL import Image | ||
import json | ||
import matplotlib.pyplot as plt | ||
from src.ssd_model import SSD300, Backbone | ||
import transform | ||
|
||
|
||
def create_model(num_classes): | ||
backbone = Backbone() | ||
model = SSD300(backbone=backbone, num_classes=num_classes) | ||
|
||
return model | ||
|
||
|
||
# get devices | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
print(device) | ||
|
||
# create model | ||
model = create_model(num_classes=21) | ||
|
||
# load train weights | ||
train_weights = "./save_weights/ssd300-15.pth" | ||
train_weights_dict = torch.load(train_weights, map_location=device)['model'] | ||
|
||
model.load_state_dict(train_weights_dict, strict=False) | ||
model.to(device) | ||
|
||
# read class_indict | ||
category_index = {} | ||
try: | ||
json_file = open('./pascal_voc_classes.json', 'r') | ||
class_dict = json.load(json_file) | ||
category_index = {v: k for k, v in class_dict.items()} | ||
except Exception as e: | ||
print(e) | ||
exit(-1) | ||
|
||
# load image | ||
original_img = Image.open("./test.jpg") | ||
|
||
# from pil image to tensor, do not normalize image | ||
data_transform = transform.Compose([transform.Resize(), | ||
transform.ToTensor(), | ||
transform.Normalization()]) | ||
img, _ = data_transform(original_img) | ||
# expand batch dimension | ||
img = torch.unsqueeze(img, dim=0) | ||
|
||
model.eval() | ||
with torch.no_grad(): | ||
predictions = model(img.to(device))[0] # bboxes_out, labels_out, scores_out | ||
predict_boxes = predictions[0].to("cpu").numpy() | ||
predict_boxes[:, [0, 2]] = predict_boxes[:, [0, 2]] * original_img.size[0] | ||
predict_boxes[:, [1, 3]] = predict_boxes[:, [1, 3]] * original_img.size[1] | ||
predict_classes = predictions[1].to("cpu").numpy() | ||
predict_scores = predictions[2].to("cpu").numpy() | ||
|
||
if len(predict_boxes) == 0: | ||
print("没有检测到任何目标!") | ||
|
||
draw_box(original_img, | ||
predict_boxes, | ||
predict_classes, | ||
predict_scores, | ||
category_index, | ||
thresh=0.5, | ||
line_thickness=5) | ||
plt.imshow(original_img) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters