-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathget_dr_txt.py
117 lines (97 loc) · 4.86 KB
/
get_dr_txt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#----------------------------------------------------#
# 获取测试集的ground-truth
# 具体视频教程可查看
# https://www.bilibili.com/video/BV1zE411u7Vw
#----------------------------------------------------#
import os
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from PIL import Image
from torch.autograd import Variable
from tqdm import tqdm
from retinanet import RetinaNet
from utils.utils import (bbox_iou, decodebox, letterbox_image,
non_max_suppression, retinanet_correct_boxes)
def preprocess_input(image):
image /= 255
# mean=(0.406, 0.456, 0.485)
# std=(0.225, 0.224, 0.229)
mean = (0.204, 0.301, 0.316)
std = (0.201, 0.175, 0.157)
image -= mean
image /= std
return image
class mAP_RetinaNet(RetinaNet):
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def detect_image(self,image_id,image):
self.confidence = 0.01
self.iou = 0.5
f = open("./input/detection-results/"+image_id+".txt","w")
image_shape = np.array(np.shape(image)[0:2])
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
#---------------------------------------------------------#
crop_img = np.array(letterbox_image(image, [self.input_shape[1], self.input_shape[0]]))
photo = np.array(crop_img,dtype = np.float32)
photo = np.transpose(preprocess_input(photo), (2, 0, 1))
with torch.no_grad():
images = torch.from_numpy(np.asarray([photo]))
if self.cuda:
images = images.cuda()
#---------------------------------------------------------#
# 传入网络当中进行预测
#---------------------------------------------------------#
_, regression, classification, anchors = self.net(images)
#-----------------------------------------------------------#
# 将预测结果进行解码
#-----------------------------------------------------------#
regression = decodebox(regression, anchors, images)
detection = torch.cat([regression,classification],axis=-1)
batch_detections = non_max_suppression(detection, len(self.class_names),
conf_thres=self.confidence,
nms_thres=self.iou)
#--------------------------------------#
# 如果没有检测到物体,则返回原图
#--------------------------------------#
try:
batch_detections = batch_detections[0].cpu().numpy()
except:
return
#-----------------------------------------------------------#
# 筛选出其中得分高于confidence的框
#-----------------------------------------------------------#
top_index = batch_detections[:,4] > self.confidence
top_conf = batch_detections[top_index,4]
top_label = np.array(batch_detections[top_index,-1],np.int32)
top_bboxes = np.array(batch_detections[top_index,:4])
top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(top_bboxes[:,0],-1),np.expand_dims(top_bboxes[:,1],-1),np.expand_dims(top_bboxes[:,2],-1),np.expand_dims(top_bboxes[:,3],-1)
#-----------------------------------------------------------#
# 去掉灰条部分
#-----------------------------------------------------------#
boxes = retinanet_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.input_shape[0],self.input_shape[1]]),image_shape)
for i, c in enumerate(top_label):
predicted_class = self.class_names[c]
score = str(top_conf[i])
top, left, bottom, right = boxes[i]
f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
f.close()
return
retinanet = mAP_RetinaNet()
image_ids = open('../../Datasets/airbus-ship-medium/ImageSets/Main/test.txt').read().strip().split()
if not os.path.exists("./input"):
os.makedirs("./input")
if not os.path.exists("./input/detection-results"):
os.makedirs("./input/detection-results")
if not os.path.exists("./input/images-optional"):
os.makedirs("./input/images-optional")
for image_id in tqdm(image_ids):
image_path = "../../Datasets/airbus-ship-medium/JPEGImages/"+image_id+".jpg"
image = Image.open(image_path)
# 开启后在之后计算mAP可以可视化
# image.save("./input/images-optional/"+image_id+".jpg")
retinanet.detect_image(image_id,image)
print("Conversion completed!")