Skip to content

Commit

Permalink
Improve TF weight loading, especially PT crossloading (huggingface#21792
Browse files Browse the repository at this point in the history
)

* First commit for the improved PT-TF weight loading

* Remove workarounds from TFEncoderDecoder tests

* Allow a custom weight renaming function in from_pretrained and use that to clean up EncoderDecoder

* make fixup

* First attempt at visionencoderdecoder

* Disable tensorfloat32 in tests to get consistent outputs

* Quick fix to tf_vision_encoder_decoder tests

* make fixup

* Update Blenderbot tests

* Remove unused arg in modeling_tf_opt

* load_tf_sharded_weights had strict=True! This meant transfer learning was impossible, so I'm setting it to False.

* Support prefixes when loading sharded TF checkpoints

* make fixup

* Add test to load sharded models with a weight prefix

* Fix sharded weight loading test

* Add a test for transfer from a sharded checkpoint

* make fixup

* Add test to check that crossloading from PT with a prefix works

* Refactor from_pretrained in the encoderdecoder classes

* Refactor from_pretrained in the encoderdecoder classes

* missmatched -> mismatched

* Explicitly check for None

* No comments showing my very impressive and attractive knowledge of Py3.9+

* Disable TF32 across all TF tests
  • Loading branch information
Rocketknight1 authored Feb 28, 2023
1 parent 871c31a commit acfb714
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 154 deletions.
58 changes: 50 additions & 8 deletions src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class TransposeType(ExplicitEnum):
CONV2D = "conv2d"


