Skip to content

Commit

Permalink
Az/fix no dump default attrs (cvat-ai#656)
Browse files Browse the repository at this point in the history
* fill absent attributes by default values during annotation save
* fill absent attributes by default values during init from db
* fixed tests
* updated changelog, added some coments, minor fixes
  • Loading branch information
azhavoro authored and nmanovic committed Aug 24, 2019
1 parent 7fb7ba1 commit fc2b9c9
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 39 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Installation of CVAT with OpenVINO on the Windows platform
- Background color was always black in utils/mask/converter.py
- Exception in attribute annotation mode when a label are switched to a value without any attributes
- Handling of wrong labelamp json file in auto annotation (https://github.com/opencv/cvat/issues/554)
- Handling of wrong labelamp json file in auto annotation (<https://github.com/opencv/cvat/issues/554>)
- No default attributes in dumped annotation (<https://github.com/opencv/cvat/issues/601>)

### Security
-
Expand Down
53 changes: 45 additions & 8 deletions cvat/apps/engine/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os
from enum import Enum
from collections import OrderedDict
from django.utils import timezone
from PIL import Image

Expand Down Expand Up @@ -192,9 +193,21 @@ def __init__(self, pk, user):
self.logger = slogger.job[self.db_job.id]
self.db_labels = {db_label.id:db_label
for db_label in db_segment.task.label_set.all()}
self.db_attributes = {db_attr.id:db_attr
for db_attr in models.AttributeSpec.objects.filter(
label__task__id=db_segment.task.id)}

self.db_attributes = {}
for db_label in self.db_labels.values():
self.db_attributes[db_label.id] = {
"mutable": OrderedDict(),
"immutable": OrderedDict(),
"all": OrderedDict(),
}
for db_attr in db_label.attributespec_set.all():
if db_attr.mutable:
self.db_attributes[db_label.id]["mutable"][db_attr.id] = db_attr
else:
self.db_attributes[db_label.id]["immutable"][db_attr.id] = db_attr

self.db_attributes[db_label.id]["all"][db_attr.id] = db_attr

def reset(self):
self.ir_data.reset()
Expand All @@ -214,7 +227,7 @@ def _save_tracks_to_db(self, tracks):

for attr in track_attributes:
db_attrval = models.LabeledTrackAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes:
if db_attrval.spec_id not in self.db_attributes[db_track.label_id]["immutable"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.track_id = len(db_tracks)
db_track_attrvals.append(db_attrval)
Expand All @@ -228,7 +241,7 @@ def _save_tracks_to_db(self, tracks):

for attr in shape_attributes:
db_attrval = models.TrackedShapeAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes:
if db_attrval.spec_id not in self.db_attributes[db_track.label_id]["mutable"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.shape_id = len(db_shapes)
db_shape_attrvals.append(db_attrval)
Expand Down Expand Up @@ -295,8 +308,9 @@ def _save_shapes_to_db(self, shapes):

for attr in attributes:
db_attrval = models.LabeledShapeAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes:
if db_attrval.spec_id not in self.db_attributes[db_shape.label_id]["all"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))

db_attrval.shape_id = len(db_shapes)
db_attrvals.append(db_attrval)

Expand Down Expand Up @@ -335,7 +349,7 @@ def _save_tags_to_db(self, tags):

for attr in attributes:
db_attrval = models.LabeledImageAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes:
if db_attrval.spec_id not in self.db_attributes[db_tag.label_id]["all"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.tag_id = len(db_tags)
db_attrvals.append(db_attrval)
Expand All @@ -350,7 +364,7 @@ def _save_tags_to_db(self, tags):
)

for db_attrval in db_attrvals:
db_attrval.tag_id = db_tags[db_attrval.tag_id].id
db_attrval.image_id = db_tags[db_attrval.tag_id].id

bulk_create(
db_model=models.LabeledImageAttributeVal,
Expand Down Expand Up @@ -436,6 +450,16 @@ def delete(self, data=None):
self._delete(data)
self._commit()

@staticmethod
def _extend_attributes(attributeval_set, attribute_specs):
shape_attribute_specs_set = set(attr.spec_id for attr in attributeval_set)
for db_attr_spec in attribute_specs:
if db_attr_spec.id not in shape_attribute_specs_set:
attributeval_set.append(OrderedDict([
('spec_id', db_attr_spec.id),
('value', db_attr_spec.default_value),
]))

def _init_tags_from_db(self):
db_tags = self.db_job.labeledimage_set.prefetch_related(
"label",
Expand All @@ -461,6 +485,11 @@ def _init_tags_from_db(self):
},
field_id='id',
)

for db_tag in db_tags:
self._extend_attributes(db_tag.labeledimageattributeval_set,
self.db_attributes[db_tag.label_id]["all"].values())

serializer = serializers.LabeledImageSerializer(db_tags, many=True)
self.ir_data.tags = serializer.data

Expand Down Expand Up @@ -493,6 +522,9 @@ def _init_shapes_from_db(self):
},
field_id='id',
)
for db_shape in db_shapes:
self._extend_attributes(db_shape.labeledshapeattributeval_set,
self.db_attributes[db_shape.label_id]["all"].values())

serializer = serializers.LabeledShapeSerializer(db_shapes, many=True)
self.ir_data.shapes = serializer.data
Expand Down Expand Up @@ -558,10 +590,15 @@ def _init_tracks_from_db(self):
# A result table can consist many equal rows for track/shape attributes
# We need filter unique attributes manually
db_track["labeledtrackattributeval_set"] = list(set(db_track["labeledtrackattributeval_set"]))
self._extend_attributes(db_track.labeledtrackattributeval_set,
self.db_attributes[db_track.label_id]["immutable"].values())

for db_shape in db_track["trackedshape_set"]:
db_shape["trackedshapeattributeval_set"] = list(
set(db_shape["trackedshapeattributeval_set"])
)
self._extend_attributes(db_shape["trackedshapeattributeval_set"],
self.db_attributes[db_track.label_id]["mutable"].values())

serializer = serializers.LabeledTrackSerializer(db_tracks, many=True)
self.ir_data.tracks = serializer.data
Expand Down
101 changes: 71 additions & 30 deletions cvat/apps/engine/tests/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,7 @@ def _create_task(self, owner, assignee):
"mutable": False,
"input_type": "select",
"default_value": "mazda",
"values": ["bmw", "mazda", "reno"]
"values": ["bmw", "mazda", "renault"]
},
{
"name": "parked",
Expand Down Expand Up @@ -1212,6 +1212,27 @@ def _create_task(self, owner, assignee):

return (task, jobs)

@staticmethod
def _get_default_attr_values(task):
default_attr_values = {}
for label in task["labels"]:
default_attr_values[label["id"]] = {
"mutable": [],
"immutable": [],
"all": [],
}
for attr in label["attributes"]:
default_value = {
"spec_id": attr["id"],
"value": attr["default_value"],
}
if attr["mutable"]:
default_attr_values[label["id"]]["mutable"].append(default_value)
else:
default_attr_values[label["id"]]["immutable"].append(default_value)
default_attr_values[label["id"]]["all"].append(default_value)
return default_attr_values

def _put_api_v1_jobs_id_data(self, jid, user, data):
with ForceLogin(user, self.client):
response = self.client.put("/api/v1/jobs/{}/annotations".format(jid),
Expand Down Expand Up @@ -1288,7 +1309,7 @@ def _run_api_v1_jobs_id_annotations(self, owner, assignee, annotator):
},
{
"spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"]
"value": task["labels"][0]["attributes"][1]["default_value"]
}
],
"points": [1.0, 2.1, 100, 300.222],
Expand All @@ -1310,7 +1331,12 @@ def _run_api_v1_jobs_id_annotations(self, owner, assignee, annotator):
"frame": 0,
"label_id": task["labels"][0]["id"],
"group": None,
"attributes": [],
"attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
],
"shapes": [
{
"frame": 0,
Expand All @@ -1319,14 +1345,10 @@ def _run_api_v1_jobs_id_annotations(self, owner, assignee, annotator):
"occluded": False,
"outside": False,
"attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
{
"spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"]
}
"value": task["labels"][0]["attributes"][1]["default_value"]
},
]
},
{
Expand Down Expand Up @@ -1357,13 +1379,18 @@ def _run_api_v1_jobs_id_annotations(self, owner, assignee, annotator):
},
]
}

