Skip to content

Commit

Permalink
Fix weights not properly initialized due to shape mismatch (huggingfa…
Browse files Browse the repository at this point in the history
…ce#28122)

* fix

---------

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Dec 20, 2023
1 parent 769a954 commit 7938c8c
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3957,13 +3957,14 @@ def _fix_key(key):

# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
if remove_prefix_from_model:
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
elif add_prefix_to_model:
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
else:
_loaded_keys = loaded_keys
set_initialized_submodules(model, _loaded_keys)
if not ignore_mismatched_sizes:
if remove_prefix_from_model:
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
elif add_prefix_to_model:
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
else:
_loaded_keys = loaded_keys
set_initialized_submodules(model, _loaded_keys)
# This will only initialize submodules that are not marked as initialized by the line above.
model.apply(model._initialize_weights)

Expand Down
104 changes: 104 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2889,6 +2889,110 @@ def test_load_with_mismatched_shapes(self):
else:
new_model_without_prefix(input_ids)

def test_mismatched_shapes_have_properly_initialized_weights(self):
if not self.test_mismatched_shapes:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

configs_no_init = _config_zero_init(config)

for model_class in self.all_model_classes:
if model_class.__name__ not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
continue

with self.subTest(msg=f"Testing {model_class}"):
with tempfile.TemporaryDirectory() as tmp_dir:
model = model_class(configs_no_init)
model.save_pretrained(tmp_dir)

# Fails when we don't set ignore_mismatched_sizes=True
with self.assertRaises(RuntimeError):
new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)

logger = logging.get_logger("transformers.modeling_utils")

with CaptureLogger(logger) as cl:
new_model = AutoModelForSequenceClassification.from_pretrained(
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
)
self.assertIn("the shapes did not match", cl.out)

for name, param in new_model.named_parameters():
if param.requires_grad:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)

def test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist(self):
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
class MyClass(PreTrainedModel):
config_class = PretrainedConfig

def __init__(self, config=None):
super().__init__(config if config is not None else PretrainedConfig())
self.linear = nn.Linear(10, config.num_labels, bias=True)
self.embedding = nn.Embedding(10, 10)
self.std = 1

def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5))
if module.bias is not None:
module.bias.data = module.bias.data.normal_(mean=0.0, std=self.std)

# Used to make sure the weights with matched shape are loaded correctly
config = PretrainedConfig()
config.num_labels = 3
model = MyClass(config=config)

# Used to make sure the weights with mismatched shape are properly initialized
set_seed(0)
config = PretrainedConfig()
config.num_labels = 4
# not to init. the weights during the creation: to match the logic in `from_pretrained`, so we can keep the
# same sequence of random ops in the execution path to allow us to compare `target_model` and `new_model` below
# for `linear` part.
with ContextManagers([no_init_weights(True)]):
target_model = MyClass(config=config)
target_model.apply(target_model._initialize_weights)

with tempfile.TemporaryDirectory() as tmpdirname:
state_dict = model.state_dict()
del state_dict["linear.weight"]

model.config.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))

set_seed(0)
new_model = MyClass.from_pretrained(tmpdirname, num_labels=4, ignore_mismatched_sizes=True)

for key in new_model.state_dict().keys():
# check weight values for weights with matched shapes are identical
# (i.e. correctly loaded from the checkpoint)
if key not in ["linear.weight", "linear.bias"]:
max_diff = torch.max(torch.abs(model.state_dict()[key] - new_model.state_dict()[key]))
self.assertLessEqual(
max_diff.item(),
1e-6,
msg=f"the weight values for `{key}` in `new_model` and `model` are not identical",
)
else:
# check we have some mismatched shapes
self.assertNotEqual(
model.state_dict()[key].shape,
new_model.state_dict()[key].shape,
msg=f"the weight shapes for {key} in `model` and `new_model` should differ",
)
# check the weights with mismatched shape are properly initialized
max_diff = torch.max(torch.abs(new_model.state_dict()[key] - target_model.state_dict()[key]))
self.assertLessEqual(
max_diff.item(),
1e-6,
msg=f"the weight values for `{key}` in `new_model` and `target_model` are not identical",
)

def test_model_is_small(self):
# Just a consistency check to make sure we are not running tests on 80M parameter models.
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down

0 comments on commit 7938c8c

Please sign in to comment.