Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras symbolic inputs/outputs do not implement __len__ #22

Closed
federicoparra opened this issue Feb 5, 2024 · 8 comments
Closed

Keras symbolic inputs/outputs do not implement __len__ #22

federicoparra opened this issue Feb 5, 2024 · 8 comments

Comments

@federicoparra
Copy link

Hi! your library is amazing, thank you so much!
I'm trying to convert this LLM https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b from pytorch to tensorflow, as usual the input is dynamic, when I run:

keras_model = nobuco.pytorch_to_keras(
model,
args=[padded_input], kwargs=None,
inputs_channel_order=ChannelOrder.TENSORFLOW,
outputs_channel_order=ChannelOrder.TENSORFLOW
)

the conversion works flawlessly and the resulting keras model produces the same result as the original model, except that the input size of the keras model is fixed to whatever the size of the padded_input was.

If instead I run the conversion like so:
keras_model = nobuco.pytorch_to_keras(
model,
args=[padded_input], kwargs=None,
input_shapes={padded_input: (1, None)},
trace_shape=True,
inputs_channel_order=ChannelOrder.TENSORFLOW,
outputs_channel_order=ChannelOrder.TENSORFLOW
)

then it crashes towards the end of the conversion with error:
TypeError: Keras symbolic inputs/outputs do not implement __len__. You may be trying to pass Keras symbolic inputs/outputs to a TF API that does not register dispatching, preventing Keras from automatically converting the API call to a lambda layer in the Functional Model. This error will also get raised if you try asserting a symbolic input/output directly.

Any pointers of what the problem might be?

@AlexanderLutsenko
Copy link
Owner

AlexanderLutsenko commented Feb 6, 2024

Hi, thanks! May I see the full script?

@federicoparra
Copy link
Author

federicoparra commented Feb 6, 2024

Hey, I created a google colab so that you can check it, I included a cell where Nobuco is executed to produce a fixed size keras model (works fine) and then another cell where i tries to convert with dynamic size (crashes)
https://colab.research.google.com/drive/1vxSotUx_tfAl1gjEdsg542SKRlL7ypBV?usp=sharing

NOTE: you need a premium account to run it because you need HIGH RAM machine (CPU only)

@AlexanderLutsenko
Copy link
Owner

AlexanderLutsenko commented Feb 6, 2024

Okay, this one is... complicated.

  • First, it turned out my implementation of torch.Tensor.expand was a bit too dynamic for Tensorflow. Got it fixed in v0.11.5.

  • Second, "Subgraph disconnected" warnings for RotaryEmbedding modules.
    r1
    These mean some of module's outputs do not depend of its inputs. In our case, the positional vectors will always be of the same size, which is obviously not how it should work for dynamic inputs. Let's follow the links Nobuco provided and have a look inside:

    class RotaryEmbedding(nn.Module):
        def __init__(
            self,
            dim: int,
            max_position_embeddings: int,
            base: int = 10_000,
            device: Optional[torch.device] = None,
        ):
            super().__init__()
    
            self.dim = dim
            self.max_position_embeddings = max_position_embeddings
            self.base = base
            inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
            self.register_buffer("inv_freq", inv_freq, persistent=False)
    
            # Build here to make `torch.jit.trace` work.
            self._set_cos_sin_cache(
                seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype(),
            )
    
        def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
    
            # Don't do einsum, it converts fp32 to fp16 under AMP
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            freqs = torch.outer(t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
    
        def forward(self, x: torch.Tensor, seq_len: Optional[int] = None):
            # x: [batch_size, num_heads, seq_len, head_size]
            if seq_len > self.max_seq_len_cached:
                self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.get_default_dtype())
            return (
                self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
                self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            )

    The module is fine, it cuts cached tensors to appropriate size. But the __getitem__ operation ([:, :, :seq_len, ...]) is not captured. The reason for that is a bit technical. I use an unorthodox method to capture the execution graph: wrapping Pytorch ops with decorators. It does not always work, unfortunately. Putting a decorator on top of torch.Tensor.__getitem__ may sometimes break the internals of Pytorch. For that one, I have to follow the recommended path via Tensor subclassing. I currently do not subclass tensors created inside modules, so __getitem__ is not recorded there.

    Until I fix it somehow, one possible workaround is to prepend [] with another (decoratable) operation, for example, detach.

    return (
        self.cos_cached.detach()[:, :, :seq_len, ...].to(dtype=x.dtype),
        self.sin_cached.detach()[:, :, :seq_len, ...].to(dtype=x.dtype),
    )

    r2

  • Finally, there are numerous bugs in Keras 2 which prevent some models from being serialized/deserialized [1, 2]. Many of them are fixed in Keras 3, but Tensorflow has not yet switched to it.
    If you encounter those issues, I'd recommend sticking to SavedModel format and load the model with Tensorflow directly.

    keras_model.save(model_path)
    
    keras_model_restored = tf.saved_model.load(model_path)

@federicoparra
Copy link
Author

