Skip to content

Commit

Permalink
[Bugfix][Relay] Fix softplus about the wrong calculation formula in R…
Browse files Browse the repository at this point in the history
…elay PyTorch frontend (apache#14821)

* fix softplus operator

* add test cases

* Update pytorch.py

* Update pytorch.py
  • Loading branch information
jikechao authored May 12, 2023
1 parent ae9209b commit 483b87d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
5 changes: 4 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1685,7 +1685,10 @@ def func(x):
def softplus(self, inputs, input_types):
dtype = input_types[0]
beta = _expr.const(float(inputs[1]), dtype=dtype)
return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta
threshold = int(inputs[2]) if inputs[2] else 20
threshold_ = _op.full_like(inputs[0], fill_value=_expr.const(threshold))
softplus_value = _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta
return _op.where(_op.greater(inputs[0] * beta, threshold_), inputs[0], softplus_value)

def make_avg_pool(self, dim):
def avg_pool(inputs, input_types):
Expand Down
3 changes: 3 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,9 @@ def test_forward_softplus():
verify_model(torch.nn.Softplus().eval(), input_data=input_data)
verify_model(torch.nn.Softplus(beta=1.5, threshold=20).eval(), input_data=input_data)
verify_model(torch.nn.Softplus(beta=5, threshold=10).eval(), input_data=input_data)
verify_model(torch.nn.Softplus(beta=5, threshold=1).eval(), input_data=input_data)
verify_model(torch.nn.Softplus(beta=1, threshold=2).eval(), input_data=input_data)
verify_model(torch.nn.Softplus(beta=1, threshold=-1).eval(), input_data=input_data)


@tvm.testing.uses_gpu
Expand Down

0 comments on commit 483b87d

Please sign in to comment.