Skip to content
This repository has been archived by the owner on Dec 30, 2024. It is now read-only.

Commit

Permalink
aten.hardsigmoid.default in unary_ops (pytorch#5396)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#5396

Implement aten.hardsigmoid in unary_ops

Reviewed By: jorgep31415

Differential Revision: D62584402

fbshipit-source-id: 3d3bd5292e73fcd6068142fa37d428b3f566c7b8
  • Loading branch information
Abhi-hpp authored and facebook-github-bot committed Sep 18, 2024
1 parent b14dea8 commit 8a0b48e
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 7 deletions.
12 changes: 12 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,15 @@ vec4 hardshrink(vec4 tex, float lambda, float neg_lambda) {
(vec4(greaterThan(tex, vec4(lambda))) +
vec4(lessThan(tex, vec4(neg_lambda))));
}

float hardsigmoid(float x) {
return mix(float(x >= 0.0), x / 6 + 0.5, float(abs(x) <= 3.0));
}

vec4 hardsigmoid(vec4 tex) {
return vec4(
hardsigmoid(tex.x),
hardsigmoid(tex.y),
hardsigmoid(tex.z),
hardsigmoid(tex.w));
}
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,5 @@ unary_op:
OPERATOR: hardshrink(X, A, B)
- NAME: hardswish
OPERATOR: hardswish(X)
- NAME: hardsigmoid
OPERATOR: hardsigmoid(X)
10 changes: 3 additions & 7 deletions backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,6 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
"hardshrink"); \
}

#define DEFINE_HARDSWISH_FN(op_name) \
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
return add_unary_op_node( \
graph, args[0], kDummyFloat, kDummyFloat, args[1], #op_name); \
}

void gelu(ComputeGraph& graph, const std::vector<ValueRef>& args) {
// args[1] is the `approximate` string
// https://fburl.com/code/9omngmyo
Expand All @@ -140,7 +134,8 @@ DEFINE_CLAMP_FN(clamp);
DEFINE_CLAMP_FN(hardtanh);
DEFINE_RELU_FN(relu);
DEFINE_HARDSHRINK_FN(hardshrink);
DEFINE_HARDSWISH_FN(hardswish);
DEFINE_ACTIVATION_FN(hardswish);
DEFINE_ACTIVATION_FN(hardsigmoid);

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.abs.default, abs);
Expand All @@ -157,6 +152,7 @@ REGISTER_OPERATORS {
VK_REGISTER_OP(aten.tanh.default, tanh);
VK_REGISTER_OP(aten.hardshrink.default, hardshrink);
VK_REGISTER_OP(aten.hardswish.default, hardswish);
VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid);
}

} // namespace vkcompute
1 change: 1 addition & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,7 @@ def get_softmax_inputs():
"aten.neg.default",
"aten.cos.default",
"aten.hardswish.default",
"aten.hardsigmoid.default",
]
)
def get_unary_ops_inputs():
Expand Down

0 comments on commit 8a0b48e

Please sign in to comment.