Skip to content

Commit

Permalink
[bugfix][relay] fix wrong calculate logic about celu (apache#14796)
Browse files Browse the repository at this point in the history
* fix celu

* add test case

* Update test_forward.py

* Update pytorch.py
  • Loading branch information
jikechao authored May 8, 2023
1 parent d6e0f1d commit e01cb47
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 3 additions & 2 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,8 +968,9 @@ def celu(self, inputs, input_types):
data = inputs[0]
dtype = input_types[0]
alpha = _expr.const(float(inputs[1]), dtype=dtype)
return alpha * _op.nn.relu(
_expr.const(1, dtype=dtype) - _op.exp(data / alpha)
zero = _op.const(0, dtype)
return alpha * _op.minimum(
zero, _op.exp(data / alpha) - _expr.const(1, dtype=dtype)
) + _op.nn.relu(data)

def gelu(self, inputs, input_types):
Expand Down
2 changes: 2 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,8 @@ def test_forward_celu():
verify_model(torch.nn.CELU(alpha=0.3).eval(), input_data=input_data)
verify_model(torch.nn.CELU(alpha=1.0).eval(), input_data=input_data)
verify_model(torch.nn.CELU(alpha=1.3).eval(), input_data=input_data)
input_data = torch.tensor([-1.0, 2.0], dtype=torch.float32)
verify_model(torch.nn.CELU().eval(), input_data=input_data)


@tvm.testing.uses_gpu
Expand Down

0 comments on commit e01cb47

Please sign in to comment.