Skip to content

Commit

Permalink
update comments
Browse files Browse the repository at this point in the history
  • Loading branch information
WZMIAOMIAO committed Nov 24, 2020
1 parent 8061c53 commit 45659bf
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
5 changes: 3 additions & 2 deletions pytorch_object_detection/yolov3_spp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,9 @@ def forward(self, p):


class Darknet(nn.Module):
# YOLOv3 object detection model

"""
YOLOv3 spp object detection model
"""
def __init__(self, cfg, img_size=(416, 416), verbose=False):
super(Darknet, self).__init__()
# 这里传入的img_size只在导出ONNX模型时起作用
Expand Down
21 changes: 15 additions & 6 deletions pytorch_object_detection/yolov3_spp/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,22 @@ def __init__(self,

# Read image shapes (wh)
# 查看data文件下是否缓存有对应数据集的.shapes文件,里面存储了每张图像的width, height
sp = path.replace(".txt", "") + ".shapes" # shapefile path
sp = path.replace(".txt", ".shapes") # shapefile path
try:
with open(sp, "r") as f: # read existing shapefile
s = [x.split() for x in f.read().splitlines()]
# 判断现有的shape文件中的行数(图像个数)是否与当前数据集中图像个数相等
# 如果不相等则认为是不同的数据集,故重新生成shape文件
assert len(s) == n, "shapefile out of aync"
except Exception as e:
print("read {} failed [{}], rebuild {}.".format(sp, e, sp))
# print("read {} failed [{}], rebuild {}.".format(sp, e, sp))
# tqdm库会显示处理的进度
# 读取每张图片的size信息
s = [exif_size(Image.open(f)) for f in tqdm(self.img_files, desc="Reading image shapes")]
if rank in [-1, 0]:
image_files = tqdm(self.img_files, desc="Reading image shapes")
else:
image_files = self.img_files
s = [exif_size(Image.open(f)) for f in image_files]
# 将所有图片的shape信息保存在.shape文件中
np.savetxt(sp, s, fmt="%g") # overwrite existing (if any)

Expand Down Expand Up @@ -163,7 +167,7 @@ def __init__(self,
# label: [class, x, y, w, h] 其中的xywh都为相对值
self.labels = [np.zeros((0, 5), dtype=np.float32)] * n
extract_bounding_boxes, labels_loaded = False, False
nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number mission, found, empty, datasunset, duplicate
nm, nf, ne, nd = 0, 0, 0, 0 # number mission, found, empty, duplicate
# 这里分别命名是为了防止出现rect为False/True时混用导致计算的mAP错误
# 当rect为True时会对self.images和self.labels进行从新排序
if rect is True:
Expand Down Expand Up @@ -258,12 +262,17 @@ def __init__(self,
# Cache images into memory for faster training (Warning: large datasets may exceed system RAM)
if cache_images: # if training
gb = 0 # Gigabytes of cached images 用于记录缓存图像占用RAM大小
pbar = tqdm(range(len(self.img_files)), desc="Caching images")
if rank in [-1, 0]:
pbar = tqdm(range(len(self.img_files)), desc="Caching images")
else:
pbar = range(len(self.img_files))

self.img_hw0, self.img_hw = [None] * n, [None] * n
for i in pbar: # max 10k images
self.imgs[i], self.img_hw0[i], self.img_hw[i] = load_image(self, i) # img, hw_original, hw_resized
gb += self.imgs[i].nbytes # 用于记录缓存图像占用RAM大小
pbar.desc = "Caching images (%.1fGB)" % (gb / 1E9)
if rank in [-1, 0]:
pbar.desc = "Caching images (%.1fGB)" % (gb / 1E9)

# Detect corrupted images https://medium.com/joelthchao/programmatically-detect-corrupted-image-8c1b2006c3d3
detect_corrupted_images = False
Expand Down

0 comments on commit 45659bf

Please sign in to comment.