Skip to content

Commit

Permalink
Mamba slow_forward gradient fix (huggingface#29563)
Browse files Browse the repository at this point in the history
* FIX: Cached slow forward in mamba
- additionally added mamba cached test
- added unused test (mamba causal lm forward and backward)
- fixed typo: "causl" --> "causal"

* formatting

* fix: use real `slow_forward` call instead of torch module's

* add shape assertion for mixer block test

* adjust shape assertion
  • Loading branch information
vasqu authored Mar 27, 2024
1 parent 1c39974 commit cefb819
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None):

# 2. Convolution sequence transformation
if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx]
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
if cache_params.seqlen_offset > 0:
conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
Expand Down
37 changes: 34 additions & 3 deletions tests/models/mamba/test_modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def create_and_check_mamba_model(self, config, input_ids, *args):
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(len(result.hidden_states), config.num_hidden_layers + 1)

def create_and_check_causl_lm(self, config, input_ids, *args):
def create_and_check_causal_lm(self, config, input_ids, *args):
model = MambaForCausalLM(config)
model.to(torch_device)
model.eval()
Expand All @@ -197,7 +197,30 @@ def create_and_check_state_equivalency(self, config, input_ids, *args):
self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5))
# TODO the orignal mamba does not support decoding more than 1 token neither do we

def create_and_check_forward_and_backwards(self, config, input_ids, *args, gradient_checkpointing=False):
def create_and_check_mamba_cached_slow_forward_and_backwards(
self, config, input_ids, *args, gradient_checkpointing=False
):
model = MambaModel(config)
model.to(torch_device)
if gradient_checkpointing:
model.gradient_checkpointing_enable()

# create cache
cache = model(input_ids, use_cache=True).cache_params
cache.seqlen_offset = 0

# use cache
token_emb = model.embeddings(input_ids)
outputs = model.layers[0].mixer.slow_forward(token_emb, cache)

loss = torch.log(1 + torch.abs(outputs.sum()))
self.parent.assertEqual(loss.shape, ())
self.parent.assertEqual(outputs.shape, (self.batch_size, self.seq_length, self.hidden_size))
loss.backward()

def create_and_check_mamba_lm_head_forward_and_backwards(
self, config, input_ids, *args, gradient_checkpointing=False
):
model = MambaForCausalLM(config)
model.to(torch_device)
if gradient_checkpointing:
Expand Down Expand Up @@ -304,12 +327,20 @@ def test_mamba_model(self):

def test_mamba_lm_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causl_lm(*config_and_inputs)
self.model_tester.create_and_check_causal_lm(*config_and_inputs)

def test_state_equivalency(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_state_equivalency(*config_and_inputs)

def test_mamba_cached_slow_forward_and_backwards(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mamba_cached_slow_forward_and_backwards(*config_and_inputs)

def test_mamba_lm_head_forward_and_backwards(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mamba_lm_head_forward_and_backwards(*config_and_inputs)

def test_initialization(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down

0 comments on commit cefb819

Please sign in to comment.