Skip to content

Commit

Permalink
Fix typo in error message.
Browse files Browse the repository at this point in the history
  • Loading branch information
kpoeppel committed Jun 17, 2024
1 parent 7a04945 commit 6a8ac43
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions xlstm/components/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_act_fn(act_fn_name: str) -> Callable[[torch.Tensor], torch.Tensor]:
else:
assert (
False
), f'Unknown activation function name "{act_fn_name}". Available activation functions are: {str(_act_fn_cls_registry.keys())}'
), f'Unknown activation function name "{act_fn_name}". Available activation functions are: {str(_act_fn_registry.keys())}'


@dataclass
Expand All @@ -41,7 +41,9 @@ class FeedForwardConfig(UpProjConfigMixin):

def __post_init__(self):
self._set_proj_up_dim(embedding_dim=self.embedding_dim)
assert self.act_fn in _act_fn_registry, f"Unknown activation function {self.act_fn}"
assert (
self.act_fn in _act_fn_registry
), f"Unknown activation function {self.act_fn}"


class GatedFeedForward(nn.Module):
Expand Down Expand Up @@ -76,7 +78,11 @@ def reset_parameters(self):
small_init_init_(self.proj_up.weight, dim=self.config.embedding_dim)
if self.proj_up.bias is not None:
nn.init.zeros_(self.proj_up.bias)
wang_init_(self.proj_down.weight, dim=self.config.embedding_dim, num_blocks=self.config._num_blocks)
wang_init_(
self.proj_down.weight,
dim=self.config.embedding_dim,
num_blocks=self.config._num_blocks,
)
if self.proj_down.bias is not None:
nn.init.zeros_(self.proj_down.bias)

Expand Down

0 comments on commit 6a8ac43

Please sign in to comment.