Skip to content

Commit

Permalink
[Relay][Bugfix] Fix the wrong implementation about operator Threshold…
Browse files Browse the repository at this point in the history
… in oneflow (apache#15715)

* [Relay][BugFix] fix the wrong implementation of Threshold in OneFlow

* Update test_forward.py

* Update oneflow.py

* Update test_forward.py

* Update oneflow.py

* Update test_forward.py

add version checking
  • Loading branch information
jikechao authored Sep 12, 2023
1 parent 4d7e93c commit d8136fb
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
15 changes: 9 additions & 6 deletions python/tvm/relay/frontend/oneflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,15 +1025,17 @@ def _impl_v1(cls, inputs, attrs, params):
return out


class ThresholdedRelu(OneFlowOpConverter):
"""Operator converter for ThresholdedRelu."""
class Threshold(OneFlowOpConverter):
"""Operator converter for Threshold."""

@classmethod
def _impl_v1(cls, inputs, attrs, params):
alpha = float(attrs.get("alpha", 1.0))
alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha))
mask = _op.greater(inputs[0], alpha_tensor).astype("float32")
return inputs[0] * mask
threshold = float(attrs.get("threshold_val", 1.0))
threshold_tensor = _op.full_like(inputs[0], fill_value=_expr.const(threshold))
value = float(attrs.get("value"))
value_tensor = _op.full_like(inputs[0], fill_value=_expr.const(value))
mask = _op.greater(inputs[0], threshold_tensor)
return _op.where(mask, inputs[0], value_tensor)


class Elu(OneFlowOpConverter):
Expand Down Expand Up @@ -1425,6 +1427,7 @@ def get_convert_map():
"relu": Renamer("relu"),
"leaky_relu": Renamer("leaky_relu"),
"prelu": PReLU.get_converter(),
"threshold": Threshold.get_converter(),
"selu": Selu.get_converter(),
"silu": Silu.get_converter(),
"gelu": Gelu.get_converter(),
Expand Down
15 changes: 15 additions & 0 deletions tests/python/frontend/oneflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tvm.testing
import tvm.topi.testing
from tvm import relay
from packaging import version as package_version

MODEL_HOME = "test_model"

Expand Down Expand Up @@ -702,6 +703,15 @@ def forward(self, x):
x = x.softmax(dim=-1)
return x

class Threshold(flow.nn.Module):
def __init__(self):
super().__init__()
self.active = flow.nn.Threshold(0.5, 0.2)

def forward(self, x):
x = self.active(x)
return x

if os.path.exists(MODEL_HOME):
rmdir(MODEL_HOME)

Expand Down Expand Up @@ -738,6 +748,11 @@ def forward(self, x):
inputs=flow.tensor(np.random.rand(1, 12, 197, 197).astype(np.float32)),
)

# Threshold was introduced in the version 0.8.0 of oneflow
if package_version.parse(flow.__version__) >= package_version.parse("0.8.0"):
model14 = Threshold().eval()
verify_activation(model14, device="llvm")


@tvm.testing.uses_gpu
def test_math():
Expand Down

0 comments on commit d8136fb

Please sign in to comment.