Skip to content

Commit

Permalink
add coco index function that speed up targets loading for pycocotools
Browse files Browse the repository at this point in the history
  • Loading branch information
wz authored and wz committed Oct 28, 2020
1 parent 767b544 commit dd1929a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
50 changes: 49 additions & 1 deletion pytorch_object_detection/faster_rcnn/my_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_height_and_width(self, idx):
def parse_xml_to_dict(self, xml):
"""
将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
Args
Args:
xml: xml tree obtained by parsing XML file contents using lxml.etree
Returns:
Expand All @@ -113,6 +113,54 @@ def parse_xml_to_dict(self, xml):
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}

def coco_index(self, idx):
"""
该方法是专门为pycocotools统计标签信息准备,不对图像和标签作任何处理
由于不用去读取图片,可大幅缩减统计时间
Args:
idx: 输入需要获取图像的索引
"""
# read xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
data_height = int(data["size"]["height"])
data_width = int(data["size"]["width"])
# img_path = os.path.join(self.img_root, data["filename"])
# image = Image.open(img_path)
# if image.format != "JPEG":
# raise ValueError("Image format not JPEG")
boxes = []
labels = []
iscrowd = []
for obj in data["object"]:
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
boxes.append([xmin, ymin, xmax, ymax])
labels.append(self.class_dict[obj["name"]])
iscrowd.append(int(obj["difficult"]))

# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd

return (data_height, data_width), target

@staticmethod
def collate_fn(batch):
return tuple(zip(*batch))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ def convert_to_coco_api(ds):
categories = set()
for img_idx in range(len(ds)):
# find better way to get target
img, targets = ds[img_idx]
hw, targets = ds.coco_index(img_idx)
image_id = targets["image_id"].item()
img_dict = {}
img_dict['id'] = image_id
img_dict['height'] = img.shape[-2]
img_dict['width'] = img.shape[-1]
img_dict['height'] = hw[0]
img_dict['width'] = hw[1]
dataset['images'].append(img_dict)
bboxes = targets["boxes"]
bboxes[:, 2:] -= bboxes[:, :2]
Expand Down

0 comments on commit dd1929a

Please sign in to comment.