Skip to content

Commit

Permalink
Add Swin backbone (huggingface#20769)
Browse files Browse the repository at this point in the history
* Add Swin backbone

* Remove line

* Add code example

Co-authored-by: Niels Rogge <[email protected]>
  • Loading branch information
NielsRogge and Niels Rogge authored Dec 14, 2022
1 parent 94f8e21 commit 67acb07
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 41 deletions.
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,6 +2078,7 @@
_import_structure["models.swin"].extend(
[
"SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
"SwinBackbone",
"SwinForImageClassification",
"SwinForMaskedImageModeling",
"SwinModel",
Expand Down Expand Up @@ -5041,6 +5042,7 @@
)
from .models.swin import (
SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
SwinBackbone,
SwinForImageClassification,
SwinForMaskedImageModeling,
SwinModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@
("maskformer-swin", "MaskFormerSwinBackbone"),
("nat", "NatBackbone"),
("resnet", "ResNetBackbone"),
("swin", "SwinBackbone"),
]
)

Expand Down
28 changes: 20 additions & 8 deletions src/transformers/models/donut/modeling_donut_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,6 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
self.shift_size = shift_size
self.window_size = config.window_size
self.input_resolution = input_resolution
self.set_shift_and_window_size(input_resolution)
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = DonutSwinAttention(config, dim, num_heads, window_size=self.window_size)
self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
Expand Down Expand Up @@ -585,7 +584,9 @@ def forward(
shortcut = hidden_states

hidden_states = self.layernorm_before(hidden_states)

hidden_states = hidden_states.view(batch_size, height, width, channels)

# pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)

Expand Down Expand Up @@ -677,14 +678,15 @@ def forward(

hidden_states = layer_outputs[0]

hidden_states_before_downsampling = hidden_states
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(layer_outputs[0], input_dimensions)
hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
else:
output_dimensions = (height, width, height, width)

stage_outputs = (hidden_states, output_dimensions)
stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)

if output_attentions:
stage_outputs += layer_outputs[1:]
Expand Down Expand Up @@ -722,9 +724,9 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, DonutSwinEncoderOutput]:
all_input_dimensions = ()
all_hidden_states = () if output_hidden_states else None
all_reshaped_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
Expand Down Expand Up @@ -755,12 +757,22 @@ def custom_forward(*inputs):
layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)

hidden_states = layer_outputs[0]
output_dimensions = layer_outputs[1]
hidden_states_before_downsampling = layer_outputs[1]
output_dimensions = layer_outputs[2]

input_dimensions = (output_dimensions[-2], output_dimensions[-1])
all_input_dimensions += (input_dimensions,)

if output_hidden_states:
if output_hidden_states and output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
# rearrange b (h w) c -> b c h w
# here we use the original (not downsampled) height and width
reshaped_hidden_state = hidden_states_before_downsampling.view(
batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
)
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states_before_downsampling,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
elif output_hidden_states and not output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states.shape
# rearrange b (h w) c -> b c h w
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
Expand All @@ -769,7 +781,7 @@ def custom_forward(*inputs):
all_reshaped_hidden_states += (reshaped_hidden_state,)

if output_attentions:
all_self_attentions += layer_outputs[2:]
all_self_attentions += layer_outputs[3:]

if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/swin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"SwinForMaskedImageModeling",
"SwinModel",
"SwinPreTrainedModel",
"SwinBackbone",
]

