Skip to content

Commit

Permalink
[GPU] Enable unet2d enable on DG2 (openvinotoolkit#9522)
Browse files Browse the repository at this point in the history
* [GPU] Enable unet2d enable on DG2

Add to support is_os_yx_isa2_osa8_isv8_osv2 format, which is used in
weight reorder.

Signed-off-by: hyunback <[email protected]>
  • Loading branch information
hyunback authored Jan 7, 2022
1 parent 89f48e0 commit 2a476f6
Show file tree
Hide file tree
Showing 12 changed files with 76 additions and 28 deletions.
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/runtime/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ struct format {
os_is_yx_osa2_isa8_osv8_isv2,
os_is_yx_osa2_isa8_osv16_isv2,
os_is_yx_osa2_isa8_osv16_isv4,
is_os_yx_isa2_osa8_isv8_osv2,
is_o_yx_isv32, ///< format for weights for 1x1 MMAD convolutions
is_o32_yx_isv32_swizzled_by_4, ///< format for weights for 1x1 MMAD convolutions
os_is_y_x8_osv8_isv4, ///< format for weights for 1x1 MMAD convolutions
Expand Down Expand Up @@ -301,6 +302,7 @@ struct format {
{ os_is_zyx_isa8_osv16_isv4, { 1, 1, 3, 0, "oizyx", "oixyz", {{1, 8}, {0, 16}, {1, 4}}}},
{ os_is_yx_osa4_isa8_osv8_isv4_swizzled_by_4, { 1, 1, 2, 0, "oiyx", "oixy?", {{0, 32}, {1, 32}}}},
{ os_is_zyx_osa4_isa8_osv8_isv4_swizzled_by_4, { 1, 1, 3, 0, "oizyx", "oixyz", {{0, 32}, {1, 32}}}},
{ is_os_yx_isa2_osa8_isv8_osv2, { 1, 1, 2, 0, "ioyx", "ioxy?", {{1, 16}, {0, 16}}}},
{ is_o_yx_isv32, { 1, 1, 2, 0, "oyxi", "oixy?", {{1, 32}}}},
{ is_o32_yx_isv32_swizzled_by_4, { 1, 1, 2, 0, "oyxi", "oixy?", {}}},
{ os_is_y_x8_osv8_isv4, { 1, 1, 2, 0, "oyxi", "oixy?", {}}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ struct convolution_onednn : typed_primitive_onednn_impl<convolution, dnnl::convo
auto cldnn_prim = arg.get_primitive();
auto weights_layout = arg.get_dependency(1).get_output_layout();
auto grouped_weights = format::is_grouped(weights_layout.format) || arg.get_primitive()->grouped_weights_shape;
cldnn::format out_fmt = onednn::convert_format(onednn::get_format_by_desc(pd.weights_desc(0)), grouped_weights);
cldnn::format out_fmt = onednn::find_format(pd.weights_desc(0), grouped_weights);
kernel_selector::WeightsLayout reqLayout = to_weights_layout(out_fmt, cldnn_prim->grouped_weights_shape);

set_params(arg, r_params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ struct deconvolution_onednn : typed_primitive_onednn_impl<deconvolution, dnnl::d
auto cldnn_prim = arg.get_primitive();
auto weights_layout = arg.get_dependency(1).get_output_layout();
auto grouped_weights = format::is_grouped(weights_layout.format) || arg.get_primitive()->grouped_weights_shape;
cldnn::format out_fmt = onednn::convert_format(onednn::get_format_by_desc(pd.weights_desc(0)), grouped_weights);
cldnn::format out_fmt = onednn::find_format(pd.weights_desc(0), grouped_weights);
kernel_selector::WeightsLayout reqLayout = to_weights_layout(out_fmt, cldnn_prim->grouped_weights_shape);

set_params(arg, r_params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected, dnn

auto cldnn_prim = arg.get_primitive();
auto weights_layout = arg.get_dependency(1).get_output_layout();
cldnn::format out_fmt = onednn::convert_format(onednn::get_format_by_desc(pd.weights_desc(0)));
cldnn::format out_fmt = onednn::find_format(pd.weights_desc(0));
kernel_selector::WeightsLayout req_layout = to_weights_layout(out_fmt, false);

// set engine info & forcing
Expand Down
65 changes: 44 additions & 21 deletions src/plugins/intel_gpu/src/graph/impls/onednn/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ static bool isSame(dnnl::memory::desc desc, dnnl::memory::format_tag fmt) {
return true;
}

dnnl::memory::format_tag get_format_by_desc(dnnl::memory::desc desc) {
static dnnl::memory::format_tag get_format_by_desc(dnnl::memory::desc desc) {
// TODO [OneDNN]: Previously it was a field of tdesc, but now the brute
// force search here. Please avoid of using this method.
const auto ndims = desc.dims().size();
Expand All @@ -239,25 +239,8 @@ dnnl::memory::format_tag get_format_by_desc(dnnl::memory::desc desc) {
return dnnl::memory::format_tag::undef;
}

dnnl::algorithm convert_activation_func(cldnn::activation_func func) {
switch (func) {
case cldnn::activation_func::relu: return dnnl::algorithm::eltwise_relu;
case cldnn::activation_func::relu_negative_slope: return dnnl::algorithm::eltwise_relu;
case cldnn::activation_func::gelu: return dnnl::algorithm::eltwise_gelu;
case cldnn::activation_func::elu: return dnnl::algorithm::eltwise_elu;
case cldnn::activation_func::mish: return dnnl::algorithm::eltwise_mish;
case cldnn::activation_func::swish: return dnnl::algorithm::eltwise_swish;
case cldnn::activation_func::hswish: return dnnl::algorithm::eltwise_hardswish;
case cldnn::activation_func::abs: return dnnl::algorithm::eltwise_abs;
case cldnn::activation_func::exp: return dnnl::algorithm::eltwise_exp;
case cldnn::activation_func::logistic: return dnnl::algorithm::eltwise_logistic;
case cldnn::activation_func::clamp: return dnnl::algorithm::eltwise_clip;
case cldnn::activation_func::hyperbolic_tan: return dnnl::algorithm::eltwise_tanh;
default: throw std::runtime_error("Unsupported activation func for onednn primitive " + std::to_string(static_cast<int>(func)));
}
}

cldnn::format convert_format(dnnl::memory::format_tag fmt, bool is_grouped) {
// onednn -> cldnn
static cldnn::format convert_format(dnnl::memory::format_tag fmt, bool is_grouped) {
if (is_grouped) {
switch (fmt) {
case dnnl::memory::format_tag::abcde: return cldnn::format::goiyx;
Expand All @@ -278,7 +261,7 @@ cldnn::format convert_format(dnnl::memory::format_tag fmt, bool is_grouped) {
switch (fmt) {
case dnnl::memory::format_tag::ab: return cldnn::format::oiyx;
case dnnl::memory::format_tag::abcd: return cldnn::format::oiyx;
case dnnl::memory::format_tag::bacd: return cldnn::format::oiyx;
case dnnl::memory::format_tag::bacd: return cldnn::format::ioyx;
case dnnl::memory::format_tag::BAcd16b16a: return cldnn::format::is_os_yx_isv16_osv16;
case dnnl::memory::format_tag::ABcd16b16a: return cldnn::format::os_is_yx_isv16_osv16;
case dnnl::memory::format_tag::abcde: return cldnn::format::oizyx;
Expand All @@ -299,6 +282,46 @@ cldnn::format convert_format(dnnl::memory::format_tag fmt, bool is_grouped) {
}
}

cldnn::format find_format(dnnl::memory::desc desc, bool is_grouped) {
auto onednn_desc = get_format_by_desc(desc);

if (onednn_desc != dnnl::memory::format_tag::undef) {
return convert_format(onednn_desc, is_grouped);
} else {
if (is_grouped) {
throw std::runtime_error(std::string("Unsupported grouped onednn dnnl::memory::desc find_format"));
} else {
auto blk = desc.data.format_desc.blocking;

if (desc.data.ndims == 4 && desc.data.format_desc.blocking.inner_nblks == 4
&& blk.inner_blks[0] == 2 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 8 && blk.inner_blks[3] == 2
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0 && blk.inner_idxs[2] == 1 && blk.inner_idxs[3] == 0) {
return cldnn::format::is_os_yx_isa2_osa8_isv8_osv2;
} else {
throw std::runtime_error(std::string("Unsupported onednn dnnl::memory::desc find_format"));
}
}
}
}

dnnl::algorithm convert_activation_func(cldnn::activation_func func) {
switch (func) {
case cldnn::activation_func::relu: return dnnl::algorithm::eltwise_relu;
case cldnn::activation_func::relu_negative_slope: return dnnl::algorithm::eltwise_relu;
case cldnn::activation_func::gelu: return dnnl::algorithm::eltwise_gelu;
case cldnn::activation_func::elu: return dnnl::algorithm::eltwise_elu;
case cldnn::activation_func::mish: return dnnl::algorithm::eltwise_mish;
case cldnn::activation_func::swish: return dnnl::algorithm::eltwise_swish;
case cldnn::activation_func::hswish: return dnnl::algorithm::eltwise_hardswish;
case cldnn::activation_func::abs: return dnnl::algorithm::eltwise_abs;
case cldnn::activation_func::exp: return dnnl::algorithm::eltwise_exp;
case cldnn::activation_func::logistic: return dnnl::algorithm::eltwise_logistic;
case cldnn::activation_func::clamp: return dnnl::algorithm::eltwise_clip;
case cldnn::activation_func::hyperbolic_tan: return dnnl::algorithm::eltwise_tanh;
default: throw std::runtime_error("Unsupported activation func for onednn primitive " + std::to_string(static_cast<int>(func)));
}
}

template <typename T>
void make_per_tensor_if_possible(cldnn::data_node& node) {
auto ptr = node.get_attached_memory_ptr();
Expand Down
5 changes: 1 addition & 4 deletions src/plugins/intel_gpu/src/graph/impls/onednn/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,8 @@ dnnl::memory::dims flatten_tensor(cldnn::tensor t);
dnnl::memory::data_type convert_data_type(cldnn::data_types dt);
dnnl::memory::format_tag convert_data_format(cldnn::format fmt);
dnnl::memory::desc layout_to_memory_desc(cldnn::layout l, dnnl::memory::format_tag target_fmt = dnnl::memory::format_tag::undef, bool flatten = false);
dnnl::memory::format_tag get_format_by_desc(dnnl::memory::desc desc);
dnnl::algorithm convert_activation_func(cldnn::activation_func func);

// onednn -> cldnn
cldnn::format convert_format(dnnl::memory::format_tag fmt, bool is_grouped = false);
cldnn::format find_format(dnnl::memory::desc desc, bool is_grouped = false);

int64_t get_offset(dnnl::memory::desc desc);

Expand Down
4 changes: 4 additions & 0 deletions src/plugins/intel_gpu/src/graph/kernel_selector_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,8 @@ kernel_selector::weights_layout to_weights_layout(format f, bool is_grouped) {
return kernel_selector::weights_layout::g_os_zyx_is_osv32_isv16;
case format::g_os_zyx_is_osv32_isv32:
return kernel_selector::weights_layout::g_os_zyx_is_osv32_isv32;
case format::is_os_yx_isa2_osa8_isv8_osv2:
return kernel_selector::weights_layout::is_os_yx_isa2_osa8_isv8_osv2;
default:
throw std::invalid_argument("Unable to convert tensor layout " + fmt_to_str(f) + " to weights layout");
}
Expand Down Expand Up @@ -506,6 +508,8 @@ cldnn::format::type from_weights_layout(kernel_selector::weights_layout l) {
return cldnn::format::is_os_zyx_isv16_osv16;
case kernel_selector::weights_layout::is_os_yx_isv16_osv16:
return cldnn::format::is_os_yx_isv16_osv16;
case kernel_selector::weights_layout::is_os_yx_isa2_osa8_isv8_osv2:
return cldnn::format::is_os_yx_isa2_osa8_isv8_osv2;
case kernel_selector::weights_layout::os_is_yx_osv8_isv2:
return cldnn::format::os_is_yx_osv8_isv2;
case kernel_selector::weights_layout::os_is_yx_osv8_isv4:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ WeightsTensor::WeightsChannelArray WeightsTensor::weightsChannelArray {{
{ WeightsLayout::os_is_yx_isv16_osv16, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::is_os_zyx_isv16_osv16, { 0, 1, 2, 4, 3, -1 } },
{ WeightsLayout::is_os_yx_isv16_osv16, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::is_os_yx_isa2_osa8_isv8_osv2, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::os_is_osv32_isv32_swizzled_by_4, { -1, -1, -1, 0, 1, -1 } },
{ WeightsLayout::os_is_zyx_isv8_osv16_isv2, { 0, 1, 2, 3, 4, -1 } },
{ WeightsLayout::os_is_yx_isv8_osv16_isv2, { 0, 1, -1, 2, 3, -1 } },
Expand Down Expand Up @@ -534,6 +535,7 @@ NDims WeightsTensor::GetSimpleDims(const std::vector<size_t>& d, WeightsLayout l
newDims[3] = RoundUp(newDims[3], 32);
break;
case os_is_yx_osa2_isa8_osv8_isv2:
case is_os_yx_isa2_osa8_isv8_osv2:
newDims[2] = RoundUp(newDims[2], 16);
newDims[3] = RoundUp(newDims[3], 16);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ enum WeightsLayout {
os_is_yx_osa2_isa8_osv8_isv2,
os_is_yx_osa2_isa8_osv16_isv4,
os_is_yx_osa2_isa8_osv16_isv2,
is_os_yx_isa2_osa8_isv8_osv2,
g_os_is_yx_osa2_isa8_osv16_isv4,
g_os_is_yx_osa2_isa8_osv16_isv2,
os_is_yx_osa4_isa8_osv8_isv4_swizzled_by_4, // for MMAD convolution swizzled from ofm 0..7 to 0,4,8,12,16,20,24,28,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,12 @@ inline uint get_g_os_is_yx_osa2_isa8_osv16_isv2(uint g, uint o, uint i, uint y,
return idx;
}

inline uint get_g_is_os_yx_isa2_osa8_isv8_osv2(uint g, uint o, uint i, uint z, uint y, uint x,
uint size_x, uint size_y, uint size_z, uint size_ifm, uint size_ofm, uint offset)
{
return get_g_os_is_yx_osa2_isa8_osv8_isv2(g, i, o, z, y, x, size_x, size_y, size_z, size_ofm, size_ifm, offset);
}

#define GET_FILTER_OS_IS_YX_OSA4_ISA8_OSV8_ISV4_INDEX(prefix, o, i, y, x) \
get_g_os_is_yx_osa4_isa8_osv8_isv4( \
0, o, i, 0, y, x, \
Expand Down Expand Up @@ -895,6 +901,16 @@ inline uint get_g_os_is_yx_osa2_isa8_osv16_isv2(uint g, uint o, uint i, uint y,
CAT(prefix, _OFM_NUM), \
CAT(prefix, _OFFSET))

#define GET_FILTER_IS_OS_YX_ISA2_OSA8_ISV8_OSV2_INDEX(prefix, o, i, y, x) \
get_g_is_os_yx_isa2_osa8_isv8_osv2( \
0, o, i, 0, y, x, \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \
1, \
CAT(prefix, _IFM_NUM), \
CAT(prefix, _OFM_NUM), \
CAT(prefix, _OFFSET))


inline uint get_is_o_yx_isv32_index(uint o, uint i, uint y, uint x, uint i_size, uint o_size, uint x_size, uint y_size)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
return GET_FILTER_OS_IS_YX_ISA8_OSV8_ISV4_SWIZZLED_BY_4_INDEX(OUTPUT, g, o, i, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_YX_OSA2_ISA8_OSV8_ISV2
return GET_FILTER_OS_IS_YX_OSA2_ISA8_OSV8_ISV2_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_IS_OS_YX_ISA2_OSA8_ISV8_OSV2
return GET_FILTER_IS_OS_YX_ISA2_OSA8_ISV8_OSV2_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_YX_OSA4_ISA8_OSV8_ISV2
return GET_FILTER_OS_IS_YX_OSA4_ISA8_OSV8_ISV2_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_ZYX_OSA4_ISA8_OSV8_ISV2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ std::string toString(WeightsLayout layout) {
case WeightsLayout::os_is_yx_osa2_isa8_osv16_isv2: return "OS_IS_YX_OSA2_ISA8_OSV16_ISV2";
case WeightsLayout::g_os_is_yx_osa2_isa8_osv16_isv2: return "G_OS_IS_YX_OSA2_ISA8_OSV16_ISV2";
case WeightsLayout::os_is_yx_osa2_isa8_osv8_isv2: return "OS_IS_YX_OSA2_ISA8_OSV8_ISV2";
case WeightsLayout::is_os_yx_isa2_osa8_isv8_osv2: return "IS_OS_YX_ISA2_OSA8_ISV8_OSV2";
case WeightsLayout::g_os_is_yx_isv16_osv16: return "G_OS_IS_YX_ISV16_OSV16";
case WeightsLayout::g_os_is_yx_osv16_isv4: return "G_OS_IS_YX_OSV16_ISV4";
case WeightsLayout::g_os_is_zyx_osv16_isv16: return "G_OS_IS_ZYX_OSV16_ISV16";
Expand Down

0 comments on commit 2a476f6

Please sign in to comment.