Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
bubbliiiing authored Dec 30, 2020
1 parent 77c3ef7 commit d4a4038
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 14 deletions.
25 changes: 25 additions & 0 deletions get_gt_txt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,20 @@
import glob
import xml.etree.ElementTree as ET

'''
!!!!!!!!!!!!!注意事项!!!!!!!!!!!!!
# 这一部分是当xml有无关的类的时候,下方有代码可以进行筛选!
'''
#---------------------------------------------------#
# 获得类
#---------------------------------------------------#
def get_classes(classes_path):
'''loads the classes'''
with open(classes_path) as f:
class_names = f.readlines()
class_names = [c.strip() for c in class_names]
return class_names

image_ids = open('VOCdevkit/VOC2007/ImageSets/Main/test.txt').read().strip().split()

if not os.path.exists("./input"):
Expand All @@ -25,11 +39,22 @@
if int(difficult)==1:
difficult_flag = True
obj_name = obj.find('name').text
'''
!!!!!!!!!!!!注意事项!!!!!!!!!!!!
# 这一部分是当xml有无关的类的时候,可以取消下面代码的注释
# 利用对应的classes.txt来进行筛选!!!!!!!!!!!!
'''
# classes_path = 'model_data/voc_classes.txt'
# class_names = get_classes(classes_path)
# if obj_name not in class_names:
# continue

bndbox = obj.find('bndbox')
left = bndbox.find('xmin').text
top = bndbox.find('ymin').text
right = bndbox.find('xmax').text
bottom = bndbox.find('ymax').text

if difficult_flag:
new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
else:
Expand Down
50 changes: 36 additions & 14 deletions get_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,13 +389,28 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, out
is_difficult = True
else:
class_name, left, top, right, bottom = line.split()
except ValueError:
error_msg = "Error: File " + txt_file + " in the wrong format.\n"
error_msg += " Expected: <class_name> <left> <top> <right> <bottom> ['difficult']\n"
error_msg += " Received: " + line
error_msg += "\n\nIf you have a <class_name> with spaces between words you should remove them\n"
error_msg += "by running the script \"remove_space.py\" or \"rename_class.py\" in the \"extra/\" folder."
error(error_msg)

except:
if "difficult" in line:
line_split = line.split()
_difficult = line_split[-1]
bottom = line_split[-2]
right = line_split[-3]
top = line_split[-4]
left = line_split[-5]
class_name = ""
for name in line_split[:-5]:
class_name += name
is_difficult = True
else:
line_split = line.split()
bottom = line_split[-1]
right = line_split[-2]
top = line_split[-3]
left = line_split[-4]
class_name = ""
for name in line_split[:-4]:
class_name += name
# check if class is in the ignore list, if yes skip
if class_name in args.ignore:
continue
Expand Down Expand Up @@ -481,11 +496,17 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, out
for line in lines:
try:
tmp_class_name, confidence, left, top, right, bottom = line.split()
except ValueError:
error_msg = "Error: File " + txt_file + " in the wrong format.\n"
error_msg += " Expected: <class_name> <confidence> <left> <top> <right> <bottom>\n"
error_msg += " Received: " + line
error(error_msg)
except:
line_split = line.split()
bottom = line_split[-1]
right = line_split[-2]
top = line_split[-3]
left = line_split[-4]
confidence = line_split[-5]
tmp_class_name = ""
for name in line_split[:-5]:
tmp_class_name += name

if tmp_class_name == class_name:
#print("match")
bbox = left + " " + top + " " + right + " " +bottom
Expand Down Expand Up @@ -702,8 +723,9 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, out
rounded_rec = [ '%.2f' % elem for elem in rec ]
results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
if not args.quiet:
print(text + "\t||\tscore_threhold=0.5 : " + "F1=" + "{0:.2f}".format(F1[score05_idx])\
+ " ; Recall=" + "{0:.2f}%".format(rec[score05_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score05_idx]*100))
if(len(rec)!=0):
print(text + "\t||\tscore_threhold=0.5 : " + "F1=" + "{0:.2f}".format(F1[score05_idx])\
+ " ; Recall=" + "{0:.2f}%".format(rec[score05_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score05_idx]*100))
ap_dictionary[class_name] = ap

n_images = counter_images_per_class[class_name]
Expand Down

0 comments on commit d4a4038

Please sign in to comment.