Skip to content

Commit

Permalink
[Feature] Add ability to load HF checkpoints into T5 model (pytorch#1918
Browse files Browse the repository at this point in the history
)

* Add ability to load HF checkpoints into T5 model

* Add HuggingFace to integrations tests

* Remove duplicate code

* Revert fix

* Add setup

* Remove ability to download from remote URL

* Remove line break from docstring
  • Loading branch information
joecummings authored Oct 5, 2022
1 parent 3f9c349 commit de54db6
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/integration-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
run: |
python -m pip install --quiet --upgrade pip
python -m pip install --quiet --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
python -m pip install --quiet pytest requests cmake ninja sentencepiece parameterized tqdm expecttest
python -m pip install --quiet pytest requests cmake ninja sentencepiece parameterized tqdm expecttest transformers
python setup.py install
- name: Run integration test
run: |
Expand Down
110 changes: 109 additions & 1 deletion test/integration_tests/prototype/test_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import tempfile

import pytest # noqa: F401
import torch
from parameterized import parameterized, parameterized_class
Expand All @@ -14,11 +16,12 @@
T5Conf,
T5Transform,
)
from torchtext.prototype.models.t5.bundler import T5Bundle
from torchtext.prototype.models.t5.wrapper import T5Wrapper
from torchtext_unittest.common.assets import get_asset_path
from torchtext_unittest.common.parameterized_utils import nested_params
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase

from transformers import T5Model, T5EncoderModel, T5ForConditionalGeneration

BUNDLERS = {
"base_model": T5_BASE,
Expand Down Expand Up @@ -135,3 +138,108 @@ def test_t5_wrapper_checkpoint(self, name) -> None:

output_text = model(test_text, beam_size, max_seq_len)
self.assertEqual(output_text, expected_text)


class TestLoadFromHFCheckpoints(TorchtextTestCase):
def setUp(self) -> None:
super().setUp()
self.encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]])
self.encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]])
self.decoder_input_ids = torch.tensor([[7, 8, 9, 0, 0, 0], [10, 11, 12, 0, 0, 0]])
self.decoder_padding_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]])

def check_outputs_of_models(self, our_output, hf_output, config, encoder_only) -> None:
# check that encoder layers match
for i in range(config.num_encoder_layers + 1):
if i < config.num_encoder_layers:
hf_output_sa = hf_output.attentions[i] if encoder_only else hf_output.encoder_attentions[i]
# self-attention scores
assert torch.equal(
our_output["encoder_sa_scores"][i], hf_output_sa
), f"Mismatched self-attention scores for encoder layer {i}"
hf_output_hs = hf_output.hidden_states[i] if encoder_only else hf_output.encoder_hidden_states[i]
# encoder hidden states
assert torch.equal(
our_output["encoder_hidden_states"][i], hf_output_hs
), f"Mismatched hidden states for encoder layer {i}"

if not encoder_only:
# check that decoder layers match
for i in range(config.num_decoder_layers + 1):
if i < config.num_encoder_layers:
# self-attention scores
assert torch.equal(
our_output["decoder_sa_scores"][i], hf_output.decoder_attentions[i]
), f"Mismatched self-attention scores for decoder layer {i}"
# cross-attention scores
assert torch.equal(
our_output["decoder_ca_scores"][i], hf_output.cross_attentions[i]
), f"Mismatched cross-attention scores for decoder layer {i}"
# decoder hidden states
assert torch.equal(
our_output["decoder_hidden_states"][i], hf_output.decoder_hidden_states[i]
), f"Mismatched hidden states for decoder layer {i}"

def test_t5_bundler_load_hf_ckpt_pretrained_encoder_only(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
model_path = f"{tmp_dir}/hf_t5_small_enc"

t5_small_enc = T5EncoderModel.from_pretrained("t5-small")
t5_small_enc.save_pretrained(model_path)

our_encoder = T5Bundle.build_model_from_huggingface_ckpt(model_path)

hf_output = t5_small_enc(
input_ids=self.encoder_input_ids,
attention_mask=self.encoder_padding_mask,
output_hidden_states=True,
output_attentions=True,
)

our_output = our_encoder(self.encoder_input_ids)

self.check_outputs_of_models(our_output, hf_output, our_encoder.config, encoder_only=True)

def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
model_path = f"{tmp_dir}/hf_t5_small"

t5_small = T5Model.from_pretrained("t5-small")
t5_small.save_pretrained(model_path)

our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path)

hf_output = t5_small(
input_ids=self.encoder_input_ids,
decoder_input_ids=self.decoder_input_ids,
attention_mask=self.encoder_padding_mask,
decoder_attention_mask=self.decoder_padding_mask,
output_hidden_states=True,
output_attentions=True,
)

our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids)

self.check_outputs_of_models(our_output, hf_output, our_t5.config, encoder_only=False)

def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder_with_gen(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
model_path = f"{tmp_dir}/hf_t5_small_gen"

t5_small_gen = T5ForConditionalGeneration.from_pretrained("t5-small")
t5_small_gen.save_pretrained(model_path)

our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path)

hf_output = t5_small_gen(
input_ids=self.encoder_input_ids,
decoder_input_ids=self.decoder_input_ids,
attention_mask=self.encoder_padding_mask,
decoder_attention_mask=self.decoder_padding_mask,
output_hidden_states=True,
output_attentions=True,
)

our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids)

