Skip to content

Commit

Permalink
--activation relu-squared (facebookresearch#2458)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#2458

Reviewed By: ngoyal2707

Differential Revision: D31721732

Pulled By: sshleifer

fbshipit-source-id: 620fbeece5ad4101baaf98cf2150027288ebad33
  • Loading branch information
sshleifer authored and facebook-github-bot committed Oct 18, 2021
1 parent 1ef3d6a commit 29be3fe
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions fairseq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,13 +536,18 @@ def deprecation_warning(message, stacklevel=3):
# don't use DeprecationWarning, since it's ignored by default
warnings.warn(message, stacklevel=stacklevel)

def relu_squared(x: torch.Tensor):
return F.relu(x).pow(2)


def get_activation_fn(activation: str) -> Callable:
"""Returns the activation function corresponding to `activation`"""
from fairseq.modules import gelu, gelu_accurate

if activation == "relu":
return F.relu
elif activation == "relu_squared":
return relu_squared
elif activation == "gelu":
return gelu
elif activation == "gelu_fast":
Expand Down

0 comments on commit 29be3fe

Please sign in to comment.