Skip to content

Commit

Permalink
Improve OWL-ViT postprocessing (huggingface#20980)
Browse files Browse the repository at this point in the history
* add post_process_object_detection method

* style changes
  • Loading branch information
alaradirik authored Jan 3, 2023
1 parent e901914 commit cd24578
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 18 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/owlvit.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi
[[autodoc]] OwlViTImageProcessor
- preprocess
- post_process
- post_process_object_detection
- post_process_image_guided_detection
## OwlViTFeatureExtractor
Expand Down
64 changes: 63 additions & 1 deletion src/transformers/models/owlvit/image_processing_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# limitations under the License.
"""Image processor class for OwlViT"""

from typing import Dict, List, Optional, Union
import warnings
from typing import Dict, List, Optional, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -344,6 +345,12 @@ def post_process(self, outputs, target_sizes):
in the batch as predicted by the model.
"""
# TODO: (amy) add support for other frameworks
warnings.warn(
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_object_detection`",
FutureWarning,
)

logits, boxes = outputs.logits, outputs.pred_boxes

if len(logits) != len(target_sizes):
Expand All @@ -367,6 +374,61 @@ def post_process(self, outputs, target_sizes):

return results

def post_process_object_detection(
self, outputs, threshold: float = 0.1, target_sizes: Union[TensorType, List[Tuple]] = None
):
"""
Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
bottom_right_x, bottom_right_y) format.
Args:
outputs ([`OwlViTObjectDetectionOutput`]):
Raw outputs of the model.
threshold (`float`, *optional*):
Score threshold to keep object detection predictions.
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
"""
# TODO: (amy) add support for other frameworks
logits, boxes = outputs.logits, outputs.pred_boxes

if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)

probs = torch.max(logits, dim=-1)
scores = torch.sigmoid(probs.values)
labels = probs.indices

# Convert to [x0, y0, x1, y1] format
boxes = center_to_corners_format(boxes)

# Convert from relative [0, 1] to absolute [0, height] coordinates
if target_sizes is not None:
if isinstance(target_sizes, List):
img_h = torch.Tensor([i[0] for i in target_sizes])
img_w = torch.Tensor([i[1] for i in target_sizes])
else:
img_h, img_w = target_sizes.unbind(1)

scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :]

results = []
for s, l, b in zip(scores, labels, boxes):
score = s[s > threshold]
label = l[s > threshold]
box = b[s > threshold]
results.append({"scores": score, "labels": label, "boxes": box})

return results

# TODO: (Amy) Make compatible with other frameworks
def post_process_image_guided_detection(self, outputs, threshold=0.6, nms_threshold=0.3, target_sizes=None):
"""
Expand Down
22 changes: 11 additions & 11 deletions src/transformers/models/owlvit/modeling_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ class OwlViTObjectDetectionOutput(ModelOutput):
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the unnormalized
bounding boxes.
possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to retrieve the
unnormalized bounding boxes.
text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
Expand Down Expand Up @@ -248,13 +248,13 @@ class OwlViTImageGuidedObjectDetectionOutput(ModelOutput):
target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual target image in the batch
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the
unnormalized bounding boxes.
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to
retrieve the unnormalized bounding boxes.
query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual query image in the batch
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the
unnormalized bounding boxes.
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to
retrieve the unnormalized bounding boxes.
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
image embeddings for each patch.
Expand Down Expand Up @@ -1644,18 +1644,18 @@ def forward(
>>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
>>> target_sizes = torch.Tensor([image.size[::-1]])
>>> # Convert outputs (bounding boxes and class logits) to COCO API
>>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
>>> # Convert outputs (bounding boxes and class logits) to final bounding boxes and scores
>>> results = processor.post_process_object_detection(
... outputs=outputs, threshold=0.1, target_sizes=target_sizes
... )
>>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
>>> text = texts[i]
>>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
>>> score_threshold = 0.1
>>> for box, score, label in zip(boxes, scores, labels):
... box = [round(i, 2) for i in box.tolist()]
... if score >= score_threshold:
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
```"""
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/owlvit/processing_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,13 @@ def post_process(self, *args, **kwargs):
"""
return self.image_processor.post_process(*args, **kwargs)

def post_process_object_detection(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_object_detection`]. Please refer
to the docstring of this method for more information.
"""
return self.image_processor.post_process_object_detection(*args, **kwargs)

def post_process_image_guided_detection(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_one_shot_object_detection`].
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/pipelines/zero_shot_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,11 @@ def postprocess(self, model_outputs, threshold=0.1, top_k=None):
for model_output in model_outputs:
label = model_output["candidate_label"]
model_output = BaseModelOutput(model_output)
outputs = self.feature_extractor.post_process(
outputs=model_output, target_sizes=model_output["target_size"]
outputs = self.feature_extractor.post_process_object_detection(
outputs=model_output, threshold=threshold, target_sizes=model_output["target_size"]
)[0]
keep = outputs["scores"] >= threshold

for index in keep.nonzero():
for index in outputs["scores"].nonzero():
score = outputs["scores"][index].item()
box = self._get_bounding_box(outputs["boxes"][index][0])

Expand Down
3 changes: 2 additions & 1 deletion tests/pipelines/test_pipelines_zero_shot_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def test_large_model_pt(self):
object_detector = pipeline("zero-shot-object-detection")

outputs = object_detector(
"http://images.cocodataset.org/val2017/000000039769.jpg", candidate_labels=["cat", "remote", "couch"]
"http://images.cocodataset.org/val2017/000000039769.jpg",
candidate_labels=["cat", "remote", "couch"],
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
Expand Down

0 comments on commit cd24578

Please sign in to comment.