From 06c5fcb619a0e35982c0d3eee9ec30907938bbc2 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 1 May 2024 23:08:07 +0100 Subject: [PATCH] Mistral testing (#888) Co-authored-by: ebsmothers --- .../llama2/scripts/compare_attention.py | 2 +- .../models/mistral/scripts/README.md | 13 + .../models/mistral/scripts/__init__.py | 5 + .../mistral/scripts/compare_feed_forward.py | 65 ++++ .../models/mistral/scripts/compare_mistral.py | 184 +++++++++++ .../mistral/scripts/mistral_reference.py | 285 ++++++++++++++++++ .../mistral/scripts/mistral_test_config.py | 23 ++ tests/torchtune/models/test_mistral.py | 49 +++ 8 files changed, 625 insertions(+), 1 deletion(-) create mode 100644 tests/torchtune/models/mistral/scripts/README.md create mode 100644 tests/torchtune/models/mistral/scripts/__init__.py create mode 100644 tests/torchtune/models/mistral/scripts/compare_feed_forward.py create mode 100644 tests/torchtune/models/mistral/scripts/compare_mistral.py create mode 100644 tests/torchtune/models/mistral/scripts/mistral_reference.py create mode 100644 tests/torchtune/models/mistral/scripts/mistral_test_config.py create mode 100644 tests/torchtune/models/test_mistral.py diff --git a/tests/torchtune/models/llama2/scripts/compare_attention.py b/tests/torchtune/models/llama2/scripts/compare_attention.py index 985bf5111..a140fb082 100644 --- a/tests/torchtune/models/llama2/scripts/compare_attention.py +++ b/tests/torchtune/models/llama2/scripts/compare_attention.py @@ -38,7 +38,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: class Attention(nn.Module): - def __init__(self, n_heads, n_kv_heads, dim): + def __init__(self, n_heads: int, n_kv_heads: int, dim: int): super().__init__() self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads self.n_heads = n_heads diff --git a/tests/torchtune/models/mistral/scripts/README.md b/tests/torchtune/models/mistral/scripts/README.md new file mode 100644 index 000000000..6b61ffa9f --- /dev/null +++ b/tests/torchtune/models/mistral/scripts/README.md @@ -0,0 +1,13 @@ +## Verifying correctness +This directory compares the current implementation of `mistral` to the reference implementation at https://github.com/mistralai/mistral-src/blob/main/one_file_ref.py. Additionally, `torchtune.models.mistral._component_builders.mistral_mlp` is compared in `tests.torchtune.models.mistral.scripts.compare_feed_forward.py` + +Since `torchtune.models.mistral` shares nearly all components with `torchtune.models.llama2`, please see `tests.torchtune.models.llama2.scripts` for comparison scripts for individual components. + +## Running the scripts + +You can run the scripts using the following command as an example. +Each script should print out the value being used in the associated unit tests. + +``` +python3 -m tests.torchtune.models.mistral.scripts.compare_mistral +``` diff --git a/tests/torchtune/models/mistral/scripts/__init__.py b/tests/torchtune/models/mistral/scripts/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/tests/torchtune/models/mistral/scripts/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/torchtune/models/mistral/scripts/compare_feed_forward.py b/tests/torchtune/models/mistral/scripts/compare_feed_forward.py new file mode 100644 index 000000000..baae85be3 --- /dev/null +++ b/tests/torchtune/models/mistral/scripts/compare_feed_forward.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from tests.test_utils import fixed_init_model +from tests.torchtune.models.mistral.scripts.mistral_reference import FeedForward + +from tests.torchtune.models.mistral.scripts.mistral_test_config import MistralTestConfig + +from torchtune.models.mistral._component_builders import mistral_mlp + + +def compare_feed_forward(embed_dim: int, intermediate_dim: int) -> None: + + # make sure we have the right seed for generating outputs + # this should match up the seed value set in the corresponding + # unit test + torch.manual_seed(MistralTestConfig.SEED) + + # generate input tensor used by both implementations + input_t = torch.randn(1, embed_dim) + + # reference implementation + ff_ref = FeedForward(dim=embed_dim, hidden_dim=intermediate_dim) + fixed_init_model(ff_ref) + + with torch.no_grad(): + ff_out_ref = ff_ref(input_t) + + ff = mistral_mlp(embed_dim, intermediate_dim) + fixed_init_model(ff) + + with torch.no_grad(): + ff_out = ff(input_t) + + torch.testing.assert_close(ff_out, ff_out_ref, atol=1e-5, rtol=1e-5) + print(f"ff_out.mean(): {ff_out.mean()}") + print(f"ff_out.max(): {ff_out.max()}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Compare FeedForward implementations") + parser.add_argument( + "--embed_dim", + type=int, + default=MistralTestConfig.EMBED_DIM, + help="Embedding dimension for self-attention", + ) + parser.add_argument( + "--intermediate_dim", + type=int, + default=MistralTestConfig.INTERMEDIATE_DIM, + help="Intermediate dimension for MLP", + ) + + args = parser.parse_args() + + compare_feed_forward(args.embed_dim, args.intermediate_dim) diff --git a/tests/torchtune/models/mistral/scripts/compare_mistral.py b/tests/torchtune/models/mistral/scripts/compare_mistral.py new file mode 100644 index 000000000..a861a2fd6 --- /dev/null +++ b/tests/torchtune/models/mistral/scripts/compare_mistral.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from tests.test_utils import fixed_init_model +from tests.torchtune.models.mistral.scripts.mistral_reference import Transformer +from tests.torchtune.models.mistral.scripts.mistral_test_config import MistralTestConfig + +from torchtune.models.mistral import mistral + + +def compare_decoder( + bsz: int, + vocab_size: int, + seq_len: int, + embed_dim: int, + intermediate_dim: int, + n_layers: int, + num_heads: int, + num_kv_heads: int, + max_seq_len: int, + rope_base: int, + norm_eps: float, +) -> None: + # make sure we have the right seed for generating outputs + # this should match up the seed value set in the corresponding + # unit test + torch.manual_seed(MistralTestConfig.SEED) + + head_dim = embed_dim // num_heads + + # generate input tensor used by both implementations + x_input = torch.randint(low=0, high=vocab_size, size=(bsz, seq_len)) + + # current implementation; initialize with constant to compare outputs + mistral_model = mistral( + vocab_size=vocab_size, + embed_dim=embed_dim, + num_layers=n_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + intermediate_dim=intermediate_dim, + norm_eps=norm_eps, + rope_base=rope_base, + ) + fixed_init_model(mistral_model) + + with torch.no_grad(): + mistral_model_out = mistral_model(x_input) + + # initialize reference implementation with constant weights + ref_mistral_model = Transformer( + vocab_size=vocab_size, + n_layers=n_layers, + n_heads=num_heads, + head_dim=head_dim, + dim=embed_dim, + n_kv_heads=num_kv_heads, + hidden_dim=intermediate_dim, + max_seq_len=max_seq_len, + rope_base=rope_base, + norm_eps=norm_eps, + ) + + mapped_sd = {} + for k, v in mistral_model.state_dict().items(): + new_k = k.replace("attn", "attention") + new_k = ( + new_k.replace("q_proj", "wq") + .replace("k_proj", "wk") + .replace("v_proj", "wv") + .replace("output_proj", "wo") + ) + new_k = new_k.replace("mlp", "feed_forward") + new_k = new_k.replace("feed_forward_norm.scale", "ffn_norm.weight") + new_k = new_k.replace("sa_norm.scale", "attention_norm.weight") + + new_k = new_k.replace("norm.scale", "norm.weight") + mapped_sd[new_k] = v + + ref_mistral_model.load_state_dict(mapped_sd) + + with torch.no_grad(): + red_mistral_model_out = ref_mistral_model(x_input, torch.arange(seq_len)) + + # # value: torch.tensor(18.2749) + print(f"mistral_model_out.mean(): {mistral_model_out.mean()}") + print(f"red_mistral_model_out.mean(): {red_mistral_model_out.mean()}") + + torch.testing.assert_close( + mistral_model_out, red_mistral_model_out, atol=1e-2, rtol=1e-2 + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Compare Decoder implementations") + parser.add_argument( + "--bsz", + type=int, + default=MistralTestConfig.BSZ, + help="Batch size of input tensor", + ) + parser.add_argument( + "--seq_len", + type=int, + default=MistralTestConfig.SEQ_LEN, + help="input sequence length", + ) + parser.add_argument( + "--vocab_size", + type=int, + default=MistralTestConfig.VOCAB_SIZE, + help="vocab size", + ) + parser.add_argument( + "--embed_dim", + type=int, + default=MistralTestConfig.EMBED_DIM, + help="Embedding dimension used to compute the dim for RopE", + ) + parser.add_argument( + "--intermediate_dim", + type=int, + default=MistralTestConfig.INTERMEDIATE_DIM, + help="Intermediate dimension for MLP", + ) + parser.add_argument( + "--num_layers", + type=int, + default=MistralTestConfig.NUM_LAYERS, + help="number of transformer layers", + ) + parser.add_argument( + "--num_heads", + type=int, + default=MistralTestConfig.NUM_HEADS, + help="Number of heads in the attention layer", + ) + parser.add_argument( + "--num_kv_heads", + type=int, + default=MistralTestConfig.NUM_KV_HEADS, + help="Number of key/value heads in the attention layer", + ) + parser.add_argument( + "--max_seq_len", + type=int, + default=MistralTestConfig.MAX_SEQ_LEN, + help="max sequence length", + ) + parser.add_argument( + "--norm_eps", + type=float, + default=MistralTestConfig.NORM_EPS, + help="RMSNorm epsilon", + ) + parser.add_argument( + "--rope_base", + type=float, + default=MistralTestConfig.ROPE_BASE, + help="Base for the rotary positional embeddings", + ) + args = parser.parse_args() + + compare_decoder( + args.bsz, + args.vocab_size, + args.seq_len, + args.embed_dim, + args.intermediate_dim, + args.num_layers, + args.num_heads, + args.num_kv_heads, + args.max_seq_len, + args.rope_base, + args.norm_eps, + ) diff --git a/tests/torchtune/models/mistral/scripts/mistral_reference.py b/tests/torchtune/models/mistral/scripts/mistral_reference.py new file mode 100644 index 000000000..79134aace --- /dev/null +++ b/tests/torchtune/models/mistral/scripts/mistral_reference.py @@ -0,0 +1,285 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch +from torch import nn + +""" +Reference mistral implementation from the official repo: +https://github.com/mistralai/mistral-src/blob/main/one_file_ref.py + +Components are copied here with minimal modifications. +""" + + +""" +Reference implementation of Attention from: +https://github.com/mistralai/mistral-src/blob/main/one_file_ref.py + +Note, there's another implementation in the same repo which uses xformers for attention: +https://github.com/mistralai/mistral-src/blob/8598cf582091a596671be31990448e0620017851/mistral/model.py#L60 + +The implementation for this test uses `one_file_ref.py` since the xformers attention implementation +expects the input `[b, s, ...]` to be flattened `[b * s, ...]` which makes comparison difficult. + +Replicating code here to minimize dependencies. The code is modified to +remove dependencies from xformers and features like KV Caching. +""" + + +def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int): + keys = torch.repeat_interleave(keys, repeats=repeats, dim=2) + values = torch.repeat_interleave(values, repeats=repeats, dim=2) + return keys, values + + +class Attention(nn.Module): + def __init__(self, n_heads: int, head_dim: int, dim: int, n_kv_heads: int): + super().__init__() + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + + self.repeats = self.n_heads // self.n_kv_heads + + self.scale = head_dim**-0.5 + + self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) + self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wo = nn.Linear(n_heads * head_dim, dim, bias=False) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # removed positions as it was only used for cache retrieval + bsz, seqlen, _ = x.shape + + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + key, value = repeat_kv(xk, xv, self.repeats) + + query = xq.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + # scores : [bsz, n_heads, seqlen | 1, seqlen] + scores = torch.matmul(query, key.transpose(2, 3)) * self.scale + print(scores.mean()) + if mask is not None: + scores += mask[None, None, ...] + print(scores.mean()) + scores = scores.float() + scores = nn.functional.softmax(scores, dim=-1).type_as(query) + output = torch.matmul(scores, value) # (bs, n_local_heads, slen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + + return self.wo(output) + + +""" +Reference implementation of RoPE from: +https://github.com/mistralai/mistral-src/blob/8598cf582091a596671be31990448e0620017851/one_file_ref.py#L47 + +The original code structures this as stand-alone functions instead of +a class. Replicating code here to minimize dependencies. +""" + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + return torch.polar(torch.ones_like(freqs), freqs) # complex64 + + +def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + freqs_cis: complex - (seq_len, head_dim / 2) + x: complex - (bsz, seq_len, head_dim / 2) + """ + ndim = x.ndim + assert 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( + freqs_cis.shape, + (x.shape[1], x.shape[-1]), + ) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = _reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +""" +Reference impementation of FeedForward from: +https://github.com/mistralai/mistral-src/blob/8598cf582091a596671be31990448e0620017851/one_file_ref.py#L152 + +The original code structures this as stand-alone functions in +`torchtune.models.mistral._component_builders.mistral_mlp` instead of +a standalone class. +""" + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x) -> torch.Tensor: + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +""" +Reference implementation of TransformerBlock from: +https://github.com/mistralai/mistral-src/blob/8598cf582091a596671be31990448e0620017851/one_file_ref.py#L190 +""" + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class TransformerBlock(nn.Module): + def __init__( + self, + n_heads: int, + head_dim: int, + dim: int, + n_kv_heads: int, + hidden_dim: int, + norm_eps: float, + ): + super().__init__() + self.n_heads = n_heads + self.dim = dim + self.attention = Attention( + n_heads=n_heads, head_dim=head_dim, dim=dim, n_kv_heads=n_kv_heads + ) + self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim) + self.attention_norm = RMSNorm(dim=dim, eps=norm_eps) + self.ffn_norm = RMSNorm(dim, eps=norm_eps) + + def forward( + self, x: torch.Tensor, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor] + ) -> torch.Tensor: + r = self.attention.forward(self.attention_norm(x), freqs_cis, mask) + h = x + r + r = self.feed_forward.forward(self.ffn_norm(h)) + out = h + r + return out + + +""" +Reference implementation of Transformer from: +https://github.com/mistralai/mistral-src/blob/8598cf582091a596671be31990448e0620017851/one_file_ref.py#L217 +""" + + +class Transformer(nn.Module): + def __init__( + self, + vocab_size: int, + n_layers: int, + n_heads: int, + head_dim: int, + dim: int, + n_kv_heads: int, + hidden_dim: int, + max_seq_len: int, + rope_base: int, + norm_eps: float, + ): + super().__init__() + self.vocab_size = vocab_size + self.n_layers = n_layers + assert self.vocab_size > 0 + + self.tok_embeddings = nn.Embedding(vocab_size, dim) + + self.layers = torch.nn.ModuleList( + [ + TransformerBlock( + n_heads=n_heads, + head_dim=head_dim, + dim=dim, + n_kv_heads=n_kv_heads, + hidden_dim=hidden_dim, + norm_eps=norm_eps, + ) + for _ in range(n_layers) + ] + ) + + self.norm = RMSNorm(dim, eps=norm_eps) + + self.output = nn.Linear(dim, vocab_size, bias=False) + + # our RoPE implementation is a bit different from the reference: + # mistral hardcodes max_seq_len and uses a `positions` argument + # in forward to index `freqs_cis` for the current sequence length + # before using it in the attention layer. + + self.freqs_cis = precompute_freqs_cis( + head_dim, max_seq_len * 2, theta=rope_base + ) # removed .to("cuda") + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor): + _, seqlen = input_ids.shape + h = self.tok_embeddings(input_ids) + freqs_cis = self.freqs_cis[positions] + mask: Optional[torch.Tensor] = None + if input_ids.shape[1] > 1: + seqlen = input_ids.shape[1] + tensor = torch.full( + (seqlen, seqlen), + dtype=h.dtype, + fill_value=1, + device=h.device, + ) + mask = torch.tril(tensor, diagonal=0).to(h.dtype) + # removed mask banding + mask = torch.triu(mask, diagonal=-1) # setting sliding window to 1 + mask = torch.log(mask) + for layer in self.layers: + h = layer(h, freqs_cis, mask) + + return self.output(self.norm(h)).float() diff --git a/tests/torchtune/models/mistral/scripts/mistral_test_config.py b/tests/torchtune/models/mistral/scripts/mistral_test_config.py new file mode 100644 index 000000000..0084d3dcc --- /dev/null +++ b/tests/torchtune/models/mistral/scripts/mistral_test_config.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + + +@dataclass +class MistralTestConfig: + BSZ = 2 + SEQ_LEN = 128 + EMBED_DIM = 64 + VOCAB_SIZE = 512 + NUM_LAYERS = 4 + NUM_HEADS = 4 + NUM_KV_HEADS = 2 + INTERMEDIATE_DIM = 512 + MAX_SEQ_LEN = 256 + ROPE_BASE = 10000 + NORM_EPS = 1e-5 + SEED = 16 diff --git a/tests/torchtune/models/test_mistral.py b/tests/torchtune/models/test_mistral.py new file mode 100644 index 000000000..9fb5505fc --- /dev/null +++ b/tests/torchtune/models/test_mistral.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import fixed_init_model +from tests.torchtune.models.mistral.scripts.mistral_test_config import MistralTestConfig +from torchtune.models.mistral import mistral +from torchtune.utils.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(MistralTestConfig.SEED) + + +class TestMistral: + @pytest.fixture + def inputs(self): + return torch.randint( + 0, + MistralTestConfig.VOCAB_SIZE, + (MistralTestConfig.BSZ, MistralTestConfig.SEQ_LEN), + ) + + def test_forward(self, inputs): + model = mistral( + vocab_size=MistralTestConfig.VOCAB_SIZE, + embed_dim=MistralTestConfig.EMBED_DIM, + num_heads=MistralTestConfig.NUM_HEADS, + num_layers=MistralTestConfig.NUM_LAYERS, + num_kv_heads=MistralTestConfig.NUM_KV_HEADS, + max_seq_len=MistralTestConfig.MAX_SEQ_LEN, + intermediate_dim=MistralTestConfig.INTERMEDIATE_DIM, + norm_eps=MistralTestConfig.NORM_EPS, + rope_base=MistralTestConfig.ROPE_BASE, + ) + fixed_init_model(model) + actual = model(inputs) + expected = torch.tensor(18.2749) + assert actual.shape == ( + MistralTestConfig.BSZ, + MistralTestConfig.SEQ_LEN, + MistralTestConfig.VOCAB_SIZE, + ) + torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4)