Skip to content

Commit

Permalink
[DETA] fix backbone freeze/unfreeze function (huggingface#27843)
Browse files Browse the repository at this point in the history
* [DETA] fix freeze/unfreeze function

* Update src/transformers/models/deta/modeling_deta.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/deta/modeling_deta.py

Co-authored-by: Arthur <[email protected]>

* add freeze/unfreeze test case in DETA

* fix type

* fix typo 2

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
SangbumChoi and ArthurZucker authored Dec 11, 2023
1 parent df5c5c6 commit 235be08
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,14 +1414,12 @@ def get_encoder(self):
def get_decoder(self):
return self.decoder

# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.freeze_backbone
def freeze_backbone(self):
for name, param in self.backbone.conv_encoder.model.named_parameters():
for name, param in self.backbone.model.named_parameters():
param.requires_grad_(False)

# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.unfreeze_backbone
def unfreeze_backbone(self):
for name, param in self.backbone.conv_encoder.model.named_parameters():
for name, param in self.backbone.model.named_parameters():
param.requires_grad_(True)

# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_valid_ratio
Expand Down
28 changes: 28 additions & 0 deletions tests/models/deta/test_modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,26 @@ def create_and_check_deta_model(self, config, pixel_values, pixel_mask, labels):

self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.hidden_size))

def create_and_check_deta_freeze_backbone(self, config, pixel_values, pixel_mask, labels):
model = DetaModel(config=config)
model.to(torch_device)
model.eval()

model.freeze_backbone()

for _, param in model.backbone.model.named_parameters():
self.parent.assertEqual(False, param.requires_grad)

def create_and_check_deta_unfreeze_backbone(self, config, pixel_values, pixel_mask, labels):
model = DetaModel(config=config)
model.to(torch_device)
model.eval()

model.unfreeze_backbone()

for _, param in model.backbone.model.named_parameters():
self.parent.assertEqual(True, param.requires_grad)

def create_and_check_deta_object_detection_head_model(self, config, pixel_values, pixel_mask, labels):
model = DetaForObjectDetection(config=config)
model.to(torch_device)
Expand Down Expand Up @@ -250,6 +270,14 @@ def test_deta_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deta_model(*config_and_inputs)

def test_deta_freeze_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deta_freeze_backbone(*config_and_inputs)

def test_deta_unfreeze_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deta_unfreeze_backbone(*config_and_inputs)

def test_deta_object_detection_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deta_object_detection_head_model(*config_and_inputs)
Expand Down

0 comments on commit 235be08

Please sign in to comment.