Skip to content

Commit

Permalink
[Bugfix][Relay] Fix threshold calculation logic in PyTorch frontend (a…
Browse files Browse the repository at this point in the history
…pache#14820)

* fix threshold

* add test case

* Update pytorch.py

* Update pytorch.py
  • Loading branch information
jikechao authored May 11, 2023
1 parent fd2a510 commit 2cafa87
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,7 +1333,11 @@ def softmax(self, inputs, input_types):

def threshold(self, inputs, input_types):
data = inputs[0]
return _op.nn.relu(data)
threshold_f = float(inputs[1])
threshold_ = _op.full_like(inputs[0], fill_value=_expr.const(threshold_f))
value_f = float(inputs[2])
value = _op.full_like(inputs[0], fill_value=_expr.const(value_f))
return _op.where(_op.greater(data, threshold_), data, value)

def contiguous(self, inputs, input_types):
return inputs[0]
Expand Down
2 changes: 2 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,8 @@ def test_forward_threshold():
input_shape = [1, 3]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.Threshold(0, 0).float().eval(), input_data=input_data)
input_data = torch.tensor([[-1.0, 2.0]], dtype=torch.float32)
verify_model(torch.nn.Threshold(1, 1).float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
Expand Down

0 comments on commit 2cafa87

Please sign in to comment.