self.check_outputs_of_models(our_output, hf_output, our_t5.config, encoder_only=False)
134 changes: 134 additions & 0 deletions torchtext/prototype/models/t5/bundler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import logging
import os
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Union
from urllib.parse import urljoin
Expand Down Expand Up @@ -133,6 +135,138 @@ def build_model(

return model

@staticmethod
def build_model_from_huggingface_ckpt(
ckpt_path: Union[str, os.PathLike],
*,
freeze_model: bool = False,
strict: bool = True,
) -> T5Model:
"""Build T5Model model from a HuggingFace checkpoint.
Note: Only works with Huggingface models saved in the PyTorch format. Will not work with TensorFlow or JAX.
Args:
ckpt_path (str, Path): Path to the HF checkpoint file. Assumes that the file is local.
freeze_model (bool): Freeze the model upon loading. (Default: `False`)
strict (bool): Load model in strict mode. (Default: `True`)
Returns:
T5Model loaded with the weights of the HuggingFace checkpoint provided
"""
config_path = f"{ckpt_path}/config.json"
model_path = f"{ckpt_path}/pytorch_model.bin"

with open(config_path, "r") as handle:
config_json = json.load(handle)
hf_weights = torch.load(model_path)

# TODO(joecummings): find better way to determine `encoder_only` and `linear_head`
config = T5Conf(
encoder_only="decoder.final_layer_norm.weight" not in hf_weights.keys(),
linear_head="lm_head.weight" in hf_weights.keys(),
embedding_dim=config_json["d_model"],
num_attention_heads=config_json["num_heads"],
num_encoder_layers=config_json["num_layers"],
num_decoder_layers=config_json["num_decoder_layers"],
ffn_dimension=config_json["d_ff"],
)

t5_model = T5Model(config, freeze_model)

t5_model_state_dict = {
"token_embeddings.weight": hf_weights["shared.weight"],
"norm1.weight": hf_weights["encoder.final_layer_norm.weight"],
"encoder.layers.0.self_attn.relative_attention_bias.weight": hf_weights[
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
],
}
# Convert encoder layers
for i in range(config.num_encoder_layers):
t5_model_state_dict[f"encoder.layers.{i}.linear1.weight"] = hf_weights[
f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight"
]
t5_model_state_dict[f"encoder.layers.{i}.linear2.weight"] = hf_weights[
f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight"
]
t5_model_state_dict[f"encoder.layers.{i}.norm1.weight"] = hf_weights[
f"encoder.block.{i}.layer.0.layer_norm.weight"
]
t5_model_state_dict[f"encoder.layers.{i}.norm2.weight"] = hf_weights[
f"encoder.block.{i}.layer.1.layer_norm.weight"
]
t5_model_state_dict[f"encoder.layers.{i}.self_attn.out_proj.weight"] = hf_weights[
f"encoder.block.{i}.layer.0.SelfAttention.o.weight"
]
t5_model_state_dict[f"encoder.layers.{i}.self_attn.q_proj_weight"] = hf_weights[
f"encoder.block.{i}.layer.0.SelfAttention.q.weight"
]
t5_model_state_dict[f"encoder.layers.{i}.self_attn.k_proj_weight"] = hf_weights[
f"encoder.block.{i}.layer.0.SelfAttention.k.weight"
]
t5_model_state_dict[f"encoder.layers.{i}.self_attn.v_proj_weight"] = hf_weights[
f"encoder.block.{i}.layer.0.SelfAttention.v.weight"
]

# Convert decoder layers if model is encoder-decoder
if not config.encoder_only:
t5_model_state_dict["norm2.weight"] = hf_weights["decoder.final_layer_norm.weight"]
t5_model_state_dict["decoder.layers.0.self_attn.relative_attention_bias.weight"] = hf_weights[
"decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
]

for i in range(config.num_decoder_layers):
t5_model_state_dict[f"decoder.layers.{i}.linear1.weight"] = hf_weights[
f"decoder.block.{i}.layer.2.DenseReluDense.wi.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.linear2.weight"] = hf_weights[
f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.norm1.weight"] = hf_weights[
f"decoder.block.{i}.layer.0.layer_norm.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.norm2.weight"] = hf_weights[
f"decoder.block.{i}.layer.2.layer_norm.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.norm3.weight"] = hf_weights[
f"decoder.block.{i}.layer.1.layer_norm.weight"
]

t5_model_state_dict[f"decoder.layers.{i}.self_attn.out_proj.weight"] = hf_weights[
f"decoder.block.{i}.layer.0.SelfAttention.o.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.self_attn.q_proj_weight"] = hf_weights[
f"decoder.block.{i}.layer.0.SelfAttention.q.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.self_attn.k_proj_weight"] = hf_weights[
f"decoder.block.{i}.layer.0.SelfAttention.k.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.self_attn.v_proj_weight"] = hf_weights[
f"decoder.block.{i}.layer.0.SelfAttention.v.weight"
]

t5_model_state_dict[f"decoder.layers.{i}.cross_attn.out_proj.weight"] = hf_weights[
f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.cross_attn.q_proj_weight"] = hf_weights[
f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.cross_attn.k_proj_weight"] = hf_weights[
f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.cross_attn.v_proj_weight"] = hf_weights[
f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"
]

# Convert language modeling head if there is one
if config.linear_head:
t5_model_state_dict["lm_head.weight"] = hf_weights["lm_head.weight"]

# Load state dict into our model
t5_model.load_state_dict(t5_model_state_dict, strict)

return t5_model

@property
def config(self) -> T5Conf:
return self._config
Expand Down

0 comments on commit de54db6

Please sign in to comment.