default_attr_values = self._get_default_attr_values(task)
response = self._put_api_v1_jobs_id_data(job["id"], annotator, data)
data["version"] += 1 # need to update the version
self.assertEqual(response.status_code, HTTP_200_OK)
self._check_response(response, data)

response = self._get_api_v1_jobs_id_data(job["id"], annotator)
self.assertEqual(response.status_code, HTTP_200_OK)
# server should add default attribute values if puted data doesn't contain it
data["tags"][0]["attributes"] = default_attr_values[data["tags"][0]["label_id"]]["all"]
data["tracks"][0]["shapes"][1]["attributes"] = default_attr_values[data["tracks"][0]["label_id"]]["mutable"]
self._check_response(response, data)

response = self._delete_api_v1_jobs_id_data(job["id"], annotator)
Expand Down Expand Up @@ -1402,7 +1429,7 @@ def _run_api_v1_jobs_id_annotations(self, owner, assignee, annotator):
},
{
"spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"]
"value": task["labels"][0]["attributes"][1]["default_value"]
}
],
"points": [1.0, 2.1, 100, 300.222],
Expand All @@ -1424,7 +1451,12 @@ def _run_api_v1_jobs_id_annotations(self, owner, assignee, annotator):
"frame": 0,
"label_id": task["labels"][0]["id"],
"group": None,
"attributes": [],
"attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
],
"shapes": [
{
"frame": 0,
Expand All @@ -1433,14 +1465,10 @@ def _run_api_v1_jobs_id_annotations(self, owner, assignee, annotator):
"occluded": False,
"outside": False,
"attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
{
"spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"]
}
"value": task["labels"][0]["attributes"][1]["default_value"]
},
]
},
{
Expand Down Expand Up @@ -1479,6 +1507,9 @@ def _run_api_v1_jobs_id_annotations(self, owner, assignee, annotator):

response = self._get_api_v1_jobs_id_data(job["id"], annotator)
self.assertEqual(response.status_code, HTTP_200_OK)
# server should add default attribute values if puted data doesn't contain it
data["tags"][0]["attributes"] = default_attr_values[data["tags"][0]["label_id"]]["all"]
data["tracks"][0]["shapes"][1]["attributes"] = default_attr_values[data["tracks"][0]["label_id"]]["mutable"]
self._check_response(response, data)

data = response.data
Expand Down Expand Up @@ -1576,7 +1607,7 @@ def _run_api_v1_jobs_id_annotations(self, owner, assignee, annotator):
},
{
"spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"]
"value": task["labels"][0]["attributes"][1]["default_value"]
}
]
},
Expand Down Expand Up @@ -1733,7 +1764,12 @@ def _run_api_v1_tasks_id_annotations(self, owner, assignee, annotator):
"frame": 0,
"label_id": task["labels"][0]["id"],
"group": None,
"attributes": [],
"attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
],
"shapes": [
{
"frame": 0,
Expand All @@ -1742,13 +1778,9 @@ def _run_api_v1_tasks_id_annotations(self, owner, assignee, annotator):
"occluded": False,
"outside": False,
"attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
{
"spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"]
"value": task["labels"][0]["attributes"][1]["default_value"]
}
]
},
Expand Down Expand Up @@ -1782,10 +1814,15 @@ def _run_api_v1_tasks_id_annotations(self, owner, assignee, annotator):
}
response = self._put_api_v1_tasks_id_annotations(task["id"], annotator, data)
data["version"] += 1