def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="", tf_weight_shape=None):
def convert_tf_weight_name_to_pt_weight_name(
tf_name, start_prefix_to_remove="", tf_weight_shape=None, name_scope=None
):
"""
Convert a TF 2.0 model variable name in a pytorch model weight name.
Expand All @@ -54,6 +56,14 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="",
- transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be
transposed with regards to each other
"""
if name_scope is not None:
if not tf_name.startswith(name_scope):
raise ValueError(
f"Weight name {tf_name} does not start with name_scope {name_scope}. This is an internal error "
"in Transformers, so (unless you were doing something really evil) please open an issue to report it!"
)
tf_name = tf_name[len(name_scope) :]
tf_name = tf_name.lstrip("/")
tf_name = tf_name.replace(":0", "") # device ids
tf_name = re.sub(
r"/[^/]*___([^/]*)/", r"/\1/", tf_name
Expand Down Expand Up @@ -144,7 +154,13 @@ def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf


def load_pytorch_checkpoint_in_tf2_model(
tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
tf_model,
pytorch_checkpoint_path,
tf_inputs=None,
allow_missing_keys=False,
output_loading_info=False,
_prefix=None,
tf_to_pt_weight_rename=None,
):
"""Load pytorch checkpoints in a TF 2.0 model"""
try:
Expand Down Expand Up @@ -176,6 +192,8 @@ def load_pytorch_checkpoint_in_tf2_model(
tf_inputs=tf_inputs,
allow_missing_keys=allow_missing_keys,
output_loading_info=output_loading_info,
_prefix=_prefix,
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
)


Expand All @@ -189,7 +207,13 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_mi


def load_pytorch_weights_in_tf2_model(
tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
tf_model,
pt_state_dict,
tf_inputs=None,
allow_missing_keys=False,
output_loading_info=False,
_prefix=None,
tf_to_pt_weight_rename=None,
):
"""Load pytorch state_dict in a TF 2.0 model."""
try:
Expand All @@ -209,11 +233,19 @@ def load_pytorch_weights_in_tf2_model(
tf_inputs=tf_inputs,
allow_missing_keys=allow_missing_keys,
output_loading_info=output_loading_info,
_prefix=_prefix,
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
)


def load_pytorch_state_dict_in_tf2_model(
tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
tf_model,
pt_state_dict,
tf_inputs=None,
allow_missing_keys=False,
output_loading_info=False,
_prefix=None,
tf_to_pt_weight_rename=None,
):
"""Load a pytorch state_dict in a TF 2.0 model."""
import tensorflow as tf
Expand All @@ -227,8 +259,11 @@ def load_pytorch_state_dict_in_tf2_model(
if tf_inputs is None:
tf_inputs = tf_model.dummy_inputs

if _prefix is None:
_prefix = ""
if tf_inputs is not None:
tf_model(tf_inputs, training=False) # Make sure model is built
with tf.name_scope(_prefix):
tf_model(tf_inputs, training=False) # Make sure model is built
# Adapt state dict - TODO remove this and update the AWS weights files instead
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
Expand All @@ -249,8 +284,10 @@ def load_pytorch_state_dict_in_tf2_model(
for old_key, new_key in zip(old_keys, new_keys):
pt_state_dict[new_key] = pt_state_dict.pop(old_key)

# Make sure we are able to load PyTorch base models as well as derived models (with heads)
# TF models always have a prefix, some of PyTorch models (base ones) don't
# Matt: All TF models store the actual model stem in a MainLayer class, including the base model.
# In PT, the derived models (with heads) use the base model class as the stem instead, and the base model
# just contains the stem itself, and there is no MainLayer class. This means that TF base classes have one
# extra layer in their weight names, corresponding to the MainLayer class. This code block compensates for that.
start_prefix_to_remove = ""
if not any(s.startswith(tf_model.base_model_prefix) for s in pt_state_dict.keys()):
start_prefix_to_remove = tf_model.base_model_prefix + "."
Expand All @@ -263,8 +300,13 @@ def load_pytorch_state_dict_in_tf2_model(
for symbolic_weight in symbolic_weights:
sw_name = symbolic_weight.name
name, transpose = convert_tf_weight_name_to_pt_weight_name(
sw_name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=symbolic_weight.shape
sw_name,
start_prefix_to_remove=start_prefix_to_remove,
tf_weight_shape=symbolic_weight.shape,
name_scope=_prefix,
)
if tf_to_pt_weight_rename is not None:
name = tf_to_pt_weight_rename(name)

# Find associated numpy array in pytorch model state dict
if name not in pt_state_dict:
Expand Down
58 changes: 38 additions & 20 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ def tf_shard_checkpoint(weights, max_shard_size="10GB"):
return shards, index


def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=True):
def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None):
"""
This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load
the TF weights from the shard file accordingly to their names and shapes.
Expand All @@ -729,32 +729,35 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
"""

# Load the index
missing_keys = []
unexpected_keys = set()
saved_keys = set()
missmatched_keys = set()
mismatched_keys = set()

# Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
# the weight, we have to get rid of the first prefix of the name of the layer.
model_keys = set()
model_layer_map = {}
for i, k in enumerate(model.weights):
if "model." in k.name or len(k.name.split("/")) == 1:
layer_name = k.name
else:
layer_name = "/".join(k.name.split("/")[1:])
layer_name = k.name
if _prefix is not None and layer_name.startswith(_prefix):
layer_name = layer_name[len(_prefix) :]
layer_name = layer_name.lstrip("/")
if not ("model." in layer_name or len(layer_name.split("/")) == 1):
layer_name = "/".join(layer_name.split("/")[1:])
model_keys.add(layer_name)
model_layer_map[layer_name] = i

for shard_file in shard_files:
state_dict = tf.io.read_file(shard_file)
saved_weight_names_set, unexpected_keys_set, missmatched_keys_set = load_tf_shard(
model, model_layer_map, shard_file, ignore_mismatched_sizes=ignore_mismatched_sizes
saved_weight_names_set, unexpected_keys_set, mismatched_keys_set = load_tf_shard(
model,
model_layer_map,
shard_file,
ignore_mismatched_sizes=ignore_mismatched_sizes,
_prefix=_prefix,
)
saved_keys.update(saved_weight_names_set)
unexpected_keys.update(unexpected_keys_set)
missmatched_keys.update(missmatched_keys_set)
del state_dict
mismatched_keys.update(mismatched_keys_set)
gc.collect()

missing_keys = model_keys - saved_keys
Expand All @@ -768,10 +771,10 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)

return missing_keys, unexpected_keys, missmatched_keys
return missing_keys, unexpected_keys, mismatched_keys


def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False):
def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
"""
Loads a shard from a sharded checkpoint file. Handles the missing keys and unexpected keys.
Expand All @@ -783,11 +786,11 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch
Returns:
`tf.keras.models.Model`: Three lists, one for the layers that were found and succesfully restored (from the
shard file), one for the missmatched layers, and another one for the unexpected layers.
shard file), one for the mismatched layers, and another one for the unexpected layers.
"""
saved_weight_names_set = set()
saved_weights = {}
missmatched_keys = set()
mismatched_keys = set()
unexpected_keys = set()
# Read the H5 file
try:
Expand Down Expand Up @@ -822,7 +825,7 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
except ValueError as e:
if ignore_mismatched_sizes:
missmatched_keys.add(
mismatched_keys.add(
(layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
)
continue
Expand All @@ -836,7 +839,7 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch

K.batch_set_value(weight_value_tuples)

return saved_weight_names_set, unexpected_keys, missmatched_keys
return saved_weight_names_set, unexpected_keys, mismatched_keys

except Exception as e:
try:
Expand Down Expand Up @@ -2458,6 +2461,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
tf_to_pt_weight_rename (`Callable`, *optional*):
A function that is called to transform the names of weights during the PyTorch to TensorFlow
crossloading process. This is not necessary for most models, but is useful to allow composite models to
be crossloaded correctly.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
Expand Down Expand Up @@ -2506,6 +2513,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
from_auto_class = kwargs.pop("_from_auto", False)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None)

if trust_remote_code is True:
logger.warning(
Expand Down Expand Up @@ -2745,7 +2753,12 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):

# Load from a PyTorch checkpoint
return load_pytorch_checkpoint_in_tf2_model(
model, resolved_archive_file, allow_missing_keys=True, output_loading_info=output_loading_info
model,
resolved_archive_file,
allow_missing_keys=True,
output_loading_info=output_loading_info,
_prefix=load_weight_prefix,
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
)

# we might need to extend the variable scope for composite models
Expand All @@ -2761,7 +2774,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
state_dict = safe_load_file(resolved_archive_file)
# Load from a PyTorch checkpoint
return load_pytorch_state_dict_in_tf2_model(
model, state_dict, allow_missing_keys=True, output_loading_info=output_loading_info
model,
state_dict,
allow_missing_keys=True,
output_loading_info=output_loading_info,
_prefix=load_weight_prefix,
)

# 'by_name' allow us to do transfer learning by skipping/adding layers
Expand All @@ -2775,6 +2792,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
model,
resolved_archive_file,
ignore_mismatched_sizes=ignore_mismatched_sizes,
_prefix=load_weight_prefix,
)
else:
missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
""" Classes to support TF Encoder-Decoder architectures"""


import gc
import os
import tempfile
import re
import warnings
from typing import Optional, Tuple, Union

Expand Down Expand Up @@ -306,46 +304,23 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
>>> model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")
```"""

from_pt = kwargs.pop("from_pt", False)
if from_pt:
import torch

from transformers import EncoderDecoderModel

# a workaround to load from pytorch checkpoint
_model = EncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
config = _model.config

with tempfile.TemporaryDirectory() as tmpdirname:
encoder_dir = os.path.join(tmpdirname, "encoder")
decoder_dir = os.path.join(tmpdirname, "decoder")
_model.encoder.save_pretrained(encoder_dir)
_model.decoder.save_pretrained(decoder_dir)

if hasattr(_model, "enc_to_dec_proj"):
enc_to_dec_proj_kernel = tf.transpose(
tf.constant(_model.enc_to_dec_proj.weight.detach().to("cpu").numpy()), perm=(1, 0)
)
enc_to_dec_proj_bias = tf.constant(_model.enc_to_dec_proj.bias.detach().to("cpu").numpy())

del _model
gc.collect()
torch.cuda.empty_cache()

model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_dir, decoder_dir, encoder_from_pt=True, decoder_from_pt=True
)
# This is only for copying some specific attributes of this particular model.
model.config = config

if hasattr(model, "enc_to_dec_proj"):
model(model.dummy_inputs)
model.enc_to_dec_proj.kernel.assign(enc_to_dec_proj_kernel)
model.enc_to_dec_proj.bias.assign(enc_to_dec_proj_bias)

return model

# Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models
# (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal.
# However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption
# here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's
# not the case, and I wasn't sure how else to go from the config to the correct MainLayer name!

if kwargs.get("from_pt", False):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
encoder_model_type = config.encoder.model_type

def tf_to_pt_weight_rename(tf_weight):
if "encoder" in tf_weight and "decoder" not in tf_weight:
return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight)
else:
return tf_weight

kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

@classmethod
Expand Down Expand Up @@ -451,14 +426,6 @@ def from_encoder_decoder_pretrained(
kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)

# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
if kwargs_encoder.get("from_pt", None):
del kwargs_encoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder.save_pretrained(tmp_dirname)
del encoder
encoder = TFAutoModel.from_pretrained(tmp_dirname, *model_args, **kwargs_encoder)

decoder = kwargs_decoder.pop("model", None)
if decoder is None:
if decoder_pretrained_model_name_or_path is None:
Expand Down Expand Up @@ -493,14 +460,6 @@ def from_encoder_decoder_pretrained(
kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)

# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
if kwargs_decoder.get("from_pt", None):
del kwargs_decoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname:
decoder.save_pretrained(tmp_dirname)
del decoder
decoder = TFAutoModelForCausalLM.from_pretrained(tmp_dirname, **kwargs_decoder)

# Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
if encoder.name != "encoder":
raise ValueError("encoder model must be created with the name `encoder`.")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/opt/modeling_tf_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def serving(self, inputs):
class TFOPTDecoder(tf.keras.layers.Layer):
config_class = OPTConfig

def __init__(self, config: OPTConfig, load_weight_prefix=None, **kwargs):
def __init__(self, config: OPTConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.padding_idx = config.pad_token_id
Expand Down
Loading

0 comments on commit acfb714

Please sign in to comment.