Skip to content

Commit

Permalink
Image transforms functionality used instead (huggingface#20278)
Browse files Browse the repository at this point in the history
* Image transforms functionality used instead

* Import torch

* Import rather than copy

* Update src/transformers/models/conditional_detr/feature_extraction_conditional_detr.py
  • Loading branch information
amyeroberts authored Nov 17, 2022
1 parent 3fad6ae commit 3a780cc
Show file tree
Hide file tree
Showing 12 changed files with 44 additions and 262 deletions.
17 changes: 9 additions & 8 deletions src/transformers/image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Tuple, Union

import numpy as np

Expand All @@ -38,13 +38,14 @@
)


if TYPE_CHECKING:
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
if is_flax_available():
import jax.numpy as jnp
if is_torch_available():
import torch

if is_tf_available():
import tensorflow as tf

if is_flax_available():
import jax.numpy as jnp


def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDimension, str]) -> np.ndarray:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from PIL import Image

from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
from ...utils import TensorType, is_torch_available, logging
from ...image_transforms import center_to_corners_format, corners_to_center_format, rgb_to_id
from ...image_utils import ImageFeatureExtractionMixin
from ...utils import TensorType, is_torch_available, is_torch_tensor, logging


if is_torch_available():
Expand All @@ -36,29 +37,6 @@
ImageInput = Union[Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"]]


# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)


# Copied from transformers.models.detr.feature_extraction_detr.corners_to_center_format
def corners_to_center_format(x):
"""
Converts a NumPy array of bounding boxes of shape (number of bounding boxes, 4) of corners format (x_0, y_0, x_1,
y_1) to center format (center_x, center_y, width, height).
"""
x_transposed = x.T
x0, y0, x1, y1 = x_transposed[0], x_transposed[1], x_transposed[2], x_transposed[3]
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
return np.stack(b, axis=-1)


# Copied from transformers.models.detr.feature_extraction_detr.masks_to_boxes
def masks_to_boxes(masks):
"""
Expand Down Expand Up @@ -93,15 +71,6 @@ def masks_to_boxes(masks):
return np.stack([x_min, y_min, x_max, y_max], 1)


# Copied from transformers.models.detr.feature_extraction_detr.rgb_to_id
def rgb_to_id(color):
if isinstance(color, np.ndarray) and len(color.shape) == 3:
if color.dtype == np.uint8:
color = color.astype(np.int32)
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])


# Copied from transformers.models.detr.feature_extraction_detr.binary_mask_to_rle
def binary_mask_to_rle(mask):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
add_start_docstrings_to_model_forward,
is_scipy_available,
is_timm_available,
is_vision_available,
logging,
replace_return_docstrings,
requires_backends,
Expand All @@ -46,6 +47,9 @@
if is_timm_available():
from timm import create_model

if is_vision_available():
from transformers.image_transforms import center_to_corners_format

logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "ConditionalDetrConfig"
Expand Down Expand Up @@ -2596,17 +2600,6 @@ def generalized_box_iou(boxes1, boxes2):
return iou - (area - union) / area


# Copied from transformers.models.detr.modeling_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)


# Copied from transformers.models.detr.modeling_detr._max_by_axis
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from PIL import Image

from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
from ...utils import TensorType, is_torch_available, logging
from ...image_transforms import center_to_corners_format, corners_to_center_format, rgb_to_id
from ...image_utils import ImageFeatureExtractionMixin
from ...utils import TensorType, is_torch_available, is_torch_tensor, logging


if is_torch_available():
Expand All @@ -36,29 +37,6 @@
ImageInput = Union[Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"]]


# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)


# Copied from transformers.models.detr.feature_extraction_detr.corners_to_center_format
def corners_to_center_format(x):
"""
Converts a NumPy array of bounding boxes of shape (number of bounding boxes, 4) of corners format (x_0, y_0, x_1,
y_1) to center format (center_x, center_y, width, height).
"""
x_transposed = x.T
x0, y0, x1, y1 = x_transposed[0], x_transposed[1], x_transposed[2], x_transposed[3]
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
return np.stack(b, axis=-1)


# Copied from transformers.models.detr.feature_extraction_detr.masks_to_boxes
def masks_to_boxes(masks):
"""
Expand Down Expand Up @@ -93,32 +71,6 @@ def masks_to_boxes(masks):
return np.stack([x_min, y_min, x_max, y_max], 1)


# Copied from transformers.models.detr.feature_extraction_detr.rgb_to_id
def rgb_to_id(color):
if isinstance(color, np.ndarray) and len(color.shape) == 3:
if color.dtype == np.uint8:
color = color.astype(np.int32)
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])


# Copied from transformers.models.detr.feature_extraction_detr.id_to_rgb
def id_to_rgb(id_map):
if isinstance(id_map, np.ndarray):
id_map_copy = id_map.copy()
rgb_shape = tuple(list(id_map.shape) + [3])
rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
for i in range(3):
rgb_map[..., i] = id_map_copy % 256
id_map_copy //= 256
return rgb_map
color = []
for _ in range(3):
color.append(id_map % 256)
id_map //= 256
return color


class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a Deformable DETR feature extractor. Differs only in the postprocessing of object detection compared to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_scipy_available,
is_timm_available,
is_torch_cuda_available,
is_vision_available,
replace_return_docstrings,
requires_backends,
)
Expand All @@ -58,6 +59,9 @@
else:
MultiScaleDeformableAttention = None

if is_vision_available():
from transformers.image_transforms import center_to_corners_format


class MultiScaleDeformableAttentionFunction(Function):
@staticmethod
Expand Down Expand Up @@ -2417,17 +2421,6 @@ def generalized_box_iou(boxes1, boxes2):
return iou - (area - union) / area


# Copied from transformers.models.detr.modeling_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)


# Copied from transformers.models.detr.modeling_detr._max_by_axis
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
Expand Down
54 changes: 3 additions & 51 deletions src/transformers/models/detr/feature_extraction_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
from PIL import Image

from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
from ...utils import TensorType, is_torch_available, logging
from ...image_transforms import center_to_corners_format, corners_to_center_format, id_to_rgb, rgb_to_id
from ...image_utils import ImageFeatureExtractionMixin
from ...utils import TensorType, is_torch_available, is_torch_tensor, logging


if is_torch_available():
Expand All @@ -38,28 +39,6 @@
ImageInput = Union[Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"]]


# 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)


def corners_to_center_format(x):
"""
Converts a NumPy array of bounding boxes of shape (number of bounding boxes, 4) of corners format (x_0, y_0, x_1,
y_1) to center format (center_x, center_y, width, height).
"""
x_transposed = x.T
x0, y0, x1, y1 = x_transposed[0], x_transposed[1], x_transposed[2], x_transposed[3]
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
return np.stack(b, axis=-1)


def masks_to_boxes(masks):
"""
Compute the bounding boxes around the provided panoptic segmentation masks.
Expand Down Expand Up @@ -93,33 +72,6 @@ def masks_to_boxes(masks):
return np.stack([x_min, y_min, x_max, y_max], 1)


# 2 functions below copied from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
# Copyright (c) 2018, Alexander Kirillov
# All rights reserved.
def rgb_to_id(color):
if isinstance(color, np.ndarray) and len(color.shape) == 3:
if color.dtype == np.uint8:
color = color.astype(np.int32)
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])


def id_to_rgb(id_map):
if isinstance(id_map, np.ndarray):
id_map_copy = id_map.copy()
rgb_shape = tuple(list(id_map.shape) + [3])
rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
for i in range(3):
rgb_map[..., i] = id_map_copy % 256
id_map_copy //= 256
return rgb_map
color = []
for _ in range(3):
color.append(id_map % 256)
id_map //= 256
return color


def binary_mask_to_rle(mask):
"""
Args:
Expand Down
15 changes: 4 additions & 11 deletions src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
add_start_docstrings_to_model_forward,
is_scipy_available,
is_timm_available,
is_vision_available,
logging,
replace_return_docstrings,
requires_backends,
Expand All @@ -46,6 +47,9 @@
if is_timm_available():
from timm import create_model

if is_vision_available():
from transformers.image_transforms import center_to_corners_format

logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "DetrConfig"
Expand Down Expand Up @@ -2284,17 +2288,6 @@ def generalized_box_iou(boxes1, boxes2):
return iou - (area - union) / area


# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)


# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306


Expand Down
16 changes: 3 additions & 13 deletions src/transformers/models/owlvit/feature_extraction_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from transformers.image_utils import PILImageResampling

from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
from ...utils import TensorType, is_torch_available, logging
from ...image_transforms import center_to_corners_format
from ...image_utils import ImageFeatureExtractionMixin
from ...utils import TensorType, is_torch_available, is_torch_tensor, logging


if is_torch_available():
Expand All @@ -32,17 +33,6 @@
logger = logging.get_logger(__name__)


# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)


# Copied from transformers.models.detr.modeling_detr._upcast
def _upcast(t):
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
Expand Down
Loading

0 comments on commit 3a780cc

Please sign in to comment.