try:
Expand Down Expand Up @@ -63,6 +64,7 @@
else:
from .modeling_swin import (
SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
SwinBackbone,
SwinForImageClassification,
SwinForMaskedImageModeling,
SwinModel,
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/swin/configuration_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class SwinConfig(PretrainedConfig):
The epsilon used by the layer normalization layers.
encoder_stride (`int`, `optional`, defaults to 32):
Factor to increase the spatial resolution by in the decoder head for masked image modeling.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset.
Example:
Expand Down Expand Up @@ -125,6 +128,7 @@ def __init__(
initializer_range=0.02,
layer_norm_eps=1e-5,
encoder_stride=32,
out_features=None,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -151,6 +155,16 @@ def __init__(
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
self.out_features = out_features


class SwinOnnxConfig(OnnxConfig):
Expand Down
146 changes: 137 additions & 9 deletions src/transformers/models/swin/modeling_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import (
ModelOutput,
Expand Down Expand Up @@ -589,7 +590,6 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
self.shift_size = shift_size
self.window_size = config.window_size
self.input_resolution = input_resolution
self.set_shift_and_window_size(input_resolution)
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = SwinAttention(config, dim, num_heads, window_size=self.window_size)
self.drop_path = SwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
Expand Down Expand Up @@ -651,7 +651,9 @@ def forward(
shortcut = hidden_states

hidden_states = self.layernorm_before(hidden_states)

hidden_states = hidden_states.view(batch_size, height, width, channels)

# pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)

Expand Down Expand Up @@ -742,14 +744,15 @@ def forward(

hidden_states = layer_outputs[0]

hidden_states_before_downsampling = hidden_states
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(layer_outputs[0], input_dimensions)
hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
else:
output_dimensions = (height, width, height, width)

stage_outputs = (hidden_states, output_dimensions)
stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)

if output_attentions:
stage_outputs += layer_outputs[1:]
Expand Down Expand Up @@ -786,9 +789,9 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, SwinEncoderOutput]:
all_input_dimensions = ()
all_hidden_states = () if output_hidden_states else None
all_reshaped_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
Expand Down Expand Up @@ -819,12 +822,22 @@ def custom_forward(*inputs):
layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)

hidden_states = layer_outputs[0]
output_dimensions = layer_outputs[1]
hidden_states_before_downsampling = layer_outputs[1]
output_dimensions = layer_outputs[2]

input_dimensions = (output_dimensions[-2], output_dimensions[-1])
all_input_dimensions += (input_dimensions,)

if output_hidden_states:
if output_hidden_states and output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
# rearrange b (h w) c -> b c h w
# here we use the original (not downsampled) height and width
reshaped_hidden_state = hidden_states_before_downsampling.view(
batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
)
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states_before_downsampling,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
elif output_hidden_states and not output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states.shape
# rearrange b (h w) c -> b c h w
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
Expand All @@ -833,7 +846,7 @@ def custom_forward(*inputs):
all_reshaped_hidden_states += (reshaped_hidden_state,)

if output_attentions:
all_self_attentions += layer_outputs[2:]
all_self_attentions += layer_outputs[3:]

if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
Expand Down Expand Up @@ -1214,3 +1227,118 @@ def forward(
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
)


@add_start_docstrings(
"""
Swin backbone, to be used with frameworks like DETR and MaskFormer.
""",
SWIN_START_DOCSTRING,
)
class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
def __init__(self, config: SwinConfig):
super().__init__(config)

self.stage_names = config.stage_names

self.embeddings = SwinEmbeddings(config)
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)

self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]

num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.out_feature_channels = {}
self.out_feature_channels["stem"] = config.embed_dim
for i, stage in enumerate(self.stage_names[1:]):
self.out_feature_channels[stage] = num_features[i]

# Add layer norms to hidden states of out_features
hidden_states_norms = dict()
for stage, num_channels in zip(self.out_features, self.channels):
hidden_states_norms[stage] = nn.LayerNorm(num_channels)
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embeddings.patch_embeddings

@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]

def forward(
self,
pixel_values: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BackboneOutput:
"""
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, AutoBackbone
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
>>> model = AutoBackbone.from_pretrained(
... "microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"]
... )
>>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> feature_maps = outputs.feature_maps
>>> list(feature_maps[-1].shape)
[1, 768, 7, 7]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions

embedding_output, input_dimensions = self.embeddings(pixel_values)

outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=None,
output_attentions=output_attentions,
output_hidden_states=True,
output_hidden_states_before_downsampling=True,
return_dict=True,
)

hidden_states = outputs.reshaped_hidden_states

feature_maps = ()
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
batch_size, num_channels, height, width = hidden_state.shape
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
hidden_state = self.hidden_states_norms[stage](hidden_state)
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)

if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs.hidden_states,)
return output

return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
Loading

0 comments on commit 67acb07

Please sign in to comment.