self.assertEqual(response.status_code, HTTP_200_OK)
self._check_response(response, data)

default_attr_values = self._get_default_attr_values(task)
response = self._get_api_v1_tasks_id_annotations(task["id"], annotator)
# server should add default attribute values if puted data doesn't contain it
data["tags"][0]["attributes"] = default_attr_values[data["tags"][0]["label_id"]]["all"]
data["tracks"][0]["shapes"][1]["attributes"] = default_attr_values[data["tracks"][0]["label_id"]]["mutable"]
self.assertEqual(response.status_code, HTTP_200_OK)
self._check_response(response, data)

Expand Down Expand Up @@ -1847,7 +1884,12 @@ def _run_api_v1_tasks_id_annotations(self, owner, assignee, annotator):
"frame": 0,
"label_id": task["labels"][0]["id"],
"group": None,
"attributes": [],
"attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
],
"shapes": [
{
"frame": 0,
Expand All @@ -1856,13 +1898,9 @@ def _run_api_v1_tasks_id_annotations(self, owner, assignee, annotator):
"occluded": False,
"outside": False,
"attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
{
"spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"]
"value": task["labels"][0]["attributes"][1]["default_value"]
}
]
},
Expand Down Expand Up @@ -1901,6 +1939,9 @@ def _run_api_v1_tasks_id_annotations(self, owner, assignee, annotator):
self._check_response(response, data)

response = self._get_api_v1_tasks_id_annotations(task["id"], annotator)
# server should add default attribute values if puted data doesn't contain it
data["tags"][0]["attributes"] = default_attr_values[data["tags"][0]["label_id"]]["all"]
data["tracks"][0]["shapes"][1]["attributes"] = default_attr_values[data["tracks"][0]["label_id"]]["mutable"]
self.assertEqual(response.status_code, HTTP_200_OK)
self._check_response(response, data)

Expand Down

0 comments on commit fc2b9c9

Please sign in to comment.