From 98c06a342ab674dab877bc4482fe6c6fa1a06d9e Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Wed, 2 Sep 2020 22:22:02 +0300 Subject: [PATCH] [Datumaro] Diff with exact annotation matching (#1989) * Add exact diff command * Update changelog * fix * fix merge * Add image matching, add test * Add point matching test * linter * Update CHANGELOG.md Co-authored-by: Nikita Manovich --- CHANGELOG.md | 1 + .../datumaro/cli/contexts/project/__init__.py | 103 +++++- .../datumaro/cli/contexts/project/diff.py | 2 +- datumaro/datumaro/components/comparator.py | 113 ------ datumaro/datumaro/components/extractor.py | 6 +- datumaro/datumaro/components/operations.py | 346 +++++++++++++++++- datumaro/datumaro/util/__init__.py | 3 + datumaro/datumaro/util/test_utils.py | 3 +- datumaro/tests/test_diff.py | 235 ++++++++---- 9 files changed, 611 insertions(+), 201 deletions(-) delete mode 100644 datumaro/datumaro/components/comparator.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a033a92ea26c..1b1b7ed29e18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added password reset functionality () - Ability to work with data on the fly (https://github.com/opencv/cvat/pull/2007) - Annotation in process outline color wheel () +- [Datumaro] CLI command for dataset equality comparison () ### Changed - UI models (like DEXTR) were redesigned to be more interactive () diff --git a/datumaro/datumaro/cli/contexts/project/__init__.py b/datumaro/datumaro/cli/contexts/project/__init__.py index 63c84076b9ff..8915086be35c 100644 --- a/datumaro/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/datumaro/cli/contexts/project/__init__.py @@ -4,25 +4,26 @@ # SPDX-License-Identifier: MIT import argparse -from enum import Enum import json import logging as log import os import os.path as osp import shutil +from enum import Enum -from datumaro.components.project import Project, Environment, \ - PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG -from datumaro.components.comparator import Comparator +from datumaro.components.cli_plugin import CliPlugin from datumaro.components.dataset_filter import DatasetItemEncoder from datumaro.components.extractor import AnnotationType -from datumaro.components.cli_plugin import CliPlugin -from datumaro.components.operations import \ - compute_image_statistics, compute_ann_statistics +from datumaro.components.operations import (DistanceComparator, + ExactComparator, compute_ann_statistics, compute_image_statistics, mean_std) +from datumaro.components.project import \ + PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG +from datumaro.components.project import Environment, Project + +from ...util import (CliException, MultilineFormatter, add_subparser, + make_file_name) +from ...util.project import generate_next_file_name, load_project from .diff import DiffVisualizer -from ...util import add_subparser, CliException, MultilineFormatter, \ - make_file_name -from ...util.project import load_project, generate_next_file_name def build_create_parser(parser_ctor=argparse.ArgumentParser): @@ -503,12 +504,12 @@ def merge_command(args): def build_diff_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor(help="Compare projects", description=""" - Compares two projects.|n + Compares two projects, match annotations by distance.|n |n Examples:|n - - Compare two projects, consider bboxes matching if their IoU > 0.7,|n + - Compare two projects, match boxes if IoU > 0.7,|n |s|s|s|sprint results to Tensorboard: - |s|sdiff path/to/other/project -o diff/ -f tensorboard --iou-thresh 0.7 + |s|sdiff path/to/other/project -o diff/ -v tensorboard --iou-thresh 0.7 """, formatter_class=MultilineFormatter) @@ -516,7 +517,7 @@ def build_diff_parser(parser_ctor=argparse.ArgumentParser): help="Directory of the second project to be compared") parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, help="Directory to save comparison results (default: do not save)") - parser.add_argument('-f', '--format', + parser.add_argument('-v', '--visualizer', default=DiffVisualizer.DEFAULT_FORMAT, choices=[f.name for f in DiffVisualizer.Format], help="Output format (default: %(default)s)") @@ -536,9 +537,7 @@ def diff_command(args): first_project = load_project(args.project_dir) second_project = load_project(args.other_project_dir) - comparator = Comparator( - iou_threshold=args.iou_thresh, - conf_threshold=args.conf_thresh) + comparator = DistanceComparator(iou_threshold=args.iou_thresh) dst_dir = args.dst_dir if dst_dir: @@ -556,7 +555,7 @@ def diff_command(args): dst_dir_existed = osp.exists(dst_dir) try: visualizer = DiffVisualizer(save_dir=dst_dir, comparator=comparator, - output_format=args.format) + output_format=args.visualizer) visualizer.save_dataset_diff( first_project.make_dataset(), second_project.make_dataset()) @@ -567,6 +566,73 @@ def diff_command(args): return 0 +def build_ediff_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Compare projects for equality", + description=""" + Compares two projects for equality.|n + |n + Examples:|n + - Compare two projects, exclude annotation group |n + |s|s|sand the 'is_crowd' attribute from comparison:|n + |s|sediff other/project/ -if group -ia is_crowd + """, + formatter_class=MultilineFormatter) + + parser.add_argument('other_project_dir', + help="Directory of the second project to be compared") + parser.add_argument('-iia', '--ignore-item-attr', action='append', + help="Ignore item attribute (repeatable)") + parser.add_argument('-ia', '--ignore-attr', action='append', + help="Ignore annotation attribute (repeatable)") + parser.add_argument('-if', '--ignore-field', + action='append', default=['id', 'group'], + help="Ignore annotation field (repeatable, default: %(default)s)") + parser.add_argument('--match-images', action='store_true', + help='Match dataset items by images instead of ids') + parser.add_argument('--all', action='store_true', + help="Include matches in the output") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the first project to be compared (default: current dir)") + parser.set_defaults(command=ediff_command) + + return parser + +def ediff_command(args): + first_project = load_project(args.project_dir) + second_project = load_project(args.other_project_dir) + + comparator = ExactComparator( + match_images=args.match_images, + ignored_fields=args.ignore_field, + ignored_attrs=args.ignore_attr, + ignored_item_attrs=args.ignore_item_attr) + matches, mismatches, a_extra, b_extra, errors = \ + comparator.compare_datasets( + first_project.make_dataset(), second_project.make_dataset()) + output = { + "mismatches": mismatches, + "a_extra_items": sorted(a_extra), + "b_extra_items": sorted(b_extra), + "errors": errors, + } + if args.all: + output["matches"] = matches + + output_file = generate_next_file_name('diff', ext='.json') + with open(output_file, 'w') as f: + json.dump(output, f, indent=4, sort_keys=True) + + print("Found:") + print("The first project has %s unmatched items" % len(a_extra)) + print("The second project has %s unmatched items" % len(b_extra)) + print("%s item conflicts" % len(errors)) + print("%s matching annotations" % len(matches)) + print("%s mismatching annotations" % len(mismatches)) + + log.info("Output has been saved to '%s'" % output_file) + + return 0 + def build_transform_parser(parser_ctor=argparse.ArgumentParser): builtins = sorted(Environment().transforms.items) @@ -753,6 +819,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser): add_subparser(subparsers, 'extract', build_extract_parser) add_subparser(subparsers, 'merge', build_merge_parser) add_subparser(subparsers, 'diff', build_diff_parser) + add_subparser(subparsers, 'ediff', build_ediff_parser) add_subparser(subparsers, 'transform', build_transform_parser) add_subparser(subparsers, 'info', build_info_parser) add_subparser(subparsers, 'stats', build_stats_parser) diff --git a/datumaro/datumaro/cli/contexts/project/diff.py b/datumaro/datumaro/cli/contexts/project/diff.py index 785c6c8ecde7..571908f66794 100644 --- a/datumaro/datumaro/cli/contexts/project/diff.py +++ b/datumaro/datumaro/cli/contexts/project/diff.py @@ -217,7 +217,7 @@ def save_item_bbox_diff(self, item_a, item_b, diff): _, mispred, a_unmatched, b_unmatched = diff if 0 < len(a_unmatched) + len(b_unmatched) + len(mispred): - img_a = item_a.image.copy() + img_a = item_a.image.data.copy() img_b = img_a.copy() for a_bbox, b_bbox in mispred: self.draw_bbox(img_a, a_bbox, (0, 255, 0)) diff --git a/datumaro/datumaro/components/comparator.py b/datumaro/datumaro/components/comparator.py deleted file mode 100644 index 842a3963a989..000000000000 --- a/datumaro/datumaro/components/comparator.py +++ /dev/null @@ -1,113 +0,0 @@ - -# Copyright (C) 2019 Intel Corporation -# -# SPDX-License-Identifier: MIT - -from itertools import zip_longest -import numpy as np - -from datumaro.components.extractor import AnnotationType, LabelCategories - - -class Comparator: - def __init__(self, - iou_threshold=0.5, conf_threshold=0.9): - self.iou_threshold = iou_threshold - self.conf_threshold = conf_threshold - - @staticmethod - def iou(box_a, box_b): - return box_a.iou(box_b) - - # pylint: disable=no-self-use - def compare_dataset_labels(self, extractor_a, extractor_b): - a_label_cat = extractor_a.categories().get(AnnotationType.label) - b_label_cat = extractor_b.categories().get(AnnotationType.label) - if not a_label_cat and not b_label_cat: - return None - if not a_label_cat: - a_label_cat = LabelCategories() - if not b_label_cat: - b_label_cat = LabelCategories() - - mismatches = [] - for a_label, b_label in zip_longest(a_label_cat.items, b_label_cat.items): - if a_label != b_label: - mismatches.append((a_label, b_label)) - return mismatches - # pylint: enable=no-self-use - - def compare_item_labels(self, item_a, item_b): - conf_threshold = self.conf_threshold - - a_labels = set([ann.label for ann in item_a.annotations \ - if ann.type is AnnotationType.label and \ - conf_threshold < ann.attributes.get('score', 1)]) - b_labels = set([ann.label for ann in item_b.annotations \ - if ann.type is AnnotationType.label and \ - conf_threshold < ann.attributes.get('score', 1)]) - - a_unmatched = a_labels - b_labels - b_unmatched = b_labels - a_labels - matches = a_labels & b_labels - - return matches, a_unmatched, b_unmatched - - def compare_item_bboxes(self, item_a, item_b): - iou_threshold = self.iou_threshold - conf_threshold = self.conf_threshold - - a_boxes = [ann for ann in item_a.annotations \ - if ann.type is AnnotationType.bbox and \ - conf_threshold < ann.attributes.get('score', 1)] - b_boxes = [ann for ann in item_b.annotations \ - if ann.type is AnnotationType.bbox and \ - conf_threshold < ann.attributes.get('score', 1)] - a_boxes.sort(key=lambda ann: 1 - ann.attributes.get('score', 1)) - b_boxes.sort(key=lambda ann: 1 - ann.attributes.get('score', 1)) - - # a_matches: indices of b_boxes matched to a bboxes - # b_matches: indices of a_boxes matched to b bboxes - a_matches = -np.ones(len(a_boxes), dtype=int) - b_matches = -np.ones(len(b_boxes), dtype=int) - - iou_matrix = np.array([ - [self.iou(a, b) for b in b_boxes] for a in a_boxes - ]) - - # matches: boxes we succeeded to match completely - # mispred: boxes we succeeded to match, having label mismatch - matches = [] - mispred = [] - - for a_idx, a_bbox in enumerate(a_boxes): - if len(b_boxes) == 0: - break - matched_b = a_matches[a_idx] - iou_max = max(iou_matrix[a_idx, matched_b], iou_threshold) - for b_idx, b_bbox in enumerate(b_boxes): - if 0 <= b_matches[b_idx]: # assign a_bbox with max conf - continue - iou = iou_matrix[a_idx, b_idx] - if iou < iou_max: - continue - iou_max = iou - matched_b = b_idx - - if matched_b < 0: - continue - a_matches[a_idx] = matched_b - b_matches[matched_b] = a_idx - - b_bbox = b_boxes[matched_b] - - if a_bbox.label == b_bbox.label: - matches.append( (a_bbox, b_bbox) ) - else: - mispred.append( (a_bbox, b_bbox) ) - - # *_umatched: boxes of (*) we failed to match - a_unmatched = [a_boxes[i] for i, m in enumerate(a_matches) if m < 0] - b_unmatched = [b_boxes[i] for i, m in enumerate(b_matches) if m < 0] - - return matches, mispred, a_unmatched, b_unmatched diff --git a/datumaro/datumaro/components/extractor.py b/datumaro/datumaro/components/extractor.py index d7991cd121e0..0473a250b463 100644 --- a/datumaro/datumaro/components/extractor.py +++ b/datumaro/datumaro/components/extractor.py @@ -46,7 +46,7 @@ def wrap(item, **kwargs): @attrs class Categories: attributes = attrib(factory=set, validator=default_if_none(set), - kw_only=True) + kw_only=True, eq=False) @attrs class LabelCategories(Categories): @@ -137,6 +137,8 @@ def inverse_colormap(self): def __eq__(self, other): if not super().__eq__(other): return False + if not isinstance(other, __class__): + return False for label_id, my_color in self.colormap.items(): other_color = other.colormap.get(label_id) if not np.array_equal(my_color, other_color): @@ -179,6 +181,8 @@ def paint(self, colormap): def __eq__(self, other): if not super().__eq__(other): return False + if not isinstance(other, __class__): + return False return \ (self.label == other.label) and \ (self.z_order == other.z_order) and \ diff --git a/datumaro/datumaro/components/operations.py b/datumaro/datumaro/components/operations.py index 9e63d3a7e84e..2e3a68136db2 100644 --- a/datumaro/datumaro/components/operations.py +++ b/datumaro/datumaro/components/operations.py @@ -5,18 +5,20 @@ from collections import OrderedDict from copy import deepcopy +import hashlib import logging as log import attr import cv2 import numpy as np from attr import attrib, attrs +from unittest import TestCase from datumaro.components.cli_plugin import CliPlugin from datumaro.components.extractor import AnnotationType, Bbox, Label from datumaro.components.project import Dataset -from datumaro.util import find -from datumaro.util.attrs_util import ensure_cls +from datumaro.util import find, filter_dict +from datumaro.util.attrs_util import ensure_cls, default_if_none from datumaro.util.annotation_util import (segment_iou, bbox_iou, mean_bbox, OKS, find_instances, max_bbox, smooth_line) @@ -585,7 +587,7 @@ class MaskMatcher(_ShapeMatcher): @attrs(kw_only=True) class PointsMatcher(_ShapeMatcher): - sigma = attrib(converter=list, default=None) + sigma = attrib(type=list, default=None) instance_map = attrib(converter=dict) def distance(self, a, b): @@ -1003,3 +1005,341 @@ def get_label(ann): } for c, (bin_min, bin_max) in zip(hist, zip(bins[:-1], bins[1:]))] return stats + +@attrs +class DistanceComparator: + iou_threshold = attrib(converter=float, default=0.5) + + @staticmethod + def match_datasets(a, b): + a_items = set((item.id, item.subset) for item in a) + b_items = set((item.id, item.subset) for item in b) + + matches = a_items & b_items + a_unmatched = a_items - b_items + b_unmatched = b_items - a_items + return matches, a_unmatched, b_unmatched + + @staticmethod + def match_classes(a, b): + a_label_cat = a.categories().get(AnnotationType.label, LabelCategories()) + b_label_cat = b.categories().get(AnnotationType.label, LabelCategories()) + + a_labels = set(c.name for c in a_label_cat) + b_labels = set(c.name for c in b_label_cat) + + matches = a_labels & b_labels + a_unmatched = a_labels - b_labels + b_unmatched = b_labels - a_labels + return matches, a_unmatched, b_unmatched + + def match_annotations(self, item_a, item_b): + return { t: self._match_ann_type(t, item_a, item_b) } + + def _match_ann_type(self, t, *args): + # pylint: disable=no-value-for-parameter + if t == AnnotationType.label: + return self.match_labels(*args) + elif t == AnnotationType.bbox: + return self.match_boxes(*args) + elif t == AnnotationType.polygon: + return self.match_polygons(*args) + elif t == AnnotationType.mask: + return self.match_masks(*args) + elif t == AnnotationType.points: + return self.match_points(*args) + elif t == AnnotationType.polyline: + return self.match_lines(*args) + # pylint: enable=no-value-for-parameter + else: + raise NotImplementedError("Unexpected annotation type %s" % t) + + @staticmethod + def _get_ann_type(t, item): + return get_ann_type(item.annotations, t) + + def match_labels(self, item_a, item_b): + a_labels = set(a.label for a in + self._get_ann_type(AnnotationType.label, item_a)) + b_labels = set(a.label for a in + self._get_ann_type(AnnotationType.label, item_b)) + + matches = a_labels & b_labels + a_unmatched = a_labels - b_labels + b_unmatched = b_labels - a_labels + return matches, a_unmatched, b_unmatched + + def _match_segments(self, t, item_a, item_b): + a_boxes = self._get_ann_type(t, item_a) + b_boxes = self._get_ann_type(t, item_b) + return match_segments(a_boxes, b_boxes, dist_thresh=self.iou_threshold) + + def match_polygons(self, item_a, item_b): + return self._match_segments(AnnotationType.polygon, item_a, item_b) + + def match_masks(self, item_a, item_b): + return self._match_segments(AnnotationType.mask, item_a, item_b) + + def match_boxes(self, item_a, item_b): + return self._match_segments(AnnotationType.bbox, item_a, item_b) + + def match_points(self, item_a, item_b): + a_points = self._get_ann_type(AnnotationType.points, item_a) + b_points = self._get_ann_type(AnnotationType.points, item_b) + + instance_map = {} + for s in [item_a.annotations, item_b.annotations]: + s_instances = find_instances(s) + for inst in s_instances: + inst_bbox = max_bbox(inst) + for ann in inst: + instance_map[id(ann)] = [inst, inst_bbox] + matcher = PointsMatcher(instance_map=instance_map) + + return match_segments(a_points, b_points, + dist_thresh=self.iou_threshold, distance=matcher.distance) + + def match_lines(self, item_a, item_b): + a_lines = self._get_ann_type(AnnotationType.polyline, item_a) + b_lines = self._get_ann_type(AnnotationType.polyline, item_b) + + matcher = LineMatcher() + + return match_segments(a_lines, b_lines, + dist_thresh=self.iou_threshold, distance=matcher.distance) + +def match_items_by_id(a, b): + a_items = set((item.id, item.subset) for item in a) + b_items = set((item.id, item.subset) for item in b) + + matches = a_items & b_items + matches = [([m], [m]) for m in matches] + a_unmatched = a_items - b_items + b_unmatched = b_items - a_items + return matches, a_unmatched, b_unmatched + +def match_items_by_image_hash(a, b): + def _hash(item): + if not item.image.has_data: + log.warning("Image (%s, %s) has no image " + "data, counted as unmatched", item.id, item.subset) + return None + return hashlib.md5(item.image.data.tobytes()).hexdigest() + + def _build_hashmap(source): + d = {} + for item in source: + h = _hash(item) + if h is None: + h = str(id(item)) # anything unique + d.setdefault(h, []).append((item.id, item.subset)) + return d + + a_hash = _build_hashmap(a) + b_hash = _build_hashmap(b) + + a_items = set(a_hash) + b_items = set(b_hash) + + matches = a_items & b_items + a_unmatched = a_items - b_items + b_unmatched = b_items - a_items + + matches = [(a_hash[h], b_hash[h]) for h in matches] + a_unmatched = set(i for h in a_unmatched for i in a_hash[h]) + b_unmatched = set(i for h in b_unmatched for i in b_hash[h]) + + return matches, a_unmatched, b_unmatched + +@attrs +class ExactComparator: + match_images = attrib(kw_only=True, type=bool, default=False) + ignored_fields = attrib(kw_only=True, + factory=set, validator=default_if_none(set)) + ignored_attrs = attrib(kw_only=True, + factory=set, validator=default_if_none(set)) + ignored_item_attrs = attrib(kw_only=True, + factory=set, validator=default_if_none(set)) + + _test = attrib(init=False, type=TestCase) + errors = attrib(init=False, type=list) + + def __attrs_post_init__(self): + self._test = TestCase() + self._test.maxDiff = None + + + def _match_items(self, a, b): + if self.match_images: + return match_items_by_image_hash(a, b) + else: + return match_items_by_id(a, b) + + def _compare_categories(self, a, b): + test = self._test + errors = self.errors + + try: + test.assertEqual( + sorted(a, key=lambda t: t.value), + sorted(b, key=lambda t: t.value) + ) + except AssertionError as e: + errors.append({'type': 'categories', 'message': str(e)}) + + if AnnotationType.label in a: + try: + test.assertEqual( + a[AnnotationType.label].items, + b[AnnotationType.label].items, + ) + except AssertionError as e: + errors.append({'type': 'labels', 'message': str(e)}) + if AnnotationType.mask in a: + try: + test.assertEqual( + a[AnnotationType.mask].colormap, + b[AnnotationType.mask].colormap, + ) + except AssertionError as e: + errors.append({'type': 'colormap', 'message': str(e)}) + if AnnotationType.points in a: + try: + test.assertEqual( + a[AnnotationType.points].items, + b[AnnotationType.points].items, + ) + except AssertionError as e: + errors.append({'type': 'points', 'message': str(e)}) + + def _compare_annotations(self, a, b): + ignored_fields = self.ignored_fields + ignored_attrs = self.ignored_attrs + + a_fields = { k: None for k in vars(a) if k in ignored_fields } + b_fields = { k: None for k in vars(b) if k in ignored_fields } + if 'attributes' not in ignored_fields: + a_fields['attributes'] = filter_dict(a.attributes, ignored_attrs) + b_fields['attributes'] = filter_dict(b.attributes, ignored_attrs) + + result = a.wrap(**a_fields) == b.wrap(**b_fields) + + return result + + def _compare_items(self, item_a, item_b): + test = self._test + + a_id = (item_a.id, item_a.subset) + b_id = (item_b.id, item_b.subset) + + matched = [] + unmatched = [] + errors = [] + + try: + test.assertEqual( + filter_dict(item_a.attributes, self.ignored_item_attrs), + filter_dict(item_b.attributes, self.ignored_item_attrs) + ) + except AssertionError as e: + errors.append({'type': 'item_attr', + 'a_item': a_id, 'b_item': b_id, 'message': str(e)}) + + b_annotations = item_b.annotations[:] + for ann_a in item_a.annotations: + ann_b_candidates = [x for x in item_b.annotations + if x.type == ann_a.type] + + ann_b = find(enumerate(self._compare_annotations(ann_a, x) + for x in ann_b_candidates), lambda x: x[1]) + if ann_b is None: + unmatched.append({ + 'item': a_id, 'source': 'a', 'ann': str(ann_a), + }) + continue + else: + ann_b = ann_b_candidates[ann_b[0]] + + b_annotations.remove(ann_b) # avoid repeats + matched.append({'a_item': a_id, 'b_item': b_id, + 'a': str(ann_a), 'b': str(ann_b)}) + + for ann_b in b_annotations: + unmatched.append({'item': b_id, 'source': 'b', 'ann': str(ann_b)}) + + return matched, unmatched, errors + + def compare_datasets(self, a, b): + self.errors = [] + errors = self.errors + + self._compare_categories(a.categories(), b.categories()) + + matched = [] + unmatched = [] + + matches, a_unmatched, b_unmatched = self._match_items(a, b) + + if a.categories().get(AnnotationType.label) != \ + b.categories().get(AnnotationType.label): + return matched, unmatched, a_unmatched, b_unmatched, errors + + _dist = lambda s: len(s[1]) + len(s[2]) + for a_ids, b_ids in matches: + # build distance matrix + match_status = {} # (a_id, b_id): [matched, unmatched, errors] + a_matches = { a_id: None for a_id in a_ids } + b_matches = { b_id: None for b_id in b_ids } + + for a_id in a_ids: + item_a = a.get(*a_id) + candidates = {} + + for b_id in b_ids: + item_b = b.get(*b_id) + + i_m, i_um, i_err = self._compare_items(item_a, item_b) + candidates[b_id] = [i_m, i_um, i_err] + + if len(i_um) == 0: + a_matches[a_id] = b_id + b_matches[b_id] = a_id + matched.extend(i_m) + errors.extend(i_err) + break + + match_status[a_id] = candidates + + # assign + for a_id in a_ids: + if len(b_ids) == 0: + break + + # find the closest, ignore already assigned + matched_b = a_matches[a_id] + if matched_b is not None: + continue + min_dist = -1 + for b_id in b_ids: + if b_matches[b_id] is not None: + continue + d = _dist(match_status[a_id][b_id]) + if d < min_dist and 0 <= min_dist: + continue + min_dist = d + matched_b = b_id + + if matched_b is None: + continue + a_matches[a_id] = matched_b + b_matches[matched_b] = a_id + + m = match_status[a_id][matched_b] + matched.extend(m[0]) + unmatched.extend(m[1]) + errors.extend(m[2]) + + a_unmatched |= set(a_id for a_id, m in a_matches.items() if not m) + b_unmatched |= set(b_id for b_id, m in b_matches.items() if not m) + + return matched, unmatched, a_unmatched, b_unmatched, errors \ No newline at end of file diff --git a/datumaro/datumaro/util/__init__.py b/datumaro/datumaro/util/__init__.py index 293bb5f62f34..010057d54c69 100644 --- a/datumaro/datumaro/util/__init__.py +++ b/datumaro/datumaro/util/__init__.py @@ -88,3 +88,6 @@ def str_to_bool(s): return False else: raise ValueError("Can't convert value '%s' to bool" % s) + +def filter_dict(d, exclude_keys): + return { k: v for k, v in d.items() if k not in exclude_keys } \ No newline at end of file diff --git a/datumaro/datumaro/util/test_utils.py b/datumaro/datumaro/util/test_utils.py index f93a74ce1b37..62973ca5a0ae 100644 --- a/datumaro/datumaro/util/test_utils.py +++ b/datumaro/datumaro/util/test_utils.py @@ -100,8 +100,7 @@ def compare_datasets(test, expected, actual, ignored_attrs=None): ann_b = find(ann_b_matches, lambda x: _compare_annotations(x, ann_a, ignored_attrs=ignored_attrs)) if ann_b is None: - test.assertEqual(ann_a, ann_b, - 'ann %s, candidates %s' % (ann_a, ann_b_matches)) + test.fail('ann %s, candidates %s' % (ann_a, ann_b_matches)) item_b.annotations.remove(ann_b) # avoid repeats def compare_datasets_strict(test, expected, actual): diff --git a/datumaro/tests/test_diff.py b/datumaro/tests/test_diff.py index 9ad9c1de6fdf..33dd79da0ff7 100644 --- a/datumaro/tests/test_diff.py +++ b/datumaro/tests/test_diff.py @@ -1,123 +1,96 @@ -from unittest import TestCase +import numpy as np + +from datumaro.components.extractor import (DatasetItem, Label, Bbox, + Caption, Mask, Points) +from datumaro.components.project import Dataset +from datumaro.components.operations import DistanceComparator, ExactComparator -from datumaro.components.extractor import DatasetItem, Label, Bbox -from datumaro.components.comparator import Comparator +from unittest import TestCase -class DiffTest(TestCase): +class DistanceComparatorTest(TestCase): def test_no_bbox_diff_with_same_item(self): detections = 3 anns = [ - Bbox(i * 10, 10, 10, 10, label=i, - attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) + Bbox(i * 10, 10, 10, 10, label=i) + for i in range(detections) ] item = DatasetItem(id=0, annotations=anns) iou_thresh = 0.5 - conf_thresh = 0.5 - comp = Comparator( - iou_threshold=iou_thresh, conf_threshold=conf_thresh) + comp = DistanceComparator(iou_threshold=iou_thresh) - result = comp.compare_item_bboxes(item, item) + result = comp.match_boxes(item, item) matches, mispred, a_greater, b_greater = result self.assertEqual(0, len(mispred)) self.assertEqual(0, len(a_greater)) self.assertEqual(0, len(b_greater)) - self.assertEqual(len([it for it in item.annotations \ - if conf_thresh < it.attributes['score']]), - len(matches)) + self.assertEqual(len(item.annotations), len(matches)) for a_bbox, b_bbox in matches: self.assertLess(iou_thresh, a_bbox.iou(b_bbox)) self.assertEqual(a_bbox.label, b_bbox.label) - self.assertLess(conf_thresh, a_bbox.attributes['score']) - self.assertLess(conf_thresh, b_bbox.attributes['score']) def test_can_find_bbox_with_wrong_label(self): detections = 3 class_count = 2 item1 = DatasetItem(id=1, annotations=[ - Bbox(i * 10, 10, 10, 10, label=i, - attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) + Bbox(i * 10, 10, 10, 10, label=i) + for i in range(detections) ]) item2 = DatasetItem(id=2, annotations=[ - Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count, - attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) + Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count) + for i in range(detections) ]) iou_thresh = 0.5 - conf_thresh = 0.5 - comp = Comparator( - iou_threshold=iou_thresh, conf_threshold=conf_thresh) + comp = DistanceComparator(iou_threshold=iou_thresh) - result = comp.compare_item_bboxes(item1, item2) + result = comp.match_boxes(item1, item2) matches, mispred, a_greater, b_greater = result - self.assertEqual(len([it for it in item1.annotations \ - if conf_thresh < it.attributes['score']]), - len(mispred)) + self.assertEqual(len(item1.annotations), len(mispred)) self.assertEqual(0, len(a_greater)) self.assertEqual(0, len(b_greater)) self.assertEqual(0, len(matches)) for a_bbox, b_bbox in mispred: self.assertLess(iou_thresh, a_bbox.iou(b_bbox)) self.assertEqual((a_bbox.label + 1) % class_count, b_bbox.label) - self.assertLess(conf_thresh, a_bbox.attributes['score']) - self.assertLess(conf_thresh, b_bbox.attributes['score']) def test_can_find_missing_boxes(self): detections = 3 class_count = 2 item1 = DatasetItem(id=1, annotations=[ - Bbox(i * 10, 10, 10, 10, label=i, - attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) if i % 2 == 0 + Bbox(i * 10, 10, 10, 10, label=i) + for i in range(detections) if i % 2 == 0 ]) item2 = DatasetItem(id=2, annotations=[ - Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count, - attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) if i % 2 == 1 + Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count) + for i in range(detections) if i % 2 == 1 ]) iou_thresh = 0.5 - conf_thresh = 0.5 - comp = Comparator( - iou_threshold=iou_thresh, conf_threshold=conf_thresh) + comp = DistanceComparator(iou_threshold=iou_thresh) - result = comp.compare_item_bboxes(item1, item2) + result = comp.match_boxes(item1, item2) matches, mispred, a_greater, b_greater = result self.assertEqual(0, len(mispred)) - self.assertEqual(len([it for it in item1.annotations \ - if conf_thresh < it.attributes['score']]), - len(a_greater)) - self.assertEqual(len([it for it in item2.annotations \ - if conf_thresh < it.attributes['score']]), - len(b_greater)) + self.assertEqual(len(item1.annotations), len(a_greater)) + self.assertEqual(len(item2.annotations), len(b_greater)) self.assertEqual(0, len(matches)) def test_no_label_diff_with_same_item(self): detections = 3 - anns = [ - Label(i, attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) - ] + anns = [ Label(i) for i in range(detections) ] item = DatasetItem(id=1, annotations=anns) - conf_thresh = 0.5 - comp = Comparator(conf_threshold=conf_thresh) - - result = comp.compare_item_labels(item, item) + result = DistanceComparator().match_labels(item, item) matches, a_greater, b_greater = result self.assertEqual(0, len(a_greater)) self.assertEqual(0, len(b_greater)) - self.assertEqual(len([it for it in item.annotations \ - if conf_thresh < it.attributes['score']]), - len(matches)) + self.assertEqual(len(item.annotations), len(matches)) def test_can_find_wrong_label(self): item1 = DatasetItem(id=1, annotations=[ @@ -131,12 +104,148 @@ def test_can_find_wrong_label(self): Label(4), ]) - conf_thresh = 0.5 - comp = Comparator(conf_threshold=conf_thresh) - - result = comp.compare_item_labels(item1, item2) + result = DistanceComparator().match_labels(item1, item2) matches, a_greater, b_greater = result self.assertEqual(2, len(a_greater)) self.assertEqual(2, len(b_greater)) - self.assertEqual(1, len(matches)) \ No newline at end of file + self.assertEqual(1, len(matches)) + + def test_can_match_points(self): + item1 = DatasetItem(id=1, annotations=[ + Points([1, 2, 2, 0, 1, 1], label=0), + + Points([3, 5, 5, 7, 5, 3], label=0), + ]) + item2 = DatasetItem(id=2, annotations=[ + Points([1.5, 2, 2, 0.5, 1, 1.5], label=0), + + Points([5, 7, 7, 7, 7, 5], label=0), + ]) + + result = DistanceComparator().match_points(item1, item2) + + matches, mismatches, a_greater, b_greater = result + self.assertEqual(1, len(a_greater)) + self.assertEqual(1, len(b_greater)) + self.assertEqual(1, len(matches)) + self.assertEqual(0, len(mismatches)) + +class ExactComparatorTest(TestCase): + def test_class_comparison(self): + a = Dataset.from_iterable([], categories=['a', 'b', 'c']) + b = Dataset.from_iterable([], categories=['b', 'c']) + + comp = ExactComparator() + _, _, _, _, errors = comp.compare_datasets(a, b) + + self.assertEqual(1, len(errors), errors) + + def test_item_comparison(self): + a = Dataset.from_iterable([ + DatasetItem(id=1, subset='train'), + DatasetItem(id=2, subset='test', attributes={'x': 1}), + ], categories=['a', 'b', 'c']) + + b = Dataset.from_iterable([ + DatasetItem(id=2, subset='test'), + DatasetItem(id=3), + ], categories=['a', 'b', 'c']) + + comp = ExactComparator() + _, _, a_extra_items, b_extra_items, errors = comp.compare_datasets(a, b) + + self.assertEqual({('1', 'train')}, a_extra_items) + self.assertEqual({('3', '')}, b_extra_items) + self.assertEqual(1, len(errors), errors) + + def test_annotation_comparison(self): + a = Dataset.from_iterable([ + DatasetItem(id=1, annotations=[ + Caption('hello'), # unmatched + Caption('world', group=5), + Label(2, attributes={ 'x': 1, 'y': '2', }), + Bbox(1, 2, 3, 4, label=4, z_order=1, attributes={ + 'score': 1.0, + }), + Bbox(5, 6, 7, 8, group=5), + Points([1, 2, 2, 0, 1, 1], label=0, z_order=4), + Mask(label=3, z_order=2, image=np.ones((2, 3))), + ]), + ], categories=['a', 'b', 'c', 'd']) + + b = Dataset.from_iterable([ + DatasetItem(id=1, annotations=[ + Caption('world', group=5), + Label(2, attributes={ 'x': 1, 'y': '2', }), + Bbox(1, 2, 3, 4, label=4, z_order=1, attributes={ + 'score': 1.0, + }), + Bbox(5, 6, 7, 8, group=5), + Bbox(5, 6, 7, 8, group=5), # unmatched + Points([1, 2, 2, 0, 1, 1], label=0, z_order=4), + Mask(label=3, z_order=2, image=np.ones((2, 3))), + ]), + ], categories=['a', 'b', 'c', 'd']) + + comp = ExactComparator() + matched, unmatched, _, _, errors = comp.compare_datasets(a, b) + + self.assertEqual(6, len(matched), matched) + self.assertEqual(2, len(unmatched), unmatched) + self.assertEqual(0, len(errors), errors) + + def test_image_comparison(self): + a = Dataset.from_iterable([ + DatasetItem(id=11, image=np.ones((5, 4, 3)), annotations=[ + Bbox(5, 6, 7, 8), + ]), + DatasetItem(id=12, image=np.ones((5, 4, 3)), annotations=[ + Bbox(1, 2, 3, 4), + Bbox(5, 6, 7, 8), + ]), + DatasetItem(id=13, image=np.ones((5, 4, 3)), annotations=[ + Bbox(9, 10, 11, 12), # mismatch + ]), + + DatasetItem(id=14, image=np.zeros((5, 4, 3)), annotations=[ + Bbox(1, 2, 3, 4), + Bbox(5, 6, 7, 8), + ], attributes={ 'a': 1 }), + + DatasetItem(id=15, image=np.zeros((5, 5, 3)), annotations=[ + Bbox(1, 2, 3, 4), + Bbox(5, 6, 7, 8), + ]), + ], categories=['a', 'b', 'c', 'd']) + + b = Dataset.from_iterable([ + DatasetItem(id=21, image=np.ones((5, 4, 3)), annotations=[ + Bbox(5, 6, 7, 8), + ]), + DatasetItem(id=22, image=np.ones((5, 4, 3)), annotations=[ + Bbox(1, 2, 3, 4), + Bbox(5, 6, 7, 8), + ]), + DatasetItem(id=23, image=np.ones((5, 4, 3)), annotations=[ + Bbox(10, 10, 11, 12), # mismatch + ]), + + DatasetItem(id=24, image=np.zeros((5, 4, 3)), annotations=[ + Bbox(6, 6, 7, 8), # 1 ann missing, mismatch + ], attributes={ 'a': 2 }), + + DatasetItem(id=25, image=np.zeros((4, 4, 3)), annotations=[ + Bbox(6, 6, 7, 8), + ]), + ], categories=['a', 'b', 'c', 'd']) + + comp = ExactComparator(match_images=True) + matched_ann, unmatched_ann, a_unmatched, b_unmatched, errors = \ + comp.compare_datasets(a, b) + + self.assertEqual(3, len(matched_ann), matched_ann) + self.assertEqual(5, len(unmatched_ann), unmatched_ann) + self.assertEqual(1, len(a_unmatched), a_unmatched) + self.assertEqual(1, len(b_unmatched), b_unmatched) + self.assertEqual(1, len(errors), errors) \ No newline at end of file