From 6a8ac4305f588b12d4073a04ff70f48e48c74cf6 Mon Sep 17 00:00:00 2001 From: Korbinian Poeppel Date: Mon, 17 Jun 2024 15:10:58 +0200 Subject: [PATCH] Fix typo in error message. --- xlstm/components/feedforward.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/xlstm/components/feedforward.py b/xlstm/components/feedforward.py index 58401e2..5995893 100644 --- a/xlstm/components/feedforward.py +++ b/xlstm/components/feedforward.py @@ -25,7 +25,7 @@ def get_act_fn(act_fn_name: str) -> Callable[[torch.Tensor], torch.Tensor]: else: assert ( False - ), f'Unknown activation function name "{act_fn_name}". Available activation functions are: {str(_act_fn_cls_registry.keys())}' + ), f'Unknown activation function name "{act_fn_name}". Available activation functions are: {str(_act_fn_registry.keys())}' @dataclass @@ -41,7 +41,9 @@ class FeedForwardConfig(UpProjConfigMixin): def __post_init__(self): self._set_proj_up_dim(embedding_dim=self.embedding_dim) - assert self.act_fn in _act_fn_registry, f"Unknown activation function {self.act_fn}" + assert ( + self.act_fn in _act_fn_registry + ), f"Unknown activation function {self.act_fn}" class GatedFeedForward(nn.Module): @@ -76,7 +78,11 @@ def reset_parameters(self): small_init_init_(self.proj_up.weight, dim=self.config.embedding_dim) if self.proj_up.bias is not None: nn.init.zeros_(self.proj_up.bias) - wang_init_(self.proj_down.weight, dim=self.config.embedding_dim, num_blocks=self.config._num_blocks) + wang_init_( + self.proj_down.weight, + dim=self.config.embedding_dim, + num_blocks=self.config._num_blocks, + ) if self.proj_down.bias is not None: nn.init.zeros_(self.proj_down.bias)