First of all, THANK YOU!

I went ahead and made modifications to address your points in the same colab https://colab.research.google.com/drive/1vxSotUx_tfAl1gjEdsg542SKRlL7ypBV?usp=sharing

  • I used pip install from github source to get last version of library.
  • I introduced a cell with code to programatically modify the file modeling_stablelm_epoch.py when it is loaded from HF and add the detach() operators.
  • I saved and reload model using your instructions above.

With the first two modifications I am happy to let you know the model now works dynamically just fine and moreover it produces the EXACT same result as the original pytorch couterpart - VICTORY!

Now, I would like to ask you for your help with something that is probably more complex: key-value caching.

This model accepts in fact a forth parameter, past_key_values.
When given to the model, it can generate the next token with 20x fold speed.

I see a few problems to even converting that aspect of the model:

  1. in the original pytorch model, past_key_values is not a simple tensor but instead a series of embedded tuples and tensors so I can't easily think what kind of shape I should declare it with to nobuco
  2. the model behaves different when this fourth parameter is given, vs. when it is not.
    In fact, in typical usage, when you send your prompt to the LLM to begin generating you pass past_key_values = None, and the model reacts to this by calculating the attention values for all words in the prompt. The model returns not only the logits but also the past_key_values.
    After this first pass and for every other generated word in that generation loop, we don't pass as input the whole prompt plus the last word generated; instead, we only pass the last word generated (as first parameter) but as 4th parameter we pass the past_key_values from the previous run.

Technically this means:
a) the pytorch model must have some mechanism to detect if past_key_values = None (if the forth parameter is None) and do stuff in that case, and to do different things in the case the forth parameter is actually the variable containing the previous run key-values
b) in the case fourth parameter is None the model expects the input (first parameter) which is the prompt to be of any size but when the fourth parameter is actual past_key_values from a previous run then the input (first parameter) is expected to be just one token (the last one that was produced).

Anyways I realize my explanation is not clear and also that if you don't have a background on HF transformers it would be difficult to follow.

Perhaps the first pointer you can give me is: in the colab, right before the nobuco conversion, when we generate the inputs for conversion you'll see I call the pytorch model to create example past_key_values. If you inspect that variable you'll notice it's composed of tuples some of which have tensors. If you could tell me how to pass that variable as fourth parameter to the conversion (what shape to declare for it, that is, taking into account that one of the dimensions is dynamic) then that would be a phenomenal start.

Thank you so much!

Federico

@AlexanderLutsenko
Copy link
Owner

AlexanderLutsenko commented Feb 7, 2024

Tensorflow only accepts tensors as inputs, so None is not allowed. We can, however, represent empty past_key_values as a series of tensors of size 0, i.e.

past_key_values = []
for i in range(24):
    k = torch.zeros(size=(1, 32, 0, 64))
    v = torch.zeros(size=(1, 32, 0, 64))
    past_key_values.append((k, v))
past_key_values = tuple(past_key_values)

Also keep in mind that input padding is extremely important for Tensorflow, as changes in input sizes will trigger graph re-tracing.

I put together an example for Zephyr here. I'm not an LLM guy, so it might be substandard, but works quite well overall.

@federicoparra
Copy link
Author

federicoparra commented Feb 7, 2024

I'll try that. You make me wonder though: if changes in input shapes trigger retracing in TF (I'm assuming this is true as well for TFlite) then that would be problematic for LLMs inference speed - since you typically begin with a large input (the prompt), say, 24 token/words, but then during generation, you only feed back to the LLM the last generated token (since the context is kept precisely through the use of the past_key_values). So the main input (the first parameter passed to the model) goes from shape [1,24] on first iteration to [1,1] on all further iterations. Does this mean you would get a full retracing happening between iteration 1 and the other ones? Also, past_key_values grows linearly with each iteration...so you would be getting retrace every iteration? Maybe Tensorflow is just not adapted to work with LLMs? I find it difficult to believe they would let such an important framework die?

@federicoparra
Copy link
Author

@AlexanderLutsenko your example runs perfectly! thank you so much, everybody should know about this library!

@AlexanderLutsenko
Copy link
Owner

AlexanderLutsenko commented Feb 9, 2024

@federicoparra Awesome!

Does this mean you would get a full retracing happening between iteration 1 and the other ones? Also, past_key_values grows linearly with each iteration...so you would be getting retrace every iteration?

Well, not quite, as you would allocate buffers of the maximum allowed size in advance. BTW, TFLite does not support dynamic shapes, although these work sometimes. See PINTO0309/onnx2tf#543 for an example how it can be made to work and the drawbacks.

Maybe Tensorflow is just not adapted to work with LLMs?

Surely feels like it. Pytorch already supports FlashAttention-v2 out of the box, while in Tensorflow we are stuck with the naive implementation.

I find it difficult to believe they would let such an important framework die?

Dunno, that'd be a shame. TFLite runs really well on the mobile and the web, so they've still got the upper hand there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants