Skip to content

Commit

Permalink
[GPU] Fix canonicalization for fused dep's shape (openvinotoolkit#19667)
Browse files Browse the repository at this point in the history
* [GPU] Fix canonicalization for fused dep's shape

Signed-off-by: Andrew Park <[email protected]>

* Update TC to reproducible on the latest master

Signed-off-by: Andrew Park <[email protected]>

* Fix custom canonicalize shapes for Gather

---------

Signed-off-by: Andrew Park <[email protected]>
  • Loading branch information
andrew-k-park authored Sep 19, 2023
1 parent 631d6d3 commit 394e58f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/plugins/intel_gpu/src/graph/impls/ocl/gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,12 @@ struct gather_impl : typed_primitive_impl_ocl<gather> {
out_layout.format = format::adjust_to_rank(out_layout.format, output_pshape.size());
}

return primitive_impl::static_canonicalize_shapes(updated_impl_params);
for (auto& input_layout : updated_impl_params.input_layouts) {
input_layout.set_partial_shape(extend_shape_to_rank_from_end(input_layout.get_partial_shape()));
}
out_layout.set_partial_shape(extend_shape_to_rank_from_end(out_layout.get_partial_shape()));

return updated_impl_params;
}

kernel_impl_params canonicalize_shapes(const kernel_impl_params& impl_params) const override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,45 @@ TEST(canonicalization, gather) {
}
}

struct fusing_gather_eltwise_params {
ov::PartialShape data_shape;
ov::Shape out_shape;
int64_t axis;
int64_t batch_dim;
bool support_neg_ind;
};

std::vector<std::pair<Shapes, fusing_gather_eltwise_params>> fusing_gather_eltwise_shapes_with_params {
{
{{{}, {}}, {{4624, 4, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {4624, 1, 1, 1}}, {{4624, 1, 1, 1}}},
{{4624, 4}, {4624}, 1, 0, true}
}
};

TEST(canonicalization, fusing_gather_eltwise) {
for (const auto& shapes : fusing_gather_eltwise_shapes_with_params) {
layout input_gather_layout = create_default_layout(shapes.second.data_shape);
layout indices_layout_first = create_default_layout(std::get<0>(shapes.first)[0]);
layout indices_layout_second = create_default_layout(std::get<0>(shapes.first)[0]);
layout input_mul_layout = create_default_layout(std::get<0>(shapes.first)[1]);

topology topology;
topology.add(input_layout("input", input_gather_layout));
topology.add(input_layout("indices_first", indices_layout_first));
topology.add(input_layout("indices_second", indices_layout_second));
topology.add(input_layout("data", input_mul_layout));
topology.add(gather("gather_first", input_info("input"), input_info("indices_first"), shapes.second.axis,
shapes.second.out_shape, shapes.second.batch_dim, shapes.second.support_neg_ind));
topology.add(gather("gather_second", input_info("input"), input_info("indices_second"), shapes.second.axis,
shapes.second.out_shape, shapes.second.batch_dim, shapes.second.support_neg_ind));
topology.add(eltwise("mul", {input_info("gather_first"), input_info("data")}, eltwise_mode::prod));
topology.add(eltwise("add", {input_info("gather_second"), input_info("mul")}, eltwise_mode::sum));
topology.add(reorder("out_reorder", input_info("add"), format::bfyx, data_types::f32));

canonicalization_test(topology, "gather_first", std::get<1>(shapes.first), std::get<2>(shapes.first), true);
}
}

struct fusing_gemm_eltwise_params {
ov::PartialShape input_gemm_first;
ov::PartialShape weights_gemm_first;
Expand Down

0 comments on commit 394e58f

Please sign in to comment.