From 51e83a61385d2819f9b8822c016478e2ba8298aa Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 28 May 2024 13:32:18 +0200 Subject: [PATCH 001/297] Refactor gemm generator Signed-off-by: Carsten Uphoff --- src/codegen_tools.cpp | 232 ++++++++++++++++++++++++ src/codegen_tools.hpp | 71 +++++++- src/gemm_generator.cpp | 311 +++++++++----------------------- src/gemm_generator.hpp | 8 +- src/visitor/alias_analysis.cpp | 1 + src/visitor/alias_analysis.hpp | 1 - src/visitor/insert_barrier.cpp | 1 - src/visitor/insert_barrier.hpp | 1 - src/visitor/stack.hpp | 2 +- test/codegen/atomic.ir | 2 +- tools/offline_compiler/main.cpp | 1 + 11 files changed, 391 insertions(+), 240 deletions(-) diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 1ff7acf7..511f2c18 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause #include "codegen_tools.hpp" +#include "error.hpp" #include "scalar_type.hpp" #include "util.hpp" @@ -12,6 +13,7 @@ #include #include +#include #include #include @@ -19,6 +21,34 @@ using namespace clir; namespace tinytc { +expr as_type(builtin_type ty, expr e) { + switch (ty) { + case builtin_type::char_t: + return as_char(std::move(e)); + case builtin_type::uchar_t: + return as_uchar(std::move(e)); + case builtin_type::short_t: + return as_short(std::move(e)); + case builtin_type::ushort_t: + return as_ushort(std::move(e)); + case builtin_type::int_t: + return as_int(std::move(e)); + case builtin_type::uint_t: + return as_uint(std::move(e)); + case builtin_type::long_t: + return as_long(std::move(e)); + case builtin_type::ulong_t: + return as_ulong(std::move(e)); + case builtin_type::float_t: + return as_float(std::move(e)); + case builtin_type::double_t: + return as_double(std::move(e)); + default: + break; + } + return e; +} + expr vload_helper(short vec_size, expr offset, expr ptr) { switch (vec_size) { case 1: @@ -39,6 +69,77 @@ expr vload_helper(short vec_size, expr offset, expr ptr) { return nullptr; } +builtin_type block_rw_op_type(builtin_type scalar_ty) { + switch (scalar_ty) { + case builtin_type::short_t: + return builtin_type::ushort_t; + case builtin_type::int_t: + case builtin_type::float_t: + return builtin_type::uint_t; + case builtin_type::long_t: + case builtin_type::double_t: + return builtin_type::ulong_t; + default: + break; + } + return scalar_ty; +} + +expr sub_group_block_read_helper(expr pointer, builtin_type scalar_ty, address_space as) { + auto const make_read = [](builtin_type bt, expr pointer) -> expr { + switch (bt) { + case builtin_type::short_t: + case builtin_type::ushort_t: + return intel_sub_group_block_read_us(std::move(pointer)); + case builtin_type::int_t: + case builtin_type::uint_t: + case builtin_type::float_t: + return intel_sub_group_block_read_ui(std::move(pointer)); + case builtin_type::long_t: + case builtin_type::ulong_t: + case builtin_type::double_t: + return intel_sub_group_block_read_ul(std::move(pointer)); + default: + break; + } + return pointer[get_sub_group_local_id()]; + }; + auto const bt = block_rw_op_type(scalar_ty); + pointer = cast(pointer_to(clir::data_type(bt, as)), std::move(pointer)); + auto inst = make_read(bt, std::move(pointer)); + if (bt != scalar_ty) { + return as_type(scalar_ty, std::move(inst)); + } + return inst; +} +expr sub_group_block_write_helper(expr pointer, expr data, builtin_type scalar_ty, + address_space as) { + auto const make_write = [](builtin_type bt, expr pointer, expr data) -> expr { + switch (bt) { + case builtin_type::short_t: + case builtin_type::ushort_t: + return intel_sub_group_block_write_us(std::move(pointer), std::move(data)); + case builtin_type::int_t: + case builtin_type::uint_t: + case builtin_type::float_t: + return intel_sub_group_block_write_ui(std::move(pointer), std::move(data)); + case builtin_type::long_t: + case builtin_type::ulong_t: + case builtin_type::double_t: + return intel_sub_group_block_write_ul(std::move(pointer), std::move(data)); + default: + break; + } + return pointer[get_sub_group_local_id()] = std::move(data); + }; + auto const bt = block_rw_op_type(scalar_ty); + pointer = cast(pointer_to(clir::data_type(bt, as)), std::move(pointer)); + if (bt != scalar_ty) { + data = as_type(bt, std::move(data)); + } + return make_write(bt, std::move(pointer), std::move(data)); +} + void store_helper(block_builder &bb, bool is_atomic, expr dst, scalar_type ty, address_space as, expr value, expr beta) { if (is_atomic) { @@ -229,4 +330,135 @@ void tile_loop_uniformly_dynamic(block_builder &bb, expr loop_trip_count, unsign .get_product()); } +block_accessor_regular::block_accessor_regular(expr block, int Kb, expr offset) + : block_(std::move(block)), Kb_(Kb), offset_(std::move(offset)) {} +auto block_accessor_regular::get(int m_block, int k) const -> expr { + const auto i = k + m_block * Kb_; + if (offset_) { + return block_[offset_ + i]; + } + return block_[i]; +} + +block_accessor_vector::block_accessor_vector(expr block) : block_(std::move(block)) {} +auto block_accessor_vector::get(int m_block, int k) const -> expr { return block_[m_block].s(k); } + +int matrix_block_description::first_block_with_check(std::int32_t subgroup_size) const { + int fb = 0; + dispatch_constant_dynamic( + M, [&](std::int64_t m) { fb = m / subgroup_size; }, [](expr const &) {}); + return fb; +} + +bool matrix_block_description::is_unit_stride(int mode) const { + bool is_unit = false; + dispatch_constant_dynamic( + stride[mode], [&](std::int64_t s) { is_unit = s == 1; }, [](expr const &) {}); + return is_unit; +} + +expr matrix_block_description::condition(int m_block, std::int32_t subgroup_size) const { + return get_sub_group_local_id() + m_block * subgroup_size < M; +} + +auto read_matrix_block_regular(block_builder &bb, matrix_block_description const &d, int M_mode, + core_config const &core_cfg, char const *block_name) + -> std::unique_ptr { + assert(M_mode == 0 || M_mode == 1); + + const int m_blocks = 1 + (d.Mb - 1) / core_cfg.subgroup_size; + auto const scalar_ty = to_clir_builtin_ty(d.ty); + auto block = bb.declare(array_of(clir::data_type(scalar_ty), m_blocks * d.Kb), block_name); + + const int first_m_block_with_check = d.first_block_with_check(core_cfg.subgroup_size); + const bool enable_sub_group_reads = + core_cfg.block_read_write_supported && d.is_unit_stride(M_mode); + for (int k = 0; k < d.Kb; ++k) { + for (int m_block = 0; m_block < m_blocks; ++m_block) { + auto const store = [&](expr rhs) { + bb.assign(block[k + m_block * d.Kb], std::move(rhs)); + }; + if (enable_sub_group_reads && m_block < first_m_block_with_check) { + store(sub_group_block_read_helper(d.pointer, scalar_ty, d.as)); + } else { + auto rhs = d.pointer[d.stride[M_mode] * + (get_sub_group_local_id() + m_block * core_cfg.subgroup_size)]; + if (m_block >= first_m_block_with_check) { + rhs = ternary_conditional(d.condition(m_block, core_cfg.subgroup_size), + std::move(rhs), 0); + } + store(std::move(rhs)); + } + } + bb.add(add_into(d.pointer, d.stride[1 - M_mode])); + } + return std::make_unique(std::move(block), d.Kb); +} + +auto read_matrix_block_vector(block_builder &bb, matrix_block_description const &d, int M_mode, + core_config const &core_cfg, char const *block_name) + -> std::unique_ptr { + assert(M_mode == 0 || M_mode == 1); + + const int m_blocks = 1 + (d.Mb - 1) / core_cfg.subgroup_size; + const auto dt = clir::data_type(to_clir_builtin_ty(d.ty), d.Kb); + auto block = bb.declare(array_of(dt, m_blocks), block_name); + + int first_m_block_with_check = d.first_block_with_check(core_cfg.subgroup_size); + for (int m_block = 0; m_block < m_blocks; ++m_block) { + auto rhs = vload_helper(d.Kb, 0, + d.pointer + d.stride[M_mode] * (get_sub_group_local_id() + + m_block * core_cfg.subgroup_size)); + if (!bool(rhs)) { + throw internal_compiler_error(); + } + if (m_block >= first_m_block_with_check) { + rhs = ternary_conditional(d.condition(m_block, core_cfg.subgroup_size), rhs, + init_vector(dt, {0})); + } + bb.assign(block[m_block], std::move(rhs)); + } + bb.add(add_into(d.pointer, d.Kb * d.stride[1 - M_mode])); + + return std::make_unique(std::move(block)); +} + +auto read_matrix_block(block_builder &bb, matrix_block_description const &d, int M_mode, + core_config const &core_cfg, char const *block_name) + -> std::unique_ptr { + assert(M_mode == 0 || M_mode == 1); + + if (d.is_unit_stride(1 - M_mode) && + (d.Kb == 2 || d.Kb == 3 || d.Kb == 4 || d.Kb == 8 || d.Kb == 16)) { + return read_matrix_block_vector(bb, d, M_mode, core_cfg, block_name); + } + return read_matrix_block_regular(bb, d, M_mode, core_cfg, block_name); +} + +void write_matrix_block(block_builder &bb, block_accessor const &block, + matrix_block_description const &d, bool is_atomic, expr alpha, expr beta, + core_config const &core_cfg) { + const int m_blocks = 1 + (d.Mb - 1) / core_cfg.subgroup_size; + + const int first_m_block_with_check = d.first_block_with_check(core_cfg.subgroup_size); + for (int k = 0; k < d.Kb; ++k) { + for (int m_block = 0; m_block < m_blocks; ++m_block) { + const auto write = [&](block_builder &bb) { + store_helper(bb, is_atomic, + d.pointer + d.stride[0] * (get_sub_group_local_id() + + m_block * core_cfg.subgroup_size), + d.ty, d.as, alpha * block.get(m_block, k), beta); + }; + if (m_block >= first_m_block_with_check) { + bb.add(if_selection_builder(d.condition(m_block, core_cfg.subgroup_size)) + .then(write) + .get_product()); + } else { + write(bb); + } + } + bb.add(add_into(d.pointer, d.stride[1])); + } +} + } // namespace tinytc diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index 4825b0a5..926fbd25 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -4,19 +4,28 @@ #ifndef CODEGEN_TOOLS_20240229_HPP #define CODEGEN_TOOLS_20240229_HPP +#include "device_info.hpp" #include "tinytc/types.hpp" -#include -#include - #include #include #include #include +#include +#include +#include +#include +#include + namespace tinytc { +clir::expr as_type(clir::builtin_type ty, clir::expr e); clir::expr vload_helper(short vec_size, clir::expr offset, clir::expr ptr); +clir::expr sub_group_block_read_helper(clir::expr pointer, clir::builtin_type scalar_ty, + clir::address_space as); +clir::expr sub_group_block_write_helper(clir::expr pointer, clir::expr data, + clir::builtin_type scalar_ty, clir::address_space as); void store_helper(clir::block_builder &bb, bool is_atomic, clir::expr dst, scalar_type ty, clir::address_space as, clir::expr value, clir::expr beta); @@ -52,6 +61,62 @@ void tile_loop_uniformly_dynamic(clir::block_builder &bb, clir::expr loop_trip_c unsigned block_size, unsigned num_tiles, clir::var sg_id, uniform_loop_body_builder const &body); +class block_accessor { + public: + virtual ~block_accessor() = default; + virtual auto get(int m_block, int k) const -> clir::expr = 0; +}; + +class block_accessor_regular : public block_accessor { + public: + block_accessor_regular(clir::expr block, int Kb, clir::expr offset = clir::expr{nullptr}); + auto get(int m_block, int k) const -> clir::expr override; + inline auto offset(clir::expr offset) { offset_ = std::move(offset); } + + private: + clir::expr block_; + int Kb_; + clir::expr offset_; +}; + +class block_accessor_vector : public block_accessor { + public: + block_accessor_vector(clir::expr block); + auto get(int m_block, int k) const -> clir::expr override; + + private: + clir::expr block_; +}; + +struct matrix_block_description { + scalar_type ty; ///< Matrix scalar type + clir::address_space as; ///< Matrix address space + int Mb; ///< Number of rows if M_mode == 0; number of columns if M_mode == 1 + int Kb; ///< Number of columns if M_mode == 0; number of rows if M_mode == 0 + clir::expr pointer; ///< Pointer to block start + clir::expr M; ///< Size of row mode if M_mode == 0; size of column mode if M_mode == 1 + std::array stride; ///< Matrix stride + + int first_block_with_check(std::int32_t subgroup_size) const; + clir::expr condition(int m_block, std::int32_t subgroup_size) const; + bool is_unit_stride(int mode) const; +}; + +auto read_matrix_block_regular(clir::block_builder &bb, matrix_block_description const &d, + int M_mode, core_config const &core_cfg, char const *block_name) + -> std::unique_ptr; +auto read_matrix_block_vector(clir::block_builder &bb, matrix_block_description const &d, + int M_mode, core_config const &core_cfg, char const *block_name) + -> std::unique_ptr; +// Read MbxKb block +auto read_matrix_block(clir::block_builder &bb, matrix_block_description const &d, int M_mode, + core_config const &core_cfg, char const *block_name) + -> std::unique_ptr; + +void write_matrix_block(clir::block_builder &bb, block_accessor const &block, + matrix_block_description const &d, bool is_atomic, clir::expr alpha, + clir::expr beta, core_config const &core_cfg); + } // namespace tinytc #endif // CODEGEN_TOOLS_20240229_HPP diff --git a/src/gemm_generator.cpp b/src/gemm_generator.cpp index f101d219..c98c5fd6 100644 --- a/src/gemm_generator.cpp +++ b/src/gemm_generator.cpp @@ -24,8 +24,8 @@ #include #include #include +#include #include -#include using namespace clir; @@ -86,11 +86,11 @@ std::string gemm_configuration::identifier(std::string_view prefix) const { constexpr static int max_K_unrolling = 8; -auto max_register_block_gemm(std::uint32_t C_scalar_type_size_in_bytes, std::uint32_t sgs, - std::uint32_t register_space, - std::pair max_fill_fraction) - -> std::pair { - auto const arithmetic_intensity = [&sgs](std::uint32_t row_blocks, std::uint32_t cols) { +auto max_register_block_gemm(std::int32_t C_scalar_type_size_in_bytes, std::int32_t sgs, + std::int32_t register_space, + std::pair max_fill_fraction) + -> std::pair { + auto const arithmetic_intensity = [&sgs](std::int32_t row_blocks, std::int32_t cols) { return (row_blocks * sgs * cols) / static_cast(row_blocks * sgs + cols); }; @@ -99,18 +99,18 @@ auto max_register_block_gemm(std::uint32_t C_scalar_type_size_in_bytes, std::uin // The required number of scalars is given by // row_blocks * sgs * (cols + max_K_unrolling) + cols * max_K_unrolling - auto const max_row_blocks = [&sgs, &max_scalars](std::uint32_t cols) { + auto const max_row_blocks = [&sgs, &max_scalars](std::int32_t cols) { return (max_scalars - cols * max_K_unrolling) / (sgs * (cols + max_K_unrolling)); }; - auto const max_cols = [&sgs, &max_scalars](std::uint32_t row_blocks) { + auto const max_cols = [&sgs, &max_scalars](std::int32_t row_blocks) { return (max_scalars - row_blocks * sgs * max_K_unrolling) / (row_blocks * sgs + max_K_unrolling); }; double max_ai = 0.0; - std::uint32_t row_blocks = 1, cols = 1; - for (std::uint32_t r = 1; r <= max_row_blocks(1); ++r) { - for (std::uint32_t c = 1; c <= max_cols(r); ++c) { + std::int32_t row_blocks = 1, cols = 1; + for (std::int32_t r = 1; r <= max_row_blocks(1); ++r) { + for (std::int32_t c = 1; c <= max_cols(r); ++c) { auto const ai = arithmetic_intensity(r, c); if (ai > max_ai) { max_ai = ai; @@ -129,8 +129,8 @@ class generator { core_config const &core_cfg, address_space As, address_space Bs, address_space Cs) : gemm_cfg(gemm_cfg), tiling(tiling), core_cfg(core_cfg), Aspace(As), Bspace(Bs), Cspace(Cs) {} - void add_microkernel(block_builder &bb, bool is_remainder, expr M, expr N, var A, var B, var C, - expr C_offset, expr alpha, expr beta); + void add_microkernel(block_builder &bb, expr M, expr N, var A, var B, var C, expr C_offset, + expr alpha, expr beta); void add_mloop(block_builder &bb, expr N, var A, var B, var C, expr C_offset, expr alpha, expr beta); void add_function_body(block_builder &bb, var A, var B, var C, expr alpha, expr beta); @@ -141,181 +141,81 @@ class generator { local_tiling const tiling; core_config const core_cfg; address_space Aspace, Bspace, Cspace; - unsigned row_blocks_in_register = 1; - unsigned cols_in_register = 1; + int row_blocks_in_register = 1; + int cols_in_register = 1; var c, m; std::array MNK; std::array A_stride, B_stride, C_stride; }; -void generator::add_microkernel(block_builder &bb, bool is_remainder, expr M, expr N, var A, var B, - var C, expr C_offset, expr alpha, expr beta) { - std::int64_t n_bs = 0; - bool is_N_constant = false; +void generator::add_microkernel(block_builder &bb, expr M, expr N, var A, var B, var C, + expr C_offset, expr alpha, expr beta) { + int n_bs = 0; dispatch_constant_dynamic( - N, - [&](std::int64_t n) { - n_bs = n; - is_N_constant = true; - }, - [&](expr) { - n_bs = static_cast(cols_in_register); - is_N_constant = false; - }); - std::int64_t const n_blocks = - 1 + (n_bs - 1) / static_cast(core_cfg.subgroup_size); - auto n = var("n"); + N, [&](std::int64_t n) { n_bs = n; }, + [&](expr) { n_bs = static_cast(cols_in_register); }); auto my_row_blocks_in_register = row_blocks_in_register; dispatch_constant_dynamic( M, - [&](std::int64_t m) { - while (my_row_blocks_in_register > 1 && - m < static_cast(my_row_blocks_in_register) * - core_cfg.subgroup_size) { - --my_row_blocks_in_register; - } - }, + [&](std::int64_t m) { my_row_blocks_in_register = 1 + (m - 1) / core_cfg.subgroup_size; }, [&](expr) {}); + auto const Mb = my_row_blocks_in_register * core_cfg.subgroup_size; - auto const am = gemm_cfg.transA == transpose::T ? 1 : 0; - auto const ak = gemm_cfg.transA == transpose::T ? 0 : 1; auto Ab = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.A}.type(Aspace)), "Ab", A); - auto const Aoffset = [&](unsigned m_block) { - return A_stride[am] * (m + m_block * core_cfg.subgroup_size); - }; - - auto const bn = gemm_cfg.transB == transpose::T ? 0 : 1; - auto const bk = gemm_cfg.transB == transpose::T ? 1 : 0; auto Bb = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.B}.type(Bspace)), "Bb", B); - auto const Boffset = [&](int n_block) { - return B_stride[bn] * (m + n_block * core_cfg.subgroup_size); - }; - auto const cmn = [&](unsigned m_block, expr n) { - return c[m_block + row_blocks_in_register * std::move(n)]; - }; + auto c_block = block_accessor_regular(c, n_bs); - bb.add(for_loop_builder(declaration_assignment(generic_short(), n, 0), n < n_bs, ++n) - .body([&](block_builder &bb) { - for (std::size_t m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - bb.assign(cmn(m_block, n), precision_helper{gemm_cfg.ty.C}.zero()); - } - }) - .attribute(opencl_unroll_hint(n_bs)) - .get_product()); + for (int n = 0; n < n_bs; ++n) { + for (int m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { + bb.assign(c_block.get(m_block, n), precision_helper{gemm_cfg.ty.C}.zero()); + } + } - auto const compute_c = [&](block_builder &bb, std::int64_t Kb, ::clir::expr K0, - ::clir::expr K1) { + auto const compute_c = [&](block_builder &bb, int Kb, ::clir::expr K0, ::clir::expr K1) { auto kb = var("kb"); - bb.add( - for_loop_builder(declaration_assignment(generic_short(), kb, std::move(K0)), - kb < std::move(K1), add_into(kb, Kb)) - .body([&](block_builder &bb) { - auto at = precision_helper{gemm_cfg.ty.A}; - auto a = bb.declare(array_of(at.type(), my_row_blocks_in_register * Kb), "a"); - auto amk = [&](unsigned m_block, unsigned k) { - return a[m_block + my_row_blocks_in_register * k]; - }; - bool const map_b_to_vec_type = - gemm_cfg.B_stride[bk] == 1 && - (Kb == 2 || Kb == 3 || Kb == 4 || Kb == 8 || Kb == 16); - int k_load_block_size = map_b_to_vec_type ? Kb : 1; - auto bt = precision_helper{gemm_cfg.ty.B}; - auto b = map_b_to_vec_type - ? bb.declare(array_of(bt.type(Kb), n_blocks), "b") - : bb.declare(array_of(bt.type(), n_blocks * Kb), "b"); - auto const read_A = [&](block_builder &bb, unsigned m_block, unsigned k, - bool check) { - auto condition = m + m_block * core_cfg.subgroup_size < M; - auto rhs = Ab[Aoffset(m_block)]; - auto rhs_checked = - check ? ternary_conditional(std::move(condition), rhs, 0) : rhs; - bb.assign(amk(m_block, k), std::move(rhs_checked)); - }; - auto block_read_A = [&](block_builder &bb, unsigned m_block, unsigned k) { - bb.assign( - amk(m_block, k), - at.sub_group_block_read(Ab + m_block * core_cfg.subgroup_size, Aspace)); - }; - for (unsigned k = 0; k < Kb; ++k) { - for (unsigned m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - if (!is_remainder && core_cfg.block_read_write_supported && - gemm_cfg.A_stride[am] == 1) { - block_read_A(bb, m_block, k); - } else { - read_A(bb, m_block, k, is_remainder); - } - } - bb.add(add_into(Ab, A_stride[ak])); - } - - auto const read_B = [&](block_builder &bb, int k, int n_block, bool check) { - auto condition = m + n_block * core_cfg.subgroup_size < N; - if (map_b_to_vec_type) { - auto rhs = vload_helper(Kb, 0, Bb + Boffset(n_block)); - if (rhs) { - auto rhs_checked = - check ? ternary_conditional(condition, rhs, - init_vector(bt.type(Kb), {0})) - : rhs; - bb.assign(b[n_block], std::move(rhs_checked)); - } else { - throw std::logic_error("Vload for native type missing"); - } - } else { - auto rhs = Bb[Boffset(n_block)]; - auto rhs_checked = check ? ternary_conditional(condition, rhs, 0) : rhs; - bb.assign(b[k + n_block * Kb], std::move(rhs_checked)); - } - }; - int first_n_block_with_check = - n_bs < n_blocks * static_cast(core_cfg.subgroup_size) - ? n_blocks - 1 - : n_blocks; - if (!is_N_constant) { - first_n_block_with_check = 0; - } - for (int k = 0; k < Kb; k += k_load_block_size) { - for (int n_block = 0; n_block < first_n_block_with_check; ++n_block) { - read_B(bb, k, n_block, false); - } - for (int n_block = first_n_block_with_check; n_block < n_blocks; - ++n_block) { - read_B(bb, k, n_block, true); - } - bb.add(add_into(Bb, k_load_block_size * B_stride[bk])); - } - - const int nbb = 4; - for (unsigned m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - for (std::int64_t nb = 0; nb < n_bs; nb += nbb) { - for (int k = 0; k < Kb; ++k) { - for (std::int64_t n = 0; n < nbb; ++n) { - if (nb + n < n_bs) { - auto const n_block = (nb + n) / core_cfg.subgroup_size; - auto const n_offset = (nb + n) % core_cfg.subgroup_size; - auto my_a = amk(m_block, k); - auto bkn = map_b_to_vec_type ? b[n_block].s(k) - : b[k + n_block * Kb]; - auto my_b = sub_group_broadcast(std::move(bkn), n_offset); - auto my_c = cmn(m_block, nb + n); - if (gemm_cfg.ty.A == gemm_cfg.ty.B && - gemm_cfg.ty.B == gemm_cfg.ty.C) { - bb.assign(my_c, - fma(std::move(my_a), std::move(my_b), my_c)); - } else { - bb.add(add_into(std::move(my_c), - std::move(my_a) * std::move(my_b))); - } - } - } - } - } - } - }) - .attribute(opencl_unroll_hint(1)) - .get_product()); + bb.add(for_loop_builder(declaration_assignment(generic_short(), kb, std::move(K0)), + kb < std::move(K1), add_into(kb, Kb)) + .body([&](block_builder &bb) { + auto const a_descr = + matrix_block_description{gemm_cfg.ty.A, Aspace, Mb, Kb, Ab, M, A_stride}; + auto const am = gemm_cfg.transA == transpose::T ? 1 : 0; + auto const a = read_matrix_block(bb, a_descr, am, core_cfg, "a"); + + auto const b_descr = matrix_block_description{ + gemm_cfg.ty.B, Bspace, n_bs, Kb, Bb, N, B_stride}; + auto const bn = gemm_cfg.transB == transpose::T ? 0 : 1; + auto const b = read_matrix_block(bb, b_descr, bn, core_cfg, "b"); + + const int nbb = 4; + for (int m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { + for (int nb = 0; nb < n_bs; nb += nbb) { + for (int k = 0; k < Kb; ++k) { + for (int n = 0; n < nbb; ++n) { + if (nb + n < n_bs) { + auto const n_block = (nb + n) / core_cfg.subgroup_size; + auto const n_offset = (nb + n) % core_cfg.subgroup_size; + auto my_a = a->get(m_block, k); + auto my_b = + sub_group_broadcast(b->get(n_block, k), n_offset); + auto my_c = c_block.get(m_block, nb + n); + if (gemm_cfg.ty.A == gemm_cfg.ty.B && + gemm_cfg.ty.B == gemm_cfg.ty.C) { + bb.assign(my_c, fma(std::move(my_a), std::move(my_b), + my_c)); + } else { + bb.add(add_into(std::move(my_c), + std::move(my_a) * std::move(my_b))); + } + } + } + } + } + } + }) + .attribute(opencl_unroll_hint(1)) + .get_product()); }; dispatch_constant_dynamic( MNK[2], @@ -339,61 +239,17 @@ void generator::add_microkernel(block_builder &bb, bool is_remainder, expr M, ex .then([&](block_builder &bb) { compute_c(bb, 1, KmultipleKb, K); }) .get_product()); }); - auto write_C = [&](block_builder &bb) { - auto n_to = is_N_constant ? n_bs : min(N, cast(generic_uint(), n_bs)); - auto n_unroll = is_N_constant ? n_bs : 1; -#if 0 - // We can use block writes if - // 1. They are supported (no block writes for SIMD32 before PVC) - // 2. We are not in a remainder loop - // 3. Data is adjacent in memory - // 4. The address is 16 byte aligned - // 5. We are not writing atomically - bool const use_block_write = - core_cfg.block_read_write_supported && !is_remainder && gemm_cfg.C_stride[0] == 1 && - gemm_cfg.C_stride[1] * size(gemm_cfg.ty.C) % 16 == 0 && !gemm_cfg.atomic -#endif - - // Block writes are disabled for now; would need to track memref alignment in subview / load - // instruction AND would need to impose alignment requirement in calling convention - constexpr bool use_block_write = false; - auto Cb = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.C}.type(Cspace)), "Cb", - C + C_offset); - if (!use_block_write) { - bb.add(add_into(Cb, C_stride[0] * m)); - } - bb.add( - for_loop_builder(declaration_assignment(generic_short(), n, 0), n < std::move(n_to), - ++n) - .body([&](block_builder &bb) { - for (std::size_t m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - auto my_c = alpha * cmn(m_block, n); - auto C_offset_m = C_stride[0] * (m_block * core_cfg.subgroup_size); - if (use_block_write) { - bb.add(precision_helper{gemm_cfg.ty.C}.sub_group_block_write( - Cb + m_block * core_cfg.subgroup_size, - std::move(my_c) + beta * Cb[std::move(C_offset_m)], Cspace)); - } else { - auto const write_C_mn = [&](block_builder &bb) { - store_helper(bb, gemm_cfg.atomic, Cb + C_offset_m, gemm_cfg.ty.C, - Cspace, my_c, beta); - }; - if (is_remainder) { - bb.add( - if_selection_builder(m + m_block * core_cfg.subgroup_size < M) - .then(write_C_mn) - .get_product()); - } else { - write_C_mn(bb); - } - } - } - bb.add(add_into(Cb, cast(generic_uint(), C_stride[1]))); - }) - .attribute(opencl_unroll_hint(n_unroll)) - .get_product()); - }; - write_C(bb); + + auto Cb = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.C}.type(Cspace)), "Cb", + C + C_offset); + auto const c_descr = matrix_block_description{gemm_cfg.ty.C, Cspace, Mb, 1, Cb, M, C_stride}; + auto n = var("n"); + bb.add(for_loop_builder(declaration_assignment(generic_short(), n, 0), n < N, ++n) + .body([&](block_builder &bb) { + c_block.offset(n); + write_matrix_block(bb, c_block, c_descr, gemm_cfg.atomic, alpha, beta, core_cfg); + }) + .get_product()); } void generator::add_mloop(block_builder &bb, expr N, var A, var B, var C, expr C_offset, expr alpha, @@ -401,12 +257,11 @@ void generator::add_mloop(block_builder &bb, expr N, var A, var B, var C, expr C auto sg_m = bb.declare_assign(generic_uint(), "sg_m", get_sub_group_id() % tiling.m_tiles()); tile_loop_by_sgs( bb, MNK[0], core_cfg.subgroup_size * row_blocks_in_register, tiling.m_tiles(), - std::move(sg_m), - [&](block_builder &bb, expr block, bool is_remainder, expr inner_trip_count) { + std::move(sg_m), [&](block_builder &bb, expr block, bool, expr inner_trip_count) { auto Astride_m = gemm_cfg.transA == transpose::T ? A_stride[1] : A_stride[0]; auto Ab = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.A}.type(Aspace)), "Ab", A + std::move(Astride_m) * block); - add_microkernel(bb, is_remainder, std::move(inner_trip_count), N, std::move(Ab), B, C, + add_microkernel(bb, std::move(inner_trip_count), N, std::move(Ab), B, C, C_stride[0] * std::move(block) + C_offset, alpha, beta); }); } diff --git a/src/gemm_generator.hpp b/src/gemm_generator.hpp index 8646ef70..cf0ad937 100644 --- a/src/gemm_generator.hpp +++ b/src/gemm_generator.hpp @@ -106,10 +106,10 @@ ::clir::func generate_gemm(gemm_configuration const &gemm_cfg, local_tiling cons * * @return {number of row-blocks (block size = subgroup size), number of columns} */ -auto max_register_block_gemm(std::uint32_t C_scalar_type_size_in_bytes, std::uint32_t sgs, - std::uint32_t register_space, - std::pair max_fill_fraction = {1, 2}) - -> std::pair; +auto max_register_block_gemm(std::int32_t C_scalar_type_size_in_bytes, std::int32_t sgs, + std::int32_t register_space, + std::pair max_fill_fraction = {1, 2}) + -> std::pair; } // namespace tinytc diff --git a/src/visitor/alias_analysis.cpp b/src/visitor/alias_analysis.cpp index 4b6cbf63..25e78ff8 100644 --- a/src/visitor/alias_analysis.cpp +++ b/src/visitor/alias_analysis.cpp @@ -5,6 +5,7 @@ #include "error.hpp" #include "node/data_type_node.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" #include diff --git a/src/visitor/alias_analysis.hpp b/src/visitor/alias_analysis.hpp index 3de375c5..d1ce730c 100644 --- a/src/visitor/alias_analysis.hpp +++ b/src/visitor/alias_analysis.hpp @@ -11,7 +11,6 @@ #include "visitor/aa_results.hpp" #include -#include namespace tinytc { diff --git a/src/visitor/insert_barrier.cpp b/src/visitor/insert_barrier.cpp index 74005d21..d4be98cf 100644 --- a/src/visitor/insert_barrier.cpp +++ b/src/visitor/insert_barrier.cpp @@ -8,7 +8,6 @@ #include #include -#include #include #include #include diff --git a/src/visitor/insert_barrier.hpp b/src/visitor/insert_barrier.hpp index ef0ccc42..ce87904a 100644 --- a/src/visitor/insert_barrier.hpp +++ b/src/visitor/insert_barrier.hpp @@ -12,7 +12,6 @@ #include "node/value_node.hpp" #include "visitor/aa_results.hpp" -#include #include namespace tinytc { diff --git a/src/visitor/stack.hpp b/src/visitor/stack.hpp index 48988e93..d360be46 100644 --- a/src/visitor/stack.hpp +++ b/src/visitor/stack.hpp @@ -10,7 +10,7 @@ #include "node/region_node.hpp" #include "node/value_node.hpp" -#include +#include #include namespace tinytc { diff --git a/test/codegen/atomic.ir b/test/codegen/atomic.ir index 58b06400..2cddc42c 100644 --- a/test/codegen/atomic.ir +++ b/test/codegen/atomic.ir @@ -26,7 +26,7 @@ func @axpby_atomic_general(%alpha: f32, %A: memref, %B: memref) { func @gemm_atomic(%A: memref, %B: memref, %C: memref) { gemm.n.n.atomic 1.0, %A, %B, 1.0, %C : f32, memref, memref, f32, memref -; CHECK: atomic_fetch_add_explicit((global volatile atomic_float*) Cb, c[n], memory_order_relaxed, memory_scope_work_group); +; CHECK: atomic_fetch_add_explicit((global volatile atomic_float*) (Cb + get_sub_group_local_id()), c[n], memory_order_relaxed, memory_scope_work_group); } func @ger_atomic(%A: memref, %B: memref, %C: memref) { diff --git a/tools/offline_compiler/main.cpp b/tools/offline_compiler/main.cpp index 31ed33ac..83945726 100644 --- a/tools/offline_compiler/main.cpp +++ b/tools/offline_compiler/main.cpp @@ -9,6 +9,7 @@ #include #include #include +#include using namespace tinytc; From 7715710255a00414c44634701fe24936b044d6b0 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 28 May 2024 13:42:22 +0200 Subject: [PATCH 002/297] Bugfix Signed-off-by: Carsten Uphoff --- src/codegen_tools.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 511f2c18..749e8d66 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -379,7 +379,8 @@ auto read_matrix_block_regular(block_builder &bb, matrix_block_description const bb.assign(block[k + m_block * d.Kb], std::move(rhs)); }; if (enable_sub_group_reads && m_block < first_m_block_with_check) { - store(sub_group_block_read_helper(d.pointer, scalar_ty, d.as)); + store(sub_group_block_read_helper(d.pointer + m_block * core_cfg.subgroup_size, + scalar_ty, d.as)); } else { auto rhs = d.pointer[d.stride[M_mode] * (get_sub_group_local_id() + m_block * core_cfg.subgroup_size)]; From 9436697e52965f8908cd92b3fa1615af3b8e11e1 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 28 May 2024 15:08:22 +0200 Subject: [PATCH 003/297] Refactor codegen Signed-off-by: Carsten Uphoff --- src/CMakeLists.txt | 1 - src/codegen_tools.cpp | 107 ++++++++++++---------------- src/codegen_tools.hpp | 10 +-- src/gemm_generator.cpp | 53 +++++++------- src/precision_helper.cpp | 149 --------------------------------------- src/precision_helper.hpp | 37 ---------- src/scalar_type.cpp | 84 ++++++++++++---------- src/scalar_type.hpp | 5 +- 8 files changed, 123 insertions(+), 323 deletions(-) delete mode 100644 src/precision_helper.cpp delete mode 100644 src/precision_helper.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3d74f3db..7243d976 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -32,7 +32,6 @@ set(SOURCES parser/parse_context.cpp parser.cpp passes.cpp - precision_helper.cpp prog.cpp recipe.cpp recipe/small_gemm_batched.cpp diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 749e8d66..b136e4ce 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -21,6 +21,10 @@ using namespace clir; namespace tinytc { +short bits(scalar_type ty) { return size(ty) * 8; } +expr constant(scalar_type ty, std::int64_t value) { return expr(value, bits(ty)); } +expr constant(scalar_type ty, double value) { return expr(value, bits(ty)); } + expr as_type(builtin_type ty, expr e) { switch (ty) { case builtin_type::char_t: @@ -69,75 +73,53 @@ expr vload_helper(short vec_size, expr offset, expr ptr) { return nullptr; } -builtin_type block_rw_op_type(builtin_type scalar_ty) { - switch (scalar_ty) { - case builtin_type::short_t: - return builtin_type::ushort_t; - case builtin_type::int_t: - case builtin_type::float_t: - return builtin_type::uint_t; - case builtin_type::long_t: - case builtin_type::double_t: - return builtin_type::ulong_t; +struct block_rw_config { + builtin_type cast_type; + expr (*sub_group_block_read)(expr); + expr (*sub_group_block_write)(expr, expr); + expr (*as_type)(expr); +}; + +auto get_block_rw_config(scalar_type ty) { + switch (ty) { + case scalar_type::i16: + return block_rw_config{builtin_type::ushort_t, &intel_sub_group_block_read_us, + &intel_sub_group_block_write_us, &as_short}; + case scalar_type::i32: + return block_rw_config{builtin_type::uint_t, &intel_sub_group_block_read_ui, + &intel_sub_group_block_write_ui, &as_int}; + case scalar_type::f32: + return block_rw_config{builtin_type::uint_t, &intel_sub_group_block_read_ui, + &intel_sub_group_block_write_ui, &as_float}; + case scalar_type::i64: + return block_rw_config{builtin_type::ulong_t, &intel_sub_group_block_read_ul, + &intel_sub_group_block_write_ul, &as_long}; + case scalar_type::f64: + return block_rw_config{builtin_type::ulong_t, &intel_sub_group_block_read_ul, + &intel_sub_group_block_write_ul, &as_double}; default: break; } - return scalar_ty; + return block_rw_config{builtin_type::void_t, nullptr, nullptr, nullptr}; } -expr sub_group_block_read_helper(expr pointer, builtin_type scalar_ty, address_space as) { - auto const make_read = [](builtin_type bt, expr pointer) -> expr { - switch (bt) { - case builtin_type::short_t: - case builtin_type::ushort_t: - return intel_sub_group_block_read_us(std::move(pointer)); - case builtin_type::int_t: - case builtin_type::uint_t: - case builtin_type::float_t: - return intel_sub_group_block_read_ui(std::move(pointer)); - case builtin_type::long_t: - case builtin_type::ulong_t: - case builtin_type::double_t: - return intel_sub_group_block_read_ul(std::move(pointer)); - default: - break; - } +expr sub_group_block_read_helper(expr pointer, scalar_type ty, address_space as) { + const auto cfg = get_block_rw_config(ty); + if (cfg.sub_group_block_read == nullptr) { return pointer[get_sub_group_local_id()]; - }; - auto const bt = block_rw_op_type(scalar_ty); - pointer = cast(pointer_to(clir::data_type(bt, as)), std::move(pointer)); - auto inst = make_read(bt, std::move(pointer)); - if (bt != scalar_ty) { - return as_type(scalar_ty, std::move(inst)); } - return inst; + pointer = cast(pointer_to(clir::data_type(cfg.cast_type, as)), std::move(pointer)); + auto inst = (*cfg.sub_group_block_read)(std::move(pointer)); + return (*cfg.as_type)(std::move(inst)); } -expr sub_group_block_write_helper(expr pointer, expr data, builtin_type scalar_ty, - address_space as) { - auto const make_write = [](builtin_type bt, expr pointer, expr data) -> expr { - switch (bt) { - case builtin_type::short_t: - case builtin_type::ushort_t: - return intel_sub_group_block_write_us(std::move(pointer), std::move(data)); - case builtin_type::int_t: - case builtin_type::uint_t: - case builtin_type::float_t: - return intel_sub_group_block_write_ui(std::move(pointer), std::move(data)); - case builtin_type::long_t: - case builtin_type::ulong_t: - case builtin_type::double_t: - return intel_sub_group_block_write_ul(std::move(pointer), std::move(data)); - default: - break; - } +expr sub_group_block_write_helper(expr pointer, expr data, scalar_type ty, address_space as) { + const auto cfg = get_block_rw_config(ty); + if (cfg.sub_group_block_write == nullptr) { return pointer[get_sub_group_local_id()] = std::move(data); - }; - auto const bt = block_rw_op_type(scalar_ty); - pointer = cast(pointer_to(clir::data_type(bt, as)), std::move(pointer)); - if (bt != scalar_ty) { - data = as_type(bt, std::move(data)); } - return make_write(bt, std::move(pointer), std::move(data)); + pointer = cast(pointer_to(clir::data_type(cfg.cast_type, as)), std::move(pointer)); + data = (*cfg.as_type)(std::move(data)); + return (*cfg.sub_group_block_write)(std::move(pointer), std::move(data)); } void store_helper(block_builder &bb, bool is_atomic, expr dst, scalar_type ty, address_space as, @@ -367,8 +349,7 @@ auto read_matrix_block_regular(block_builder &bb, matrix_block_description const assert(M_mode == 0 || M_mode == 1); const int m_blocks = 1 + (d.Mb - 1) / core_cfg.subgroup_size; - auto const scalar_ty = to_clir_builtin_ty(d.ty); - auto block = bb.declare(array_of(clir::data_type(scalar_ty), m_blocks * d.Kb), block_name); + auto block = bb.declare(array_of(to_clir_ty(d.ty), m_blocks * d.Kb), block_name); const int first_m_block_with_check = d.first_block_with_check(core_cfg.subgroup_size); const bool enable_sub_group_reads = @@ -380,7 +361,7 @@ auto read_matrix_block_regular(block_builder &bb, matrix_block_description const }; if (enable_sub_group_reads && m_block < first_m_block_with_check) { store(sub_group_block_read_helper(d.pointer + m_block * core_cfg.subgroup_size, - scalar_ty, d.as)); + d.ty, d.as)); } else { auto rhs = d.pointer[d.stride[M_mode] * (get_sub_group_local_id() + m_block * core_cfg.subgroup_size)]; @@ -402,7 +383,7 @@ auto read_matrix_block_vector(block_builder &bb, matrix_block_description const assert(M_mode == 0 || M_mode == 1); const int m_blocks = 1 + (d.Mb - 1) / core_cfg.subgroup_size; - const auto dt = clir::data_type(to_clir_builtin_ty(d.ty), d.Kb); + const auto dt = to_clir_ty(d.ty, d.Kb); auto block = bb.declare(array_of(dt, m_blocks), block_name); int first_m_block_with_check = d.first_block_with_check(core_cfg.subgroup_size); diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index 926fbd25..8d2d8933 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -20,12 +20,14 @@ namespace tinytc { +short bits(scalar_type ty); +clir::expr constant(scalar_type ty, std::int64_t value); +clir::expr constant(scalar_type ty, double value); clir::expr as_type(clir::builtin_type ty, clir::expr e); clir::expr vload_helper(short vec_size, clir::expr offset, clir::expr ptr); -clir::expr sub_group_block_read_helper(clir::expr pointer, clir::builtin_type scalar_ty, - clir::address_space as); -clir::expr sub_group_block_write_helper(clir::expr pointer, clir::expr data, - clir::builtin_type scalar_ty, clir::address_space as); +clir::expr sub_group_block_read_helper(clir::expr pointer, scalar_type ty, clir::address_space as); +clir::expr sub_group_block_write_helper(clir::expr pointer, clir::expr data, scalar_type ty, + clir::address_space as); void store_helper(clir::block_builder &bb, bool is_atomic, clir::expr dst, scalar_type ty, clir::address_space as, clir::expr value, clir::expr beta); diff --git a/src/gemm_generator.cpp b/src/gemm_generator.cpp index c98c5fd6..7337d0f1 100644 --- a/src/gemm_generator.cpp +++ b/src/gemm_generator.cpp @@ -4,7 +4,6 @@ #include "gemm_generator.hpp" #include "codegen_tools.hpp" #include "device_info.hpp" -#include "precision_helper.hpp" #include "scalar_type.hpp" #include "tiling.hpp" #include "tinytc/tinytc.hpp" @@ -162,14 +161,14 @@ void generator::add_microkernel(block_builder &bb, expr M, expr N, var A, var B, [&](expr) {}); auto const Mb = my_row_blocks_in_register * core_cfg.subgroup_size; - auto Ab = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.A}.type(Aspace)), "Ab", A); - auto Bb = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.B}.type(Bspace)), "Bb", B); + auto Ab = bb.declare_assign(pointer_to(to_clir_ty(gemm_cfg.ty.A, Aspace)), "Ab", A); + auto Bb = bb.declare_assign(pointer_to(to_clir_ty(gemm_cfg.ty.B, Bspace)), "Bb", B); auto c_block = block_accessor_regular(c, n_bs); for (int n = 0; n < n_bs; ++n) { for (int m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - bb.assign(c_block.get(m_block, n), precision_helper{gemm_cfg.ty.C}.zero()); + bb.assign(c_block.get(m_block, n), constant(gemm_cfg.ty.C, 0.0)); } } @@ -240,8 +239,7 @@ void generator::add_microkernel(block_builder &bb, expr M, expr N, var A, var B, .get_product()); }); - auto Cb = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.C}.type(Cspace)), "Cb", - C + C_offset); + auto Cb = bb.declare_assign(pointer_to(to_clir_ty(gemm_cfg.ty.C, Cspace)), "Cb", C + C_offset); auto const c_descr = matrix_block_description{gemm_cfg.ty.C, Cspace, Mb, 1, Cb, M, C_stride}; auto n = var("n"); bb.add(for_loop_builder(declaration_assignment(generic_short(), n, 0), n < N, ++n) @@ -259,8 +257,8 @@ void generator::add_mloop(block_builder &bb, expr N, var A, var B, var C, expr C bb, MNK[0], core_cfg.subgroup_size * row_blocks_in_register, tiling.m_tiles(), std::move(sg_m), [&](block_builder &bb, expr block, bool, expr inner_trip_count) { auto Astride_m = gemm_cfg.transA == transpose::T ? A_stride[1] : A_stride[0]; - auto Ab = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.A}.type(Aspace)), - "Ab", A + std::move(Astride_m) * block); + auto Ab = bb.declare_assign(pointer_to(to_clir_ty(gemm_cfg.ty.A, Aspace)), "Ab", + A + std::move(Astride_m) * block); add_microkernel(bb, std::move(inner_trip_count), N, std::move(Ab), B, C, C_stride[0] * std::move(block) + C_offset, alpha, beta); }); @@ -292,20 +290,19 @@ void generator::add_function_body(block_builder &bb, var A, var B, var C, expr a cols_in_register = tile_loop_uniformly_max_block_size(gemm_cfg.N, cols_in_register, tiling.n_tiles()); } - bb.declare( - array_of(precision_helper{gemm_cfg.ty.C}.type(), row_blocks_in_register * cols_in_register), - c); + bb.declare(array_of(to_clir_ty(gemm_cfg.ty.C), row_blocks_in_register * cols_in_register), c); auto sg_n = bb.declare_assign(generic_uint(), "sg_n", get_sub_group_id() / tiling.m_tiles()); - tile_loop_uniformly( - bb, MNK[1], max_cols, tiling.n_tiles(), std::move(sg_n), - [&](block_builder &bb, expr block, expr inner_trip_count) { - auto Bstride_n = gemm_cfg.transB == transpose::T ? B_stride[0] : B_stride[1]; - auto Bb = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.B}.type(Bspace)), - "Bb", B + std::move(Bstride_n) * block); - add_mloop(bb, std::move(inner_trip_count), A, std::move(Bb), C, - C_stride[1] * std::move(block), alpha, beta); - }); + tile_loop_uniformly(bb, MNK[1], max_cols, tiling.n_tiles(), std::move(sg_n), + [&](block_builder &bb, expr block, expr inner_trip_count) { + auto Bstride_n = + gemm_cfg.transB == transpose::T ? B_stride[0] : B_stride[1]; + auto Bb = + bb.declare_assign(pointer_to(to_clir_ty(gemm_cfg.ty.B, Bspace)), + "Bb", B + std::move(Bstride_n) * block); + add_mloop(bb, std::move(inner_trip_count), A, std::move(Bb), C, + C_stride[1] * std::move(block), alpha, beta); + }); } ::clir::func generator::function(std::string_view name) { @@ -314,11 +311,11 @@ ::clir::func generator::function(std::string_view name) { auto C = var("C"); auto fb = ::clir::function_builder{std::string(name)}; - auto const scalar = [&](precision_helper const &fph, std::optional const &val, + auto const scalar = [&](scalar_type ty, std::optional const &val, std::string const &prefix) -> expr { auto v = var{prefix}; - fb.argument(fph.type(), v); - return val ? fph.constant(*val) : v; + fb.argument(to_clir_ty(ty), v); + return val ? constant(ty, *val) : v; }; auto const shape = [&](std::int64_t shape, expr &target, std::string const &prefix) { auto v = var{prefix}; @@ -337,13 +334,13 @@ ::clir::func generator::function(std::string_view name) { shape(gemm_cfg.M, MNK[0], "M"); shape(gemm_cfg.N, MNK[1], "N"); shape(gemm_cfg.K, MNK[2], "K"); - expr alpha = scalar(precision_helper{gemm_cfg.ty.alpha}, gemm_cfg.alpha, "alpha"); - fb.argument(pointer_to(precision_helper{gemm_cfg.ty.A}.type(Aspace)), A); + expr alpha = scalar(gemm_cfg.ty.alpha, gemm_cfg.alpha, "alpha"); + fb.argument(pointer_to(to_clir_ty(gemm_cfg.ty.A, Aspace)), A); stride(gemm_cfg.A_stride, A_stride, "A_stride"); - fb.argument(pointer_to(precision_helper{gemm_cfg.ty.B}.type(Bspace)), B); + fb.argument(pointer_to(to_clir_ty(gemm_cfg.ty.B, Bspace)), B); stride(gemm_cfg.B_stride, B_stride, "B_stride"); - expr beta = scalar(precision_helper{gemm_cfg.ty.beta}, gemm_cfg.beta, "beta"); - fb.argument(pointer_to(precision_helper{gemm_cfg.ty.C}.type(Cspace)), C); + expr beta = scalar(gemm_cfg.ty.beta, gemm_cfg.beta, "beta"); + fb.argument(pointer_to(to_clir_ty(gemm_cfg.ty.C, Cspace)), C); stride(gemm_cfg.C_stride, C_stride, "C_stride"); fb.body([&](block_builder &bb) { add_function_body(bb, A, B, C, alpha, beta); }); diff --git a/src/precision_helper.cpp b/src/precision_helper.cpp deleted file mode 100644 index e5498a03..00000000 --- a/src/precision_helper.cpp +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "precision_helper.hpp" -#include "scalar_type.hpp" -#include "tinytc/tinytc.hpp" - -#include -#include -#include -#include - -#include - -using clir::address_space; -using clir::as_char; -using clir::as_double; -using clir::as_float; -using clir::as_int; -using clir::as_long; -using clir::as_short; -using clir::as_uchar; -using clir::as_uint; -using clir::as_ulong; -using clir::builtin_type; -using clir::cast; -using clir::expr; -using clir::get_sub_group_local_id; -using clir::intel_sub_group_block_read_ui; -using clir::intel_sub_group_block_read_ul; -using clir::intel_sub_group_block_read_us; -using clir::intel_sub_group_block_write_ui; -using clir::intel_sub_group_block_write_ul; -using clir::intel_sub_group_block_write_us; -using clir::pointer_to; - -namespace tinytc { - -precision_helper::precision_helper(scalar_type ty) : ty_(ty) {} -builtin_type precision_helper::base_type() const { return to_clir_builtin_ty(ty_); } -builtin_type precision_helper::block_rw_base_type() const { - auto bt = base_type(); - switch (bt) { - case builtin_type::short_t: - return builtin_type::ushort_t; - case builtin_type::int_t: - case builtin_type::float_t: - return builtin_type::uint_t; - case builtin_type::long_t: - case builtin_type::double_t: - return builtin_type::ulong_t; - default: - break; - } - return bt; -} -expr precision_helper::as_type(builtin_type ty, expr e) const { - switch (ty) { - case builtin_type::char_t: - return as_char(std::move(e)); - case builtin_type::uchar_t: - return as_uchar(std::move(e)); - case builtin_type::short_t: - return as_short(std::move(e)); - case builtin_type::ushort_t: - return as_ushort(std::move(e)); - case builtin_type::int_t: - return as_int(std::move(e)); - case builtin_type::uint_t: - return as_uint(std::move(e)); - case builtin_type::long_t: - return as_long(std::move(e)); - case builtin_type::ulong_t: - return as_ulong(std::move(e)); - case builtin_type::float_t: - return as_float(std::move(e)); - case builtin_type::double_t: - return as_double(std::move(e)); - default: - break; - } - return e; -} -short precision_helper::bits() const { return size(ty_) * 8; } -clir::data_type precision_helper::type(address_space as) const { - return clir::data_type(base_type(), as); -} -clir::data_type precision_helper::type(short size, address_space as) const { - return clir::data_type(base_type(), size, as); -} -// TODO: Think of something for integer constants -expr precision_helper::constant(double value) const { return expr(value, bits()); } -expr precision_helper::zero() const { return constant(0.0); } - -expr precision_helper::sub_group_block_read(expr address, address_space as) const { - auto const make_read = [](builtin_type bt, expr address) -> expr { - switch (bt) { - case builtin_type::short_t: - case builtin_type::ushort_t: - return intel_sub_group_block_read_us(std::move(address)); - case builtin_type::int_t: - case builtin_type::uint_t: - case builtin_type::float_t: - return intel_sub_group_block_read_ui(std::move(address)); - case builtin_type::long_t: - case builtin_type::ulong_t: - case builtin_type::double_t: - return intel_sub_group_block_read_ul(std::move(address)); - default: - break; - } - return address[get_sub_group_local_id()]; - }; - auto bt = block_rw_base_type(); - address = cast(pointer_to(clir::data_type(bt, as)), std::move(address)); - auto inst = make_read(bt, std::move(address)); - if (bt != base_type()) { - return as_type(base_type(), std::move(inst)); - } - return inst; -} -expr precision_helper::sub_group_block_write(expr address, expr data, address_space as) const { - auto const make_write = [](builtin_type bt, expr address, expr data) -> expr { - switch (bt) { - case builtin_type::short_t: - case builtin_type::ushort_t: - return intel_sub_group_block_write_us(std::move(address), std::move(data)); - case builtin_type::int_t: - case builtin_type::uint_t: - case builtin_type::float_t: - return intel_sub_group_block_write_ui(std::move(address), std::move(data)); - case builtin_type::long_t: - case builtin_type::ulong_t: - case builtin_type::double_t: - return intel_sub_group_block_write_ul(std::move(address), std::move(data)); - default: - break; - } - return address[get_sub_group_local_id()] = std::move(data); - }; - auto bt = block_rw_base_type(); - address = cast(pointer_to(clir::data_type(bt, as)), std::move(address)); - if (bt != base_type()) { - data = as_type(bt, std::move(data)); - } - return make_write(bt, std::move(address), std::move(data)); -} - -} // namespace tinytc diff --git a/src/precision_helper.hpp b/src/precision_helper.hpp deleted file mode 100644 index 32445697..00000000 --- a/src/precision_helper.hpp +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef PRECISION_HELPER_20230214_HPP -#define PRECISION_HELPER_20230214_HPP - -#include "tinytc/types.hpp" - -#include "clir/builtin_type.hpp" -#include "clir/data_type.hpp" -#include "clir/expr.hpp" - -namespace tinytc { - -class precision_helper { - public: - precision_helper(scalar_type ty); - clir::builtin_type base_type() const; - clir::builtin_type block_rw_base_type() const; - clir::expr as_type(clir::builtin_type ty, clir::expr e) const; - short bits() const; - clir::data_type type(clir::address_space as = clir::address_space::generic_t) const; - clir::data_type type(short size, clir::address_space as = clir::address_space::generic_t) const; - clir::expr constant(double value) const; - clir::expr zero() const; - clir::expr sub_group_block_read(clir::expr address, - clir::address_space as = clir::address_space::generic_t) const; - clir::expr sub_group_block_write(clir::expr address, clir::expr data, - clir::address_space as = clir::address_space::generic_t) const; - - private: - scalar_type ty_; -}; - -} // namespace tinytc - -#endif // PRECISION_HELPER_20230214_HPP diff --git a/src/scalar_type.cpp b/src/scalar_type.cpp index af4ef4d1..fb5cfb12 100644 --- a/src/scalar_type.cpp +++ b/src/scalar_type.cpp @@ -21,52 +21,58 @@ bool is_floating_type(scalar_type ty) { return false; } -clir::builtin_type to_clir_builtin_ty(scalar_type ty) { - switch (ty) { - case scalar_type::i1: - return clir::builtin_type::bool_t; - case scalar_type::i8: - return clir::builtin_type::char_t; - case scalar_type::i16: - return clir::builtin_type::short_t; - case scalar_type::i32: - return clir::builtin_type::int_t; - case scalar_type::i64: - return clir::builtin_type::long_t; - case scalar_type::index: - return clir::builtin_type::long_t; - case scalar_type::f32: - return clir::builtin_type::float_t; - case scalar_type::f64: - return clir::builtin_type::double_t; - } - return clir::builtin_type::void_t; -} - clir::data_type to_clir_ty(scalar_type ty, clir::address_space as, clir::type_qualifier q) { - return clir::data_type(to_clir_builtin_ty(ty), as, q); + return to_clir_ty(ty, 1, as, q); } -clir::builtin_type to_clir_atomic_builtin_ty(scalar_type ty) { - switch (ty) { - case scalar_type::i32: - return clir::builtin_type::atomic_int_t; - case scalar_type::i64: - return clir::builtin_type::atomic_long_t; - case scalar_type::index: - return clir::builtin_type::atomic_long_t; - case scalar_type::f32: - return clir::builtin_type::atomic_float_t; - case scalar_type::f64: - return clir::builtin_type::atomic_double_t; - default: - break; +clir::data_type to_clir_ty(scalar_type ty, short size, clir::address_space as, + clir::type_qualifier q) { + const auto base_type = [](scalar_type ty) { + switch (ty) { + case scalar_type::i1: + return clir::builtin_type::bool_t; + case scalar_type::i8: + return clir::builtin_type::char_t; + case scalar_type::i16: + return clir::builtin_type::short_t; + case scalar_type::i32: + return clir::builtin_type::int_t; + case scalar_type::i64: + return clir::builtin_type::long_t; + case scalar_type::index: + return clir::builtin_type::long_t; + case scalar_type::f32: + return clir::builtin_type::float_t; + case scalar_type::f64: + return clir::builtin_type::double_t; + } + return clir::builtin_type::void_t; + }; + if (size == 1) { + return clir::data_type(base_type(ty), as, q); } - return clir::builtin_type::void_t; + return clir::data_type(base_type(ty), size, as, q); } clir::data_type to_clir_atomic_ty(scalar_type ty, clir::address_space as, clir::type_qualifier q) { - return clir::data_type(to_clir_atomic_builtin_ty(ty), as, q); + auto const base_type = [](scalar_type ty) { + switch (ty) { + case scalar_type::i32: + return clir::builtin_type::atomic_int_t; + case scalar_type::i64: + return clir::builtin_type::atomic_long_t; + case scalar_type::index: + return clir::builtin_type::atomic_long_t; + case scalar_type::f32: + return clir::builtin_type::atomic_float_t; + case scalar_type::f64: + return clir::builtin_type::atomic_double_t; + default: + break; + } + return clir::builtin_type::void_t; + }; + return clir::data_type(base_type(ty), as, q); } } // namespace tinytc diff --git a/src/scalar_type.hpp b/src/scalar_type.hpp index ff0428c1..be685e1e 100644 --- a/src/scalar_type.hpp +++ b/src/scalar_type.hpp @@ -12,10 +12,11 @@ namespace tinytc { bool is_floating_type(scalar_type ty); -clir::builtin_type to_clir_builtin_ty(scalar_type ty); -clir::builtin_type to_clir_atomic_builtin_ty(scalar_type ty); clir::data_type to_clir_ty(scalar_type ty, clir::address_space as = clir::address_space::generic_t, clir::type_qualifier q = clir::type_qualifier::none); +clir::data_type to_clir_ty(scalar_type ty, short size, + clir::address_space as = clir::address_space::generic_t, + clir::type_qualifier q = clir::type_qualifier::none); clir::data_type to_clir_atomic_ty(scalar_type ty, clir::address_space as = clir::address_space::generic_t, clir::type_qualifier q = clir::type_qualifier::none); From fbb67e9e43fb92761378bb1d487b558a2bf7029b Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 1 Aug 2024 04:04:10 -0700 Subject: [PATCH 004/297] SYCL/OpenCL bugfix Signed-off-by: Carsten Uphoff --- include/tinytc/types.h | 1 + include/tinytc/types.hpp | 1 + src/cl/device_info.cpp | 50 +++++++++++++++++++++---------- src/cl/kernel.cpp | 17 ++++++++++- src/error.cpp | 2 ++ src/recipe/small_gemm_batched.cpp | 2 +- src/recipe/tall_and_skinny.cpp | 13 ++++---- 7 files changed, 64 insertions(+), 22 deletions(-) diff --git a/include/tinytc/types.h b/include/tinytc/types.h index a404daa1..85511574 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -39,6 +39,7 @@ typedef enum { tinytc_status_unsupported_backend = 0xc, ///< Unsupported backend (SYCL runtime) tinytc_status_invalid_kernel_arguments = 0xd, ///< Kernel got invalid arguments tinytc_status_unsupported_device = 0xe, ///< Unsupported device + tinytc_status_invalid_core_info = 0xf, ///< Invalid core info object // IR errors tinytc_status_ir_out_of_bounds = 0x100, ///< Out of bounds access tinytc_status_ir_invalid_shape = 0x101, ///< Invalid tensor shape diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 06514da5..99bc1429 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -49,6 +49,7 @@ enum class status { unsupported_backend = tinytc_status_unsupported_backend, invalid_kernel_arguments = tinytc_status_invalid_kernel_arguments, unsupported_device = tinytc_status_unsupported_device, + invalid_core_info = tinytc_status_invalid_core_info, // IR errors ir_out_of_bounds = tinytc_status_ir_out_of_bounds, ir_invalid_shape = tinytc_status_ir_invalid_shape, diff --git a/src/cl/device_info.cpp b/src/cl/device_info.cpp index 76c7fe2b..dbd98a6d 100644 --- a/src/cl/device_info.cpp +++ b/src/cl/device_info.cpp @@ -81,10 +81,12 @@ tinytc_status_t tinytc_cl_core_info_create(tinytc_core_info_t *info, cl_device_i clGetDeviceInfo(device, CL_DEVICE_VENDOR_ID, sizeof(vendor_id), &vendor_id, nullptr)); if (vendor_id == 0x8086) { - cl_version ip_ver; - cl_uint num_eus_per_subslice, num_threads_per_eu; + cl_device_type device_type; std::size_t subgroup_sizes_size = 0; + TINYTC_CL_CHECK_STATUS( + clGetDeviceInfo(device, CL_DEVICE_TYPE, sizeof(device_type), &device_type, nullptr)); + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_SUB_GROUP_SIZES_INTEL, 0, nullptr, &subgroup_sizes_size)); auto subgroup_sizes_long = @@ -95,19 +97,37 @@ tinytc_status_t tinytc_cl_core_info_create(tinytc_core_info_t *info, cl_device_i auto subgroup_sizes = std::vector(subgroup_sizes_long.begin(), subgroup_sizes_long.end()); - TINYTC_CL_CHECK_STATUS( - clGetDeviceInfo(device, CL_DEVICE_IP_VERSION_INTEL, sizeof(ip_ver), &ip_ver, nullptr)); - - TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_NUM_EUS_PER_SUB_SLICE_INTEL, - sizeof(num_eus_per_subslice), &num_eus_per_subslice, - nullptr)); - TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_NUM_THREADS_PER_EU_INTEL, - sizeof(num_threads_per_eu), &num_threads_per_eu, - nullptr)); - - TINYTC_CHECK_STATUS(tinytc_core_info_intel_create(info, ip_ver, num_eus_per_subslice, - num_threads_per_eu, subgroup_sizes.size(), - subgroup_sizes.data())); + if (device_type == CL_DEVICE_TYPE_GPU) { + cl_version ip_ver; + cl_uint num_eus_per_subslice, num_threads_per_eu; + + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_IP_VERSION_INTEL, + sizeof(ip_ver), &ip_ver, nullptr)); + + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_NUM_EUS_PER_SUB_SLICE_INTEL, + sizeof(num_eus_per_subslice), + &num_eus_per_subslice, nullptr)); + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_NUM_THREADS_PER_EU_INTEL, + sizeof(num_threads_per_eu), &num_threads_per_eu, + nullptr)); + + TINYTC_CHECK_STATUS(tinytc_core_info_intel_create( + info, ip_ver, num_eus_per_subslice, num_threads_per_eu, subgroup_sizes.size(), + subgroup_sizes.data())); + } else if (device_type == CL_DEVICE_TYPE_CPU) { + // 32 zmm registers + // @todo: need to do something smarter here + std::uint32_t register_space = 32 * 64; + size_t max_work_group_size; + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, + sizeof(max_work_group_size), + &max_work_group_size, nullptr)); + TINYTC_CHECK_STATUS( + tinytc_core_info_generic_create(info, register_space, max_work_group_size, + subgroup_sizes.size(), subgroup_sizes.data())); + } else { + return tinytc_status_unsupported_device; + } } else if (vendor_id == 0x1002) { // 512 KB / 32 wavefronts // @todo: can this info be queried? diff --git a/src/cl/kernel.cpp b/src/cl/kernel.cpp index 23814944..416d7a35 100644 --- a/src/cl/kernel.cpp +++ b/src/cl/kernel.cpp @@ -14,6 +14,7 @@ #include #include #include +#include extern "C" { @@ -147,10 +148,24 @@ tinytc_status_t tinytc_cl_get_group_size(cl_kernel kernel, size_t *local_size) { if (local_size == nullptr) { return tinytc_status_invalid_arguments; } + constexpr int short_dev_list = 4; cl_program p; cl_device_id d; + cl_uint num_devices; TINYTC_CL_CHECK_STATUS(clGetKernelInfo(kernel, CL_KERNEL_PROGRAM, sizeof(p), &p, nullptr)); - TINYTC_CL_CHECK_STATUS(clGetProgramInfo(p, CL_PROGRAM_DEVICES, sizeof(d), &d, nullptr)); + TINYTC_CL_CHECK_STATUS( + clGetProgramInfo(p, CL_PROGRAM_NUM_DEVICES, sizeof(num_devices), &num_devices, nullptr)); + if (num_devices <= short_dev_list) { + cl_device_id dbuf[4]; + TINYTC_CL_CHECK_STATUS( + clGetProgramInfo(p, CL_PROGRAM_DEVICES, sizeof(dbuf), &dbuf, nullptr)); + d = dbuf[0]; + } else { + auto dbuf = std::vector(num_devices); + TINYTC_CL_CHECK_STATUS(clGetProgramInfo( + p, CL_PROGRAM_DEVICES, num_devices * sizeof(cl_device_id), dbuf.data(), nullptr)); + d = dbuf[0]; + } return tinytc_cl_convert_status( clGetKernelWorkGroupInfo(kernel, d, CL_KERNEL_COMPILE_WORK_GROUP_SIZE, 3 * sizeof(std::size_t), local_size, nullptr)); diff --git a/src/error.cpp b/src/error.cpp index 2431eecc..d315e54c 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -112,6 +112,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Invalid arguments passed to kernel"; case tinytc_status_unsupported_device: return "Unsupported device"; + case tinytc_status_invalid_core_info: + return "Invalid core info object (e.g. max work group size is 0 or subgroup sizes vector is empty)"; // IR case tinytc_status_ir_out_of_bounds: return "Argument is out of bounds"; diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index 7ace3283..0c01ddfc 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -62,7 +62,7 @@ tinytc_recipe_small_gemm_batched_create(tinytc_recipe_t *recipe, const_tinytc_co std::int32_t source_id = 0; if (ctx) { TINYTC_CHECK_STATUS( - tinytc_source_context_add_source(ctx, "small gemm batched recipe", "", &source_id)); + tinytc_source_context_add_source(ctx, "recipe/small_gemm_batched.cpp", "", &source_id)); } auto const my_loc = [&](std::source_location const loc = std::source_location::current()) { diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 90b2505f..46ec8109 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -73,7 +73,7 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( std::int32_t source_id = 0; if (ctx) { TINYTC_CHECK_STATUS( - tinytc_source_context_add_source(ctx, "tall and skinny recipe", "", &source_id)); + tinytc_source_context_add_source(ctx, "recipe/tall_and_skinny.cpp", "", &source_id)); } auto const my_loc = [&](std::source_location const loc = std::source_location::current()) { @@ -86,7 +86,7 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( return l; }; - if (M_block_size == 0u) { + if (M_block_size == 0) { TINYTC_CHECK_STATUS(tinytc_recipe_tall_and_skinny_suggest_block_size(info, &M_block_size)); } @@ -174,9 +174,12 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_suggest_block_size(const_tinytc_co if (info == nullptr || M_block_size == nullptr) { return tinytc_status_invalid_arguments; } - - return tinytc::exception_to_status_code( - [&] { *M_block_size = std::min(128, info->minmax_work_group_size()); }); + return tinytc::exception_to_status_code([&] { + if (info->minmax_work_group_size() <= 0) { + throw tinytc::status::invalid_core_info; + } + *M_block_size = std::min(128, info->minmax_work_group_size()); + }); } tinytc_status_t tinytc_recipe_tall_and_skinny_set_args( From 6a390f7b9adb53e4d6a8d67580261c84bf82dd5f Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 29 May 2024 12:49:17 +0200 Subject: [PATCH 005/297] Initial complex support in GEMM Signed-off-by: Carsten Uphoff --- examples/benchmark/args.cpp | 15 ++- examples/benchmark/args.hpp | 2 +- examples/benchmark/main.cpp | 47 +++++++--- include/tinytc/tinytc.hpp | 9 ++ include/tinytc/types.h | 4 +- include/tinytc/types.hpp | 4 +- src/codegen_tools.cpp | 55 +++++------ src/codegen_tools.hpp | 15 +-- src/gemm_generator.cpp | 180 ++++++++++++++++++++++++++---------- src/parser/lexer.re | 4 +- src/recipe.cpp | 4 + src/scalar_type.cpp | 50 ++++++++++ src/scalar_type.hpp | 2 + 13 files changed, 280 insertions(+), 111 deletions(-) diff --git a/examples/benchmark/args.cpp b/examples/benchmark/args.cpp index 61030e6b..cdd1ca6e 100644 --- a/examples/benchmark/args.cpp +++ b/examples/benchmark/args.cpp @@ -12,6 +12,7 @@ args arg_parser::parse_args(int argc, char **argv) { args a = {}; a.internal_repetitions = 1; + a.ty = tinytc::scalar_type::f32; a.transA = tinytc::transpose::N; a.transB = tinytc::transpose::N; a.beta = 0.0; @@ -42,10 +43,14 @@ args arg_parser::parse_args(int argc, char **argv) { } else if (std::strcmp(argv[i], "-p") == 0 || std::strcmp(argv[i], "--precision") == 0) { ++i; - if (argv[i][0] == 'd') { - a.double_precision = true; - } else if (argv[i][0] == 's') { - a.double_precision = false; + if (argv[i][0] == 'd' || strcmp(argv[i], "f64") == 0) { + a.ty = tinytc::scalar_type::f64; + } else if (argv[i][0] == 's' || strcmp(argv[i], "f32") == 0) { + a.ty = tinytc::scalar_type::f32; + } else if (strcmp(argv[i], "c64") == 0) { + a.ty = tinytc::scalar_type::c64; + } else if (strcmp(argv[i], "c32") == 0) { + a.ty = tinytc::scalar_type::c32; } else { fail(); } @@ -83,7 +88,7 @@ positional arguments: optional arguments: -h, --help Show help and quit -i, --internal-reps Number of GEMM repetitions inside kernel (default: 1) - -p, --precision Precision (single = s, double = d) + -p, --precision Precision (single = s or f32, double = d or f64, complex = c32, long complex = c64) --trans-a Transpose A matrix --trans-b Transpose B matrix -v, --verify Verify optimized implementation diff --git a/examples/benchmark/args.hpp b/examples/benchmark/args.hpp index 23f69c65..636c2fb3 100644 --- a/examples/benchmark/args.hpp +++ b/examples/benchmark/args.hpp @@ -19,7 +19,7 @@ struct test_case { struct args { std::vector tc; int internal_repetitions; - bool double_precision; + tinytc::scalar_type ty; bool help; tinytc::transpose transA; tinytc::transpose transB; diff --git a/examples/benchmark/main.cpp b/examples/benchmark/main.cpp index f62101a7..e414c3d6 100644 --- a/examples/benchmark/main.cpp +++ b/examples/benchmark/main.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -129,8 +130,8 @@ template void test(queue q, args &a) { q.copy(C, C_host, total_reals).wait(); std::size_t num_err = 0; for (std::size_t i = 0; i < M * N * howmany; ++i) { - auto err = std::abs(C_host[i] - C_ref_host[i]); - if (err > 10.0 * std::numeric_limits::epsilon()) { + const auto err = std::abs(C_host[i] - C_ref_host[i]); + if (err > 10.0 * std::numeric_limits::epsilon()) { if (num_err < 10) { std::cout << i << " " << err << " " << C_host[i] << " " << C_ref_host[i] << std::endl; @@ -143,7 +144,6 @@ template void test(queue q, args &a) { } }; - auto const &type = typeid(T); for (auto &c : a.tc) { auto na = c.m * c.k; auto nb = c.k * c.n; @@ -163,7 +163,7 @@ template void test(queue q, args &a) { auto c_ref = C_ref + batch * nc; for (std::int64_t mb = m; mb < c.m; mb += 32) { for (std::int64_t n = 0; n < c.n; ++n) { - auto c_acc = 0.0f; + auto c_acc = T(0.0); for (std::int64_t k = 0; k < c.k; ++k) { c_acc += a[transa ? k + mb * c.k : mb + k * c.m] * b[transb ? n + k * c.n : k + n * c.k]; @@ -185,9 +185,10 @@ template void test(queue q, args &a) { } double min_exec_time_ns = 0.0; + constexpr auto element_ty = to_scalar_type_v; try { auto src = gemm_kernel_with_inner_repetition( - to_scalar_type_v, a.transA, a.transB, a.atomic, c.m, c.n, c.k, + element_ty, a.transA, a.transB, a.atomic, c.m, c.n, c.k, {1, a.transA == transpose::T ? c.k : c.m}, {1, a.transB == transpose::T ? c.n : c.k}, a.beta, {1, c.m}, a.internal_repetitions, q); @@ -209,14 +210,25 @@ template void test(queue q, args &a) { }).wait(); }); - auto gflops = - a.internal_repetitions * 2 * c.m * c.n * c.k * howmany / min_exec_time_ns; + auto ops_per_mnk = 0; + switch (element_ty) { + case scalar_type::c32: + case scalar_type::c64: + ops_per_mnk = 8; + break; + default: + ops_per_mnk = 2; + break; + } + + auto gflops = a.internal_repetitions * ops_per_mnk * c.m * c.n * c.k * howmany / + min_exec_time_ns; auto roofline_gflops = std::min(512 * 32 * 1.6e9, a.internal_repetitions * 2 * c.m * c.n * c.k / (sizeof(T) * (na + nb + nc) / 1.1e12)) / 1e9; - std::cout << type.name() << "," << c.m << "," << c.n << "," << c.k << "," << howmany - << "," << min_exec_time_ns / 1e9 << "," << gflops << "," + std::cout << to_string(element_ty) << "," << c.m << "," << c.n << "," << c.k << "," + << howmany << "," << min_exec_time_ns / 1e9 << "," << gflops << "," << roofline_gflops << "," << std::round(gflops / roofline_gflops * 100) << "%," << a.internal_repetitions << std::endl; } @@ -260,10 +272,21 @@ int main(int argc, char **argv) { "repetitions" << std::endl; try { - if (a.double_precision) { - test(std::move(q), a); - } else { + switch (a.ty) { + case scalar_type::f32: test(std::move(q), a); + break; + case scalar_type::f64: + test(std::move(q), a); + break; + case scalar_type::c32: + test>(std::move(q), a); + break; + case scalar_type::c64: + test>(std::move(q), a); + break; + default: + return -1; } } catch (std::exception const &e) { std::cerr << e.what() << std::endl; diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 99b19b41..6c828935 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -7,6 +7,7 @@ #include "tinytc/tinytc.h" #include "tinytc/types.hpp" +#include #include #include #include @@ -107,6 +108,14 @@ template <> struct to_scalar_type { template <> struct to_scalar_type { static constexpr scalar_type value = scalar_type::f64; ///< value }; +//! to_scalar_type specialization +template <> struct to_scalar_type> { + static constexpr scalar_type value = scalar_type::c32; ///< value +}; +//! to_scalar_type specialization +template <> struct to_scalar_type> { + static constexpr scalar_type value = scalar_type::c64; ///< value +}; /** * Convenience variable for to_scalar_type. * diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 85511574..f9f1f7ac 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -220,7 +220,9 @@ typedef enum { tinytc_scalar_type_i64 = 4, ///< Signed 64 bit integer tinytc_scalar_type_index = 5, ///< Integer type for indices tinytc_scalar_type_f32 = 6, ///< Single precision floating point (32 bit) - tinytc_scalar_type_f64 = 7 ///< Double precision floating point (64 bit) + tinytc_scalar_type_f64 = 7, ///< Double precision floating point (64 bit) + tinytc_scalar_type_c32 = 8, ///< Single precision complex (2x32 bit) + tinytc_scalar_type_c64 = 9 ///< Double precision complex (2x64 bit) } tinytc_scalar_type_t; //! Arithmetic operations diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 99bc1429..91ed0eb6 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -205,7 +205,9 @@ enum class scalar_type { i64 = tinytc_scalar_type_i64, ///< Signed 64 bit integer index = tinytc_scalar_type_index, ///< Unsigned Integer type for indices f32 = tinytc_scalar_type_f32, ///< Single precision floating point (32 bit) - f64 = tinytc_scalar_type_f64 ///< Double precision floating point (64 bit) + f64 = tinytc_scalar_type_f64, ///< Double precision floating point (64 bit) + c32 = tinytc_scalar_type_c32, ///< Single precision complex (2x32 bit) + c64 = tinytc_scalar_type_c64 ///< Double precision complex (2x64 bit) }; //! Arithmetic operations diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index b136e4ce..e46b2d57 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -23,34 +23,15 @@ namespace tinytc { short bits(scalar_type ty) { return size(ty) * 8; } expr constant(scalar_type ty, std::int64_t value) { return expr(value, bits(ty)); } -expr constant(scalar_type ty, double value) { return expr(value, bits(ty)); } - -expr as_type(builtin_type ty, expr e) { - switch (ty) { - case builtin_type::char_t: - return as_char(std::move(e)); - case builtin_type::uchar_t: - return as_uchar(std::move(e)); - case builtin_type::short_t: - return as_short(std::move(e)); - case builtin_type::ushort_t: - return as_ushort(std::move(e)); - case builtin_type::int_t: - return as_int(std::move(e)); - case builtin_type::uint_t: - return as_uint(std::move(e)); - case builtin_type::long_t: - return as_long(std::move(e)); - case builtin_type::ulong_t: - return as_ulong(std::move(e)); - case builtin_type::float_t: - return as_float(std::move(e)); - case builtin_type::double_t: - return as_double(std::move(e)); - default: - break; +expr constant(scalar_type ty, double value) { + if (is_complex_type(ty)) { + const auto ety = element_type(ty); + return init_vector(to_clir_ty(ty), {constant(ety, value), constant(ety, 0.0)}); } - return e; + return expr(value, bits(ty)); +} +expr complex_mul(scalar_type ty, expr a, expr b) { + return a * b.s(0) + init_vector(to_clir_ty(ty), {-a.s(1), a.s(0)}) * b.s(1); } expr vload_helper(short vec_size, expr offset, expr ptr) { @@ -127,7 +108,15 @@ void store_helper(block_builder &bb, bool is_atomic, expr dst, scalar_type ty, a if (is_atomic) { atomic_store_helper(bb, std::move(dst), ty, as, std::move(value), std::move(beta)); } else { - bb.assign(dereference(dst), std::move(value) + std::move(beta) * dereference(dst)); + auto c_scaled = clir::expr{nullptr}; + if (is_complex_type(ty)) { + c_scaled = bb.declare_assign(to_clir_ty(ty), "c_scaled", dereference(dst)); + auto beta1 = bb.declare_assign(to_clir_ty(ty), "beta", beta); + bb.assign(c_scaled, complex_mul(ty, beta1, c_scaled)); + } else { + c_scaled = beta * dereference(dst); + } + bb.assign(dereference(dst), std::move(value) + std::move(c_scaled)); } } @@ -312,8 +301,8 @@ void tile_loop_uniformly_dynamic(block_builder &bb, expr loop_trip_count, unsign .get_product()); } -block_accessor_regular::block_accessor_regular(expr block, int Kb, expr offset) - : block_(std::move(block)), Kb_(Kb), offset_(std::move(offset)) {} +block_accessor_regular::block_accessor_regular(expr block, int Kb) + : block_(std::move(block)), offset_{clir::expr{nullptr}}, Kb_(Kb) {} auto block_accessor_regular::get(int m_block, int k) const -> expr { const auto i = k + m_block * Kb_; if (offset_) { @@ -410,7 +399,7 @@ auto read_matrix_block(block_builder &bb, matrix_block_description const &d, int -> std::unique_ptr { assert(M_mode == 0 || M_mode == 1); - if (d.is_unit_stride(1 - M_mode) && + if (d.is_unit_stride(1 - M_mode) && !is_complex_type(d.ty) && (d.Kb == 2 || d.Kb == 3 || d.Kb == 4 || d.Kb == 8 || d.Kb == 16)) { return read_matrix_block_vector(bb, d, M_mode, core_cfg, block_name); } @@ -418,7 +407,7 @@ auto read_matrix_block(block_builder &bb, matrix_block_description const &d, int } void write_matrix_block(block_builder &bb, block_accessor const &block, - matrix_block_description const &d, bool is_atomic, expr alpha, expr beta, + matrix_block_description const &d, bool is_atomic, expr beta, core_config const &core_cfg) { const int m_blocks = 1 + (d.Mb - 1) / core_cfg.subgroup_size; @@ -429,7 +418,7 @@ void write_matrix_block(block_builder &bb, block_accessor const &block, store_helper(bb, is_atomic, d.pointer + d.stride[0] * (get_sub_group_local_id() + m_block * core_cfg.subgroup_size), - d.ty, d.as, alpha * block.get(m_block, k), beta); + d.ty, d.as, block.get(m_block, k), beta); }; if (m_block >= first_m_block_with_check) { bb.add(if_selection_builder(d.condition(m_block, core_cfg.subgroup_size)) diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index 8d2d8933..d8c5b59d 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -23,7 +23,7 @@ namespace tinytc { short bits(scalar_type ty); clir::expr constant(scalar_type ty, std::int64_t value); clir::expr constant(scalar_type ty, double value); -clir::expr as_type(clir::builtin_type ty, clir::expr e); +clir::expr complex_mul(scalar_type ty, clir::expr a, clir::expr b); clir::expr vload_helper(short vec_size, clir::expr offset, clir::expr ptr); clir::expr sub_group_block_read_helper(clir::expr pointer, scalar_type ty, clir::address_space as); clir::expr sub_group_block_write_helper(clir::expr pointer, clir::expr data, scalar_type ty, @@ -71,14 +71,13 @@ class block_accessor { class block_accessor_regular : public block_accessor { public: - block_accessor_regular(clir::expr block, int Kb, clir::expr offset = clir::expr{nullptr}); + block_accessor_regular(clir::expr block, int Kb); auto get(int m_block, int k) const -> clir::expr override; - inline auto offset(clir::expr offset) { offset_ = std::move(offset); } + inline void offset(clir::expr offset) { offset_ = std::move(offset); } private: - clir::expr block_; + clir::expr block_, offset_; int Kb_; - clir::expr offset_; }; class block_accessor_vector : public block_accessor { @@ -110,14 +109,16 @@ auto read_matrix_block_regular(clir::block_builder &bb, matrix_block_description auto read_matrix_block_vector(clir::block_builder &bb, matrix_block_description const &d, int M_mode, core_config const &core_cfg, char const *block_name) -> std::unique_ptr; + // Read MbxKb block auto read_matrix_block(clir::block_builder &bb, matrix_block_description const &d, int M_mode, core_config const &core_cfg, char const *block_name) -> std::unique_ptr; +// Write MbxKb block void write_matrix_block(clir::block_builder &bb, block_accessor const &block, - matrix_block_description const &d, bool is_atomic, clir::expr alpha, - clir::expr beta, core_config const &core_cfg); + matrix_block_description const &d, bool is_atomic, clir::expr beta, + core_config const &core_cfg); } // namespace tinytc diff --git a/src/gemm_generator.cpp b/src/gemm_generator.cpp index 7337d0f1..30bff390 100644 --- a/src/gemm_generator.cpp +++ b/src/gemm_generator.cpp @@ -22,6 +22,7 @@ #include #include +#include #include #include #include @@ -128,6 +129,8 @@ class generator { core_config const &core_cfg, address_space As, address_space Bs, address_space Cs) : gemm_cfg(gemm_cfg), tiling(tiling), core_cfg(core_cfg), Aspace(As), Bspace(Bs), Cspace(Cs) {} + bool use_double_buffering() const; + void multiply_update(block_builder &bb, expr a, expr b, int n_offset, expr c, expr c_im); void add_microkernel(block_builder &bb, expr M, expr N, var A, var B, var C, expr C_offset, expr alpha, expr beta); void add_mloop(block_builder &bb, expr N, var A, var B, var C, expr C_offset, expr alpha, @@ -142,11 +145,43 @@ class generator { address_space Aspace, Bspace, Cspace; int row_blocks_in_register = 1; int cols_in_register = 1; - var c, m; + var c_acc, c_acc_im, m; std::array MNK; std::array A_stride, B_stride, C_stride; }; +bool generator::use_double_buffering() const { + return is_complex_type(gemm_cfg.ty.A) && is_complex_type(gemm_cfg.ty.B); +} + +void generator::multiply_update(block_builder &bb, expr a, expr b, int n_offset, expr c, + expr c_im) { + if (is_complex_type(gemm_cfg.ty.A)) { + if (is_complex_type(gemm_cfg.ty.B)) { + assert(use_double_buffering()); + auto b_bc_re = sub_group_broadcast(b.s(0), n_offset); + auto b_bc_im = sub_group_broadcast(b.s(1), n_offset); + bb.add(add_into(c, a * b_bc_re)); + bb.add(add_into(c_im, a * b_bc_im)); + } else { + auto b_bc = sub_group_broadcast(b, n_offset); + bb.add(add_into(std::move(c), std::move(a) * std::move(b_bc))); + } + } else if (is_complex_type(gemm_cfg.ty.B)) { + auto b_bc_re = sub_group_broadcast(b.s(0), n_offset); + auto b_bc_im = sub_group_broadcast(b.s(1), n_offset); + bb.add(add_into(c.s(0), a * b_bc_re)); + bb.add(add_into(c.s(1), a * b_bc_im)); + } else { + auto b_bc = sub_group_broadcast(b, n_offset); + if (gemm_cfg.ty.A == gemm_cfg.ty.B && gemm_cfg.ty.B == gemm_cfg.ty.C) { + bb.assign(c, fma(std::move(a), std::move(b_bc), c)); + } else { + bb.add(add_into(std::move(c), std::move(a) * std::move(b_bc))); + } + } +} + void generator::add_microkernel(block_builder &bb, expr M, expr N, var A, var B, var C, expr C_offset, expr alpha, expr beta) { int n_bs = 0; @@ -164,57 +199,68 @@ void generator::add_microkernel(block_builder &bb, expr M, expr N, var A, var B, auto Ab = bb.declare_assign(pointer_to(to_clir_ty(gemm_cfg.ty.A, Aspace)), "Ab", A); auto Bb = bb.declare_assign(pointer_to(to_clir_ty(gemm_cfg.ty.B, Bspace)), "Bb", B); - auto c_block = block_accessor_regular(c, n_bs); + auto c_block = block_accessor_regular(c_acc, n_bs); + auto c_block_im = block_accessor_regular(c_acc_im, n_bs); for (int n = 0; n < n_bs; ++n) { for (int m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { bb.assign(c_block.get(m_block, n), constant(gemm_cfg.ty.C, 0.0)); + if (use_double_buffering()) { + bb.assign(c_block_im.get(m_block, n), constant(gemm_cfg.ty.C, 0.0)); + } } } auto const compute_c = [&](block_builder &bb, int Kb, ::clir::expr K0, ::clir::expr K1) { auto kb = var("kb"); - bb.add(for_loop_builder(declaration_assignment(generic_short(), kb, std::move(K0)), - kb < std::move(K1), add_into(kb, Kb)) - .body([&](block_builder &bb) { - auto const a_descr = - matrix_block_description{gemm_cfg.ty.A, Aspace, Mb, Kb, Ab, M, A_stride}; - auto const am = gemm_cfg.transA == transpose::T ? 1 : 0; - auto const a = read_matrix_block(bb, a_descr, am, core_cfg, "a"); - - auto const b_descr = matrix_block_description{ - gemm_cfg.ty.B, Bspace, n_bs, Kb, Bb, N, B_stride}; - auto const bn = gemm_cfg.transB == transpose::T ? 0 : 1; - auto const b = read_matrix_block(bb, b_descr, bn, core_cfg, "b"); - - const int nbb = 4; - for (int m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - for (int nb = 0; nb < n_bs; nb += nbb) { - for (int k = 0; k < Kb; ++k) { - for (int n = 0; n < nbb; ++n) { - if (nb + n < n_bs) { - auto const n_block = (nb + n) / core_cfg.subgroup_size; - auto const n_offset = (nb + n) % core_cfg.subgroup_size; - auto my_a = a->get(m_block, k); - auto my_b = - sub_group_broadcast(b->get(n_block, k), n_offset); - auto my_c = c_block.get(m_block, nb + n); - if (gemm_cfg.ty.A == gemm_cfg.ty.B && - gemm_cfg.ty.B == gemm_cfg.ty.C) { - bb.assign(my_c, fma(std::move(my_a), std::move(my_b), - my_c)); - } else { - bb.add(add_into(std::move(my_c), - std::move(my_a) * std::move(my_b))); - } - } - } - } - } - } - }) - .attribute(opencl_unroll_hint(1)) - .get_product()); + bb.add( + for_loop_builder(declaration_assignment(generic_short(), kb, std::move(K0)), + kb < std::move(K1), add_into(kb, Kb)) + .body([&](block_builder &bb) { + auto const a_descr = + matrix_block_description{gemm_cfg.ty.A, Aspace, Mb, Kb, Ab, M, A_stride}; + auto const am = gemm_cfg.transA == transpose::T ? 1 : 0; + auto const a = read_matrix_block(bb, a_descr, am, core_cfg, "a"); + + auto const b_descr = + matrix_block_description{gemm_cfg.ty.B, Bspace, n_bs, Kb, Bb, N, B_stride}; + auto const bn = gemm_cfg.transB == transpose::T ? 0 : 1; + auto const b = read_matrix_block(bb, b_descr, bn, core_cfg, "b"); + + const int nbb = 4; + for (int m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { + for (int nb = 0; nb < n_bs; nb += nbb) { + for (int k = 0; k < Kb; ++k) { + for (int n = 0; n < nbb; ++n) { + if (nb + n < n_bs) { + auto const n_block = (nb + n) / core_cfg.subgroup_size; + auto const n_offset = (nb + n) % core_cfg.subgroup_size; + /*auto my_a = a->get(m_block, k); + auto my_b = + sub_group_broadcast(b->get(n_block, k), n_offset); + auto my_c = c_block.get(m_block, nb + n); + if (gemm_cfg.ty.A == gemm_cfg.ty.B && + gemm_cfg.ty.B == gemm_cfg.ty.C) { + bb.assign(my_c, fma(std::move(my_a), std::move(my_b), + my_c)); + } else { + bb.add(add_into(std::move(my_c), + std::move(my_a) * std::move(my_b))); + }*/ + auto my_a = a->get(m_block, k); + auto my_b = b->get(n_block, k); + auto c_re = c_block.get(m_block, nb + n); + auto c_im = c_block_im.get(m_block, nb + n); + multiply_update(bb, std::move(my_a), std::move(my_b), + n_offset, std::move(c_re), std::move(c_im)); + } + } + } + } + } + }) + .attribute(opencl_unroll_hint(1)) + .get_product()); }; dispatch_constant_dynamic( MNK[2], @@ -242,10 +288,32 @@ void generator::add_microkernel(block_builder &bb, expr M, expr N, var A, var B, auto Cb = bb.declare_assign(pointer_to(to_clir_ty(gemm_cfg.ty.C, Cspace)), "Cb", C + C_offset); auto const c_descr = matrix_block_description{gemm_cfg.ty.C, Cspace, Mb, 1, Cb, M, C_stride}; auto n = var("n"); + c_block.offset(n); + c_block_im.offset(n); bb.add(for_loop_builder(declaration_assignment(generic_short(), n, 0), n < N, ++n) .body([&](block_builder &bb) { - c_block.offset(n); - write_matrix_block(bb, c_block, c_descr, gemm_cfg.atomic, alpha, beta, core_cfg); + if (use_double_buffering()) { + for (int m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { + auto c_im = c_block_im.get(m_block, 0); + auto c_ty = to_clir_ty(gemm_cfg.ty.C); + bb.add(add_into(c_block.get(m_block, 0), + init_vector(c_ty, {-c_im.s(1), c_im.s(0)}))); + } + } + for (int m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { + if (is_complex_type(gemm_cfg.ty.alpha)) { + auto c = c_block.get(m_block, 0); + auto myalpha = gemm_cfg.alpha + ? bb.declare_assign(to_clir_ty(gemm_cfg.ty.alpha), + "alpha", alpha) + : alpha; + bb.assign(c, complex_mul(gemm_cfg.ty.C, myalpha, c)); + + } else { + bb.add(multiply_into(c_block.get(m_block, 0), alpha)); + } + } + write_matrix_block(bb, c_block, c_descr, gemm_cfg.atomic, beta, core_cfg); }) .get_product()); } @@ -266,10 +334,17 @@ void generator::add_mloop(block_builder &bb, expr N, var A, var B, var C, expr C void generator::add_function_body(block_builder &bb, var A, var B, var C, expr alpha, expr beta) { m = bb.declare_assign(generic_uint(), "m", get_sub_group_local_id()); - c = var("c"); - - auto [max_row_blocks, max_cols] = max_register_block_gemm( - size(gemm_cfg.ty.C), core_cfg.subgroup_size, core_cfg.register_space); + c_acc = var("c"); + c_acc_im = var("c_im"); + + auto register_space = core_cfg.register_space; + if (use_double_buffering()) { + // We buffer the real / imag part separately, so we only have half the register space + // available for one of the buffers + register_space /= 2; + } + auto [max_row_blocks, max_cols] = + max_register_block_gemm(size(gemm_cfg.ty.C), core_cfg.subgroup_size, register_space); row_blocks_in_register = max_row_blocks; cols_in_register = max_cols; if (!is_dynamic_value(gemm_cfg.M)) { @@ -290,7 +365,12 @@ void generator::add_function_body(block_builder &bb, var A, var B, var C, expr a cols_in_register = tile_loop_uniformly_max_block_size(gemm_cfg.N, cols_in_register, tiling.n_tiles()); } - bb.declare(array_of(to_clir_ty(gemm_cfg.ty.C), row_blocks_in_register * cols_in_register), c); + bb.declare(array_of(to_clir_ty(gemm_cfg.ty.C), row_blocks_in_register * cols_in_register), + c_acc); + if (use_double_buffering()) { + bb.declare(array_of(to_clir_ty(gemm_cfg.ty.C), row_blocks_in_register * cols_in_register), + c_acc_im); + } auto sg_n = bb.declare_assign(generic_uint(), "sg_n", get_sub_group_id() / tiling.m_tiles()); tile_loop_uniformly(bb, MNK[1], max_cols, tiling.n_tiles(), std::move(sg_n), diff --git a/src/parser/lexer.re b/src/parser/lexer.re index e76b2b87..b59772ca 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -40,7 +40,7 @@ lex: global_identifier = "@" identifier; integer_type = "i" ("1" | "8" | "16" | "32" | "64") | "index"; - floating_type = "f" ("32" | "64"); + floating_type = ("f" | "c") ("32" | "64"); digit = [0-9]; hexdigit = [0-9a-fA-F]; @@ -260,6 +260,8 @@ scalar_type lexer::lex_floating_type(char const *s, char const *) { "f32" { return scalar_type::f32; } "f64" { return scalar_type::f64; } + "c32" { return scalar_type::c32; } + "c64" { return scalar_type::c64; } $ { return {}; } * { return {}; } */ diff --git a/src/recipe.cpp b/src/recipe.cpp index 75c9f18b..9e547295 100644 --- a/src/recipe.cpp +++ b/src/recipe.cpp @@ -32,6 +32,10 @@ auto is_argument_zero(scalar_type type, std::size_t arg_size, const void *arg_va return is_argument_zero(arg_size, arg_value); case scalar_type::f64: return is_argument_zero(arg_size, arg_value); + case scalar_type::c32: + //return is_argument_zero(arg_size, arg_value); + case scalar_type::c64: + //return is_argument_zero(arg_size, arg_value); case scalar_type::i1: break; }; diff --git a/src/scalar_type.cpp b/src/scalar_type.cpp index fb5cfb12..e55da42a 100644 --- a/src/scalar_type.cpp +++ b/src/scalar_type.cpp @@ -21,6 +21,29 @@ bool is_floating_type(scalar_type ty) { return false; } +bool is_complex_type(scalar_type ty) { + switch (ty) { + case scalar_type::c32: + case scalar_type::c64: + return true; + default: + break; + } + return false; +} + +scalar_type element_type(scalar_type ty) { + switch (ty) { + case scalar_type::c32: + return scalar_type::f32; + case scalar_type::c64: + return scalar_type::f64; + default: + break; + } + return ty; +} + clir::data_type to_clir_ty(scalar_type ty, clir::address_space as, clir::type_qualifier q) { return to_clir_ty(ty, 1, as, q); } @@ -42,12 +65,32 @@ clir::data_type to_clir_ty(scalar_type ty, short size, clir::address_space as, case scalar_type::index: return clir::builtin_type::long_t; case scalar_type::f32: + case scalar_type::c32: return clir::builtin_type::float_t; case scalar_type::f64: + case scalar_type::c64: return clir::builtin_type::double_t; } return clir::builtin_type::void_t; }; + const auto components = [](scalar_type ty) -> short { + switch (ty) { + case scalar_type::i1: + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + case scalar_type::f32: + case scalar_type::f64: + return 1; + case scalar_type::c32: + case scalar_type::c64: + return 2; + } + return 0; + }; + size *= components(ty); if (size == 1) { return clir::data_type(base_type(ty), as, q); } @@ -95,6 +138,10 @@ char const *tinytc_scalar_type_to_string(tinytc_scalar_type_t ty) { return "f32"; case tinytc_scalar_type_f64: return "f64"; + case tinytc_scalar_type_c32: + return "c32"; + case tinytc_scalar_type_c64: + return "c64"; } return "unknown"; } @@ -111,7 +158,10 @@ size_t tinytc_scalar_type_size(tinytc_scalar_type_t ty) { case tinytc_scalar_type_i64: case tinytc_scalar_type_index: case tinytc_scalar_type_f64: + case tinytc_scalar_type_c32: return 8; + case tinytc_scalar_type_c64: + return 16; } return 0; } diff --git a/src/scalar_type.hpp b/src/scalar_type.hpp index be685e1e..d813973e 100644 --- a/src/scalar_type.hpp +++ b/src/scalar_type.hpp @@ -12,6 +12,8 @@ namespace tinytc { bool is_floating_type(scalar_type ty); +bool is_complex_type(scalar_type ty); +scalar_type element_type(scalar_type ty); clir::data_type to_clir_ty(scalar_type ty, clir::address_space as = clir::address_space::generic_t, clir::type_qualifier q = clir::type_qualifier::none); clir::data_type to_clir_ty(scalar_type ty, short size, From d629065fa2e6b3e110f3977b15e16884cd975652 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 5 Jul 2024 14:05:36 +0200 Subject: [PATCH 006/297] Add test for complex gemm Signed-off-by: Carsten Uphoff --- src/codegen_tools.cpp | 7 +++++++ src/codegen_tools.hpp | 1 + src/gemm_generator.cpp | 13 ++----------- src/recipe.cpp | 4 ++-- test/smm.hpp | 14 ++++++++++++-- test/tensor3.hpp | 5 +++-- test/ze/smm.cpp | 22 ++++++++++++++++++++++ 7 files changed, 49 insertions(+), 17 deletions(-) diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index e46b2d57..60dc33a1 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -30,6 +30,13 @@ expr constant(scalar_type ty, double value) { } return expr(value, bits(ty)); } +expr multiply(scalar_type ty_a, scalar_type ty_b, expr a, expr b) { + if (is_complex_type(ty_a) && is_complex_type(ty_b)) { + return a * b.s(0) + init_vector(to_clir_ty(ty_a), {-a.s(1), a.s(0)}) * b.s(1); + } + return a * b; +} + expr complex_mul(scalar_type ty, expr a, expr b) { return a * b.s(0) + init_vector(to_clir_ty(ty), {-a.s(1), a.s(0)}) * b.s(1); } diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index d8c5b59d..b4b8f9f2 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -23,6 +23,7 @@ namespace tinytc { short bits(scalar_type ty); clir::expr constant(scalar_type ty, std::int64_t value); clir::expr constant(scalar_type ty, double value); +clir::expr multiply(scalar_type ty_a, scalar_type ty_b, clir::expr a, clir::expr b); clir::expr complex_mul(scalar_type ty, clir::expr a, clir::expr b); clir::expr vload_helper(short vec_size, clir::expr offset, clir::expr ptr); clir::expr sub_group_block_read_helper(clir::expr pointer, scalar_type ty, clir::address_space as); diff --git a/src/gemm_generator.cpp b/src/gemm_generator.cpp index 30bff390..a8ff249d 100644 --- a/src/gemm_generator.cpp +++ b/src/gemm_generator.cpp @@ -301,17 +301,8 @@ void generator::add_microkernel(block_builder &bb, expr M, expr N, var A, var B, } } for (int m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - if (is_complex_type(gemm_cfg.ty.alpha)) { - auto c = c_block.get(m_block, 0); - auto myalpha = gemm_cfg.alpha - ? bb.declare_assign(to_clir_ty(gemm_cfg.ty.alpha), - "alpha", alpha) - : alpha; - bb.assign(c, complex_mul(gemm_cfg.ty.C, myalpha, c)); - - } else { - bb.add(multiply_into(c_block.get(m_block, 0), alpha)); - } + auto c = c_block.get(m_block, 0); + bb.assign(c, multiply(gemm_cfg.ty.alpha, gemm_cfg.ty.C, alpha, c)); } write_matrix_block(bb, c_block, c_descr, gemm_cfg.atomic, beta, core_cfg); }) diff --git a/src/recipe.cpp b/src/recipe.cpp index 9e547295..ef98462c 100644 --- a/src/recipe.cpp +++ b/src/recipe.cpp @@ -33,9 +33,9 @@ auto is_argument_zero(scalar_type type, std::size_t arg_size, const void *arg_va case scalar_type::f64: return is_argument_zero(arg_size, arg_value); case scalar_type::c32: - //return is_argument_zero(arg_size, arg_value); + return is_argument_zero>(arg_size, arg_value); case scalar_type::c64: - //return is_argument_zero(arg_size, arg_value); + return is_argument_zero>(arg_size, arg_value); case scalar_type::i1: break; }; diff --git a/test/smm.hpp b/test/smm.hpp index 4d100c1a..717b4a5f 100644 --- a/test/smm.hpp +++ b/test/smm.hpp @@ -40,6 +40,10 @@ #define TEST_PRECISIONS float, double +template struct is_complex : public std::false_type {}; +template struct is_complex> : public std::true_type {}; +template inline constexpr bool is_complex_v = is_complex::value; + template void small_gemm_batched_ref(tinytc::transpose transA, tinytc::transpose transB, T alpha, tensor3 const &A, tensor3 const &B, T beta, tensor3 &C) { @@ -111,7 +115,7 @@ void check_small_gemm_batched(tinytc::transpose transA, tinytc::transpose transB }; auto gpu_rt = std::make_shared(); - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v || std::is_same_v>) { if (!gpu_rt->supports_fp64()) { WARN_MESSAGE(false, "Double precision tests need double precision device support"); return; @@ -122,7 +126,13 @@ void check_small_gemm_batched(tinytc::transpose transA, tinytc::transpose transB T *data = x.data(); std::size_t n = x.size(); for (std::size_t i = 0; i < n; ++i) { - data[i] = static_cast(i % 101); + constexpr std::size_t prime = 101; + if constexpr (is_complex_v) { + data[i] = T{static_cast((2 * i) % prime), + static_cast((2 * i + 1) % prime)}; + } else { + data[i] = static_cast(i % prime); + } } }; diff --git a/test/tensor3.hpp b/test/tensor3.hpp index 29a09936..64e12331 100644 --- a/test/tensor3.hpp +++ b/test/tensor3.hpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -51,8 +52,8 @@ template bool compare(tensor3 const &A, tensor3 const &B) { for (std::uint32_t k = 0; k < A.shape(2); ++k) { for (std::uint32_t j = 0; j < A.shape(1); ++j) { for (std::uint32_t i = 0; i < A.shape(0); ++i) { - constexpr auto eps = 10.0 * std::numeric_limits::epsilon(); - REQUIRE(A(i, j, k) == doctest::Approx(B(i, j, k)).epsilon(eps)); + constexpr auto eps = 10.0 * std::numeric_limits::epsilon(); + REQUIRE(std::abs(A(i, j, k) - B(i, j, k)) == doctest::Approx(0.0).epsilon(eps)); } } } diff --git a/test/ze/smm.cpp b/test/ze/smm.cpp index ebdf2922..0a945ebc 100644 --- a/test/ze/smm.cpp +++ b/test/ze/smm.cpp @@ -7,6 +7,7 @@ #include "doctest/doctest.h" #include +#include using namespace tinytc; @@ -22,3 +23,24 @@ TEST_CASE_TEMPLATE("level zero packed alpha=1 beta=0", T, TEST_PRECISIONS) { check_small_gemm_batched( transpose::N, transpose::N, M, N, K, M, M * K, K, K * N, M, M * N, 1.0, 0.0, howmany); } + +TEST_CASE_TEMPLATE("level zero packed complex alpha=1 beta=0", T, TEST_PRECISIONS) { + auto KK = std::vector{53}; + auto MM = std::vector{21, 42}; + auto NN = std::vector{7, 11}; + auto HH = std::vector{1, 101}; + + std::uint32_t M, N, K, howmany; + DOCTEST_TENSOR4_TEST(MM, NN, KK, HH); + + check_small_gemm_batched, level_zero_test_runtime>( + transpose::N, transpose::N, M, N, K, M, M * K, K, K * N, M, M * N, 1.0, 0.0, howmany); +} + +TEST_CASE_TEMPLATE("level zero packed complex alpha=(-1,-2) beta=(2,3)", T, TEST_PRECISIONS) { + std::uint32_t M = 8, N = 16, K = 16, howmany = 5; + + check_small_gemm_batched, level_zero_test_runtime>( + transpose::N, transpose::N, M, N, K, M, M * K, K, K * N, M, M * N, {-1.0, -2.0}, {2.0, 3.0}, + howmany); +} From 1d69dd2f3564691fcb70ab9b85d167d1bb30d724 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 22 Jul 2024 12:48:59 +0200 Subject: [PATCH 007/297] Update complex number support Signed-off-by: Carsten Uphoff --- src/codegen_tools.cpp | 25 ++++-------- src/codegen_tools.hpp | 10 ++--- src/gemm_generator.cpp | 3 +- src/visitor/opencl_ast.cpp | 80 +++++++++++++++++++++++++------------- test/codegen/atomic.ir | 6 ++- 5 files changed, 71 insertions(+), 53 deletions(-) diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 60dc33a1..cdad4a86 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -37,10 +37,6 @@ expr multiply(scalar_type ty_a, scalar_type ty_b, expr a, expr b) { return a * b; } -expr complex_mul(scalar_type ty, expr a, expr b) { - return a * b.s(0) + init_vector(to_clir_ty(ty), {-a.s(1), a.s(0)}) * b.s(1); -} - expr vload_helper(short vec_size, expr offset, expr ptr) { switch (vec_size) { case 1: @@ -111,24 +107,17 @@ expr sub_group_block_write_helper(expr pointer, expr data, scalar_type ty, addre } void store_helper(block_builder &bb, bool is_atomic, expr dst, scalar_type ty, address_space as, - expr value, expr beta) { + expr value, scalar_type beta_ty, expr beta) { if (is_atomic) { - atomic_store_helper(bb, std::move(dst), ty, as, std::move(value), std::move(beta)); + atomic_store_helper(bb, std::move(dst), ty, as, std::move(value), beta_ty, std::move(beta)); } else { - auto c_scaled = clir::expr{nullptr}; - if (is_complex_type(ty)) { - c_scaled = bb.declare_assign(to_clir_ty(ty), "c_scaled", dereference(dst)); - auto beta1 = bb.declare_assign(to_clir_ty(ty), "beta", beta); - bb.assign(c_scaled, complex_mul(ty, beta1, c_scaled)); - } else { - c_scaled = beta * dereference(dst); - } + const auto c_scaled = multiply(ty, beta_ty, dereference(dst), beta); bb.assign(dereference(dst), std::move(value) + std::move(c_scaled)); } } void atomic_store_helper(block_builder &bb, expr dst, scalar_type ty, address_space as, expr value, - expr beta) { + scalar_type beta_ty, expr beta) { int mode = -1; visit(overloaded{ [&](clir::internal::int_imm &c) { @@ -414,8 +403,8 @@ auto read_matrix_block(block_builder &bb, matrix_block_description const &d, int } void write_matrix_block(block_builder &bb, block_accessor const &block, - matrix_block_description const &d, bool is_atomic, expr beta, - core_config const &core_cfg) { + matrix_block_description const &d, bool is_atomic, scalar_type beta_ty, + expr beta, core_config const &core_cfg) { const int m_blocks = 1 + (d.Mb - 1) / core_cfg.subgroup_size; const int first_m_block_with_check = d.first_block_with_check(core_cfg.subgroup_size); @@ -425,7 +414,7 @@ void write_matrix_block(block_builder &bb, block_accessor const &block, store_helper(bb, is_atomic, d.pointer + d.stride[0] * (get_sub_group_local_id() + m_block * core_cfg.subgroup_size), - d.ty, d.as, block.get(m_block, k), beta); + d.ty, d.as, block.get(m_block, k), beta_ty, beta); }; if (m_block >= first_m_block_with_check) { bb.add(if_selection_builder(d.condition(m_block, core_cfg.subgroup_size)) diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index b4b8f9f2..f27ee98b 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -24,16 +24,16 @@ short bits(scalar_type ty); clir::expr constant(scalar_type ty, std::int64_t value); clir::expr constant(scalar_type ty, double value); clir::expr multiply(scalar_type ty_a, scalar_type ty_b, clir::expr a, clir::expr b); -clir::expr complex_mul(scalar_type ty, clir::expr a, clir::expr b); clir::expr vload_helper(short vec_size, clir::expr offset, clir::expr ptr); clir::expr sub_group_block_read_helper(clir::expr pointer, scalar_type ty, clir::address_space as); clir::expr sub_group_block_write_helper(clir::expr pointer, clir::expr data, scalar_type ty, clir::address_space as); void store_helper(clir::block_builder &bb, bool is_atomic, clir::expr dst, scalar_type ty, - clir::address_space as, clir::expr value, clir::expr beta); + clir::address_space as, clir::expr value, scalar_type beta_ty, clir::expr beta); void atomic_store_helper(clir::block_builder &bb, clir::expr dst, scalar_type ty, - clir::address_space as, clir::expr value, clir::expr beta); + clir::address_space as, clir::expr value, scalar_type beta_ty, + clir::expr beta); void dispatch_constant_dynamic(clir::expr e, std::function const &const_case, std::function const &dyn_case); @@ -118,8 +118,8 @@ auto read_matrix_block(clir::block_builder &bb, matrix_block_description const & // Write MbxKb block void write_matrix_block(clir::block_builder &bb, block_accessor const &block, - matrix_block_description const &d, bool is_atomic, clir::expr beta, - core_config const &core_cfg); + matrix_block_description const &d, bool is_atomic, scalar_type beta_ty, + clir::expr beta, core_config const &core_cfg); } // namespace tinytc diff --git a/src/gemm_generator.cpp b/src/gemm_generator.cpp index a8ff249d..aa25060b 100644 --- a/src/gemm_generator.cpp +++ b/src/gemm_generator.cpp @@ -304,7 +304,8 @@ void generator::add_microkernel(block_builder &bb, expr M, expr N, var A, var B, auto c = c_block.get(m_block, 0); bb.assign(c, multiply(gemm_cfg.ty.alpha, gemm_cfg.ty.C, alpha, c)); } - write_matrix_block(bb, c_block, c_descr, gemm_cfg.atomic, beta, core_cfg); + write_matrix_block(bb, c_block, c_descr, gemm_cfg.atomic, gemm_cfg.ty.beta, beta, + core_cfg); }) .get_product()); } diff --git a/src/visitor/opencl_ast.cpp b/src/visitor/opencl_ast.cpp index 125caa88..9d7b241a 100644 --- a/src/visitor/opencl_ast.cpp +++ b/src/visitor/opencl_ast.cpp @@ -232,6 +232,8 @@ std::vector opencl_ast::operator()(alloca_inst const &a) { std::vector opencl_ast::operator()(axpby_inst const &inst) { auto at = get_memref_type(*inst.A()); auto bt = get_memref_type(*inst.B()); + auto alpha_ty = get_scalar_type(*inst.alpha()->ty()); + auto beta_ty = get_scalar_type(*inst.beta()->ty()); auto &adv = get_dope_vector(inst.A().get()); auto &bdv = get_dope_vector(inst.B().get()); @@ -249,8 +251,9 @@ std::vector opencl_ast::operator()(axpby_inst const &inst) { auto const inner_loop = [&](clir::block_builder &bb) { auto a = Ab[(block + m) * adv.stride(pA)]; auto b = bb.declare_assign((*this)(*bt), "b", Bb + (block + m) * bdv.stride(0)); + const auto a_scaled = multiply(alpha_ty, at->element_ty(), alpha, std::move(a)); store_helper(bb, inst.atomic(), b, bt->element_ty(), bt->addrspace(), - alpha * std::move(a), beta); + std::move(a_scaled), beta_ty, beta); }; if (is_remainder) { bb.add(clir::if_selection_builder(m < std::move(inner_trip_count)) @@ -266,8 +269,9 @@ std::vector opencl_ast::operator()(axpby_inst const &inst) { auto B = visit(*this, *inst.B()); if (bt->dim() == 0) { auto bb = clir::block_builder{}; - store_helper(bb, inst.atomic(), B, bt->element_ty(), bt->addrspace(), - std::move(alpha) * A[0], std::move(beta)); + const auto a_scaled = multiply(alpha_ty, at->element_ty(), alpha, A[0]); + store_helper(bb, inst.atomic(), B, bt->element_ty(), bt->addrspace(), std::move(a_scaled), + beta_ty, std::move(beta)); return {bb.get_product()}; } @@ -317,7 +321,7 @@ std::vector opencl_ast::operator()(arith_inst const &a) { case arithmetic::sub: return std::move(a) - std::move(b); case arithmetic::mul: - return std::move(a) * std::move(b); + return multiply(sty, sty, std::move(a), std::move(b)); case arithmetic::div: return std::move(a) / std::move(b); case arithmetic::rem: @@ -686,6 +690,7 @@ std::vector opencl_ast::operator()(gemv_inst const &g) { } std::vector opencl_ast::operator()(ger_inst const &g) { + auto at = get_memref_type(*g.A()); auto bt = get_memref_type(*g.B()); auto ct = get_memref_type(*g.C()); auto &adv = get_dope_vector(g.A().get()); @@ -694,6 +699,8 @@ std::vector opencl_ast::operator()(ger_inst const &g) { auto alpha = visit(*this, *g.alpha()); auto beta = visit(*this, *g.beta()); + auto alpha_ty = get_scalar_type(*g.alpha()->ty()); + auto beta_ty = get_scalar_type(*g.beta()->ty()); auto A = visit(*this, *g.A()); auto B = visit(*this, *g.B()); @@ -725,8 +732,14 @@ std::vector opencl_ast::operator()(ger_inst const &g) { auto a = A[(block + m) * adv.stride(0)]; auto c = bb.declare_assign((*this)(*ct), "c", Cb + (block + m) * cdv.stride(0)); + auto ab = bb.declare_assign( + to_clir_ty(ct->element_ty()), "ab", + multiply(at->element_ty(), bt->element_ty(), + std::move(a), b)); + const auto ab_scaled = multiply(alpha_ty, ct->element_ty(), + alpha, std::move(ab)); store_helper(bb, g.atomic(), c, ct->element_ty(), - ct->addrspace(), alpha * std::move(a) * b, + ct->addrspace(), std::move(ab_scaled), beta_ty, beta); }; if (is_remainder) { @@ -778,6 +791,8 @@ std::vector opencl_ast::operator()(foreach_inst const &p) { } std::vector opencl_ast::operator()(hadamard_inst const &g) { + auto at = get_memref_type(*g.A()); + auto bt = get_memref_type(*g.B()); auto ct = get_memref_type(*g.C()); auto &adv = get_dope_vector(g.A().get()); auto &bdv = get_dope_vector(g.B().get()); @@ -785,6 +800,8 @@ std::vector opencl_ast::operator()(hadamard_inst const &g) { auto alpha = visit(*this, *g.alpha()); auto beta = visit(*this, *g.beta()); + auto alpha_ty = get_scalar_type(*g.alpha()->ty()); + auto beta_ty = get_scalar_type(*g.beta()->ty()); auto A = visit(*this, *g.A()); auto B = visit(*this, *g.B()); @@ -793,26 +810,31 @@ std::vector opencl_ast::operator()(hadamard_inst const &g) { auto bb = clir::block_builder{}; auto sg = bb.declare_assign(clir::generic_uint(), "sg", clir::get_sub_group_id()); auto m = bb.declare_assign(clir::generic_uint(), "m", clir::get_sub_group_local_id()); - tile_loop_by_sgs(bb, cdv.shape(0), core_cfg_.subgroup_size, - tiling_.m_tiles() * tiling_.n_tiles(), std::move(sg), - [&](clir::block_builder &bb, clir::expr block, bool is_remainder, - clir::expr inner_trip_count) { - auto const inner_loop = [&](clir::block_builder &bb) { - auto b = B[(block + m) * bdv.stride(0)]; - auto a = A[(block + m) * adv.stride(0)]; - auto c = bb.declare_assign((*this)(*ct), "c", - C + (block + m) * cdv.stride(0)); - store_helper(bb, g.atomic(), c, ct->element_ty(), ct->addrspace(), - alpha * std::move(a) * std::move(b), beta); - }; - if (is_remainder) { - bb.add(clir::if_selection_builder(m < std::move(inner_trip_count)) - .then(inner_loop) - .get_product()); - } else { - inner_loop(bb); - } - }); + tile_loop_by_sgs( + bb, cdv.shape(0), core_cfg_.subgroup_size, tiling_.m_tiles() * tiling_.n_tiles(), + std::move(sg), + [&](clir::block_builder &bb, clir::expr block, bool is_remainder, + clir::expr inner_trip_count) { + auto const inner_loop = [&](clir::block_builder &bb) { + auto b = B[(block + m) * bdv.stride(0)]; + auto a = A[(block + m) * adv.stride(0)]; + + auto c = bb.declare_assign((*this)(*ct), "c", C + (block + m) * cdv.stride(0)); + auto ab = bb.declare_assign( + to_clir_ty(ct->element_ty()), "ab", + multiply(at->element_ty(), bt->element_ty(), std::move(a), b)); + const auto ab_scaled = multiply(alpha_ty, ct->element_ty(), alpha, std::move(ab)); + store_helper(bb, g.atomic(), c, ct->element_ty(), ct->addrspace(), + std::move(ab_scaled), beta_ty, beta); + }; + if (is_remainder) { + bb.add(clir::if_selection_builder(m < std::move(inner_trip_count)) + .then(inner_loop) + .get_product()); + } else { + inner_loop(bb); + } + }); return {bb.get_product()}; } @@ -919,6 +941,8 @@ std::vector opencl_ast::operator()(sum_inst const &inst) { auto alpha = visit(*this, *inst.alpha()); auto beta = visit(*this, *inst.beta()); + auto alpha_ty = get_scalar_type(*inst.alpha()->ty()); + auto beta_ty = get_scalar_type(*inst.beta()->ty()); auto zero = clir::expr(0.0, static_cast(size(at->element_ty()) * 8)); @@ -950,8 +974,9 @@ std::vector opencl_ast::operator()(sum_inst const &inst) { bb.add(clir::if_selection_builder(clir::get_sub_group_id() == 0 && clir::get_sub_group_local_id() == 0) .then([&](clir::block_builder &bb) { + const auto sum_scaled = multiply(alpha_ty, at->element_ty(), alpha, sum); store_helper(bb, inst.atomic(), B, bt->element_ty(), bt->addrspace(), - alpha * sum, beta); + std::move(sum_scaled), beta_ty, beta); }) .get_product()); } else if (bt->dim() == 1) { @@ -973,8 +998,9 @@ std::vector opencl_ast::operator()(sum_inst const &inst) { }) .get_product()); auto b = bb.declare_assign((*this)(*bt), "b", B + (block + m) * bdv.stride(0)); + const auto sum_scaled = multiply(alpha_ty, at->element_ty(), alpha, acc); store_helper(bb, inst.atomic(), b, bt->element_ty(), bt->addrspace(), - alpha * acc, beta); + std::move(sum_scaled), beta_ty, beta); }; if (is_remainder) { bb.add(clir::if_selection_builder(m < std::move(inner_trip_count)) diff --git a/test/codegen/atomic.ir b/test/codegen/atomic.ir index 2cddc42c..3d7ebfe4 100644 --- a/test/codegen/atomic.ir +++ b/test/codegen/atomic.ir @@ -33,14 +33,16 @@ func @ger_atomic(%A: memref, %B: memref, %C: memref) { ger.atomic 1.0, %A, %B, 1.0, %C : f32, memref, memref, f32, memref ; CHECK: global float* c = Cb + (blck1 + m) * 1; -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) c, 0x1p+0f * A[(blck1 + m) * 1] * b, memory_order_relaxed, memory_scope_work_group); +; CHECK-NEXT: float ab = A[(blck1 + m) * 1] * b; +; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) c, 0x1p+0f * ab, memory_order_relaxed, memory_scope_work_group); } func @hadamard_atomic(%A: memref, %B: memref, %C: memref) { hadamard.atomic 1.0, %A, %B, 1.0, %C : f32, memref, memref, f32, memref ; CHECK: global float* c = C + (blck + m) * 1; -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) c, 0x1p+0f * A[(blck + m) * 1] * B[(blck + m) * 1], memory_order_relaxed, memory_scope_work_group); +; CHECK-NEXT: float ab = A[(blck + m) * 1] * B[(blck + m) * 1]; +; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) c, 0x1p+0f * ab, memory_order_relaxed, memory_scope_work_group); } func @sum_atomic(%A: memref, %B: memref) { From 469f4cd7b2125a714e8468dda62b5a4fceb4be15 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 1 Aug 2024 15:09:19 +0200 Subject: [PATCH 008/297] Update spec Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 848 +++++++++++++------------- docs/manual/tutorial_matrix_chain.rst | 6 +- 2 files changed, 435 insertions(+), 419 deletions(-) diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 3f7bcdeb..ef2547e4 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -17,21 +17,19 @@ is called a **kernel**. Kernels are launched in batches, where each instance of the kernel is called a work-group. The kernel has access to its group id that is used to select the work done in the work group. Each work group consists of a fixed number of work-items that execute concurrently. -The language distinguishes between two kinds of instructions: *replicated* and *collective* instructions. -It is distinguished between *mixed* and *spmd* regions. -Mixed regions may contain replicated and collective instructions whereas spmd regions -may only contain replicated instructions. -A collective instruction distributes the work among the work-items. -The instruction is responsible to distribute the work in a sensible manner. +The language distinguishes between *collective*, *SPMD*, and *mixed* instructions. +A collective instruction distributes the work among the work-items in an implementation-defined manner. +Local variables passed to or returned from a collective instruction are always uniform, meaning +that each work-item holds the same value. +An SPMD instruction follows the OpenCL execution model, where local variables may have a different value +for each work-item. +Mixed instructions accept both varying and uniform local variables. -A replicated instruction replicates the work across all work-items. -In a mixed region, the replicated instructions always operate on the same data. -In spmd regions, the replicated instructions can operate on multiple data, -but in these regions collective instructions are prohibited. - -Mixed regions can be nested whereas spmd regions must not be nested. -A mixed region may be nested in a spmd region. +Regions come in two different kinds: collective and SPMD. +A collective instructions must only appear in a collective region, and an SPMD instruction +must only appear in a in a SPMD region. Mixed instructions might appear in both kinds of regions. +SPMD regions may be nested in collective regions but collective regions must not be nested in SPMD regions. Core rules ========== @@ -244,81 +242,303 @@ Instructions .. code:: abnf - value-instruction = local-identifier "=" (alloca-instruction - / arith-binary-instruction - / arith-unary-instruction - / cast-instruction - / comparison-instruction - / expand-instruction - / fuse-instruction - / group-id-instruction - / group-size-instruction - / load-instruction - / size-instruction - / subview-instruction) - multi-value-instruction = [local-identifier-list "="] if-instruction - local-identifier-list = local-identifier *("," local-identifier) - instruction = value-instruction - / multi-value-instruction - / axpby-instruction - / barrier-instruction - / for-instruction - / foreach-instruction - / lifetime-stop-instruction - / gemm-instruction - / gemv-instruction - / ger-instruction - / hadamard-product-instruction - / store-instruction - / sum-instruction - / yield-instruction + value-instruction-assignment = local-identifier "=" value-instruction + multi-value-instruction-assignment = [local-identifier-list "="] multi-value-instruction + local-identifier-list = local-identifier *("," local-identifier) + instruction = value-instruction-assignment + / multi-value-instruction-assignment + + +Collective instructions +----------------------- Alloca ------- +...... .. code:: abnf - alloca-instruction = "alloca" "->" memref-type + value-instruction = "alloca" "->" memref-type Overview -........ +~~~~~~~~ -*Collective instruction.* The alloca instruction allocates temporary memory that is freed automatically at the end of the block that contains the alloca. Returns -....... +~~~~~~~ A memref of the memref-type. Restrictions -............ +~~~~~~~~~~~~ The memref's size must known at compile-time, i.e. the tensor shape must not contain any dynamic modes. +Axpby +..... + +.. code:: abnf + + transpose = ".t" / ".n" + const-or-val = floating-constant / local-identifier + instruction =/ "axpby" transpose [".atomic"] + const-or-val "," local-identifier "," const-or-val "," local-identifier + ":" scalar-type "," memref-type "," scalar-type "," memref-type + +Overview +~~~~~~~~ + +Axpby implements + +.. math:: + + B := \alpha \text{op}(A) + \beta B + +for vectors and matrices. +If the atomic flag is set, B is updated atomically. + +Arguments +~~~~~~~~~ + +The first argument gives :math:`\alpha`, and the third argument gives :math:`\beta`. +The second and the fourth argument must have memref type and give A and B, respectively. + +The transpose modifier defines :math:`\text{op}` as following: + +.. math:: + + \text{op}_i(X) := \left\{ + \begin{array}{rcl} + X^T & \text{ if } & \text{modifier}_i= t \wedge \text{order}(X) = 2,\\ + X & \text{ else. } + \end{array} + \right. + +(Note that ".t" has no effect on vectors.) + +The shape of :math:`\text{op}(A)` and B must be identical and the order of A and B needs to be 1 (vector) +or 2 (matrix). + +Foreach +....... + +.. code:: abnf + + instruction =/ "foreach" local-identifier "=" identifier-or-int-constant "," identifier-or-int-constant + [":" integer-type] region + +Overview +~~~~~~~~ + +A foreach loop that executes the loop's range [from; to) without any sequence guarantee. +The region of a foreach is a *spmd region*. + +The loop's range [from; to) is given by the first integer constant and second integer constant, +and the trip count is stored in the local identifier. +The integer type of the loop variable is given after the colon. +The integer type of the loop variable and the loop bounds is given after the colon. +The default integer type is ``index``. + +GEMM +.... + +.. code:: abnf + + instruction =/ "gemm" transpose transpose [".atomic"] + "," const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier + ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + +Overview +~~~~~~~~ + +GEMM implements the well-known GEMM BLAS-3 operation. + +.. math:: + + C := \alpha \text{op}_1(A) \text{op}_2(B) + \beta C + +If the atomic flag is set, C is updated atomically. + +Arguments +~~~~~~~~~ + +The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. +The second, the third, and the fifth argument must have memref type and give +A, B, and C, respectively. + +The first transpose modifier defines :math:`\text{op}_1` and the second transpose modifier +defines :math:`\text{op}_2` as following: + +.. math:: + + \text{op}_i(X) := \left\{ + \begin{array}{rcl} + X^T & \text{ if } & \text{modifier}_i = t,\\ + X & \text{ if } & \text{modifier}_i = n. + \end{array} + \right. + + +If :math:`\text{op}_1(A)` has the shape MxK and +:math:`\text{op}_2(B)` has the shape KxN then C must have the shape MxN. + +GEMV +.... + +.. code:: abnf + + instruction =/ "gemv" transpose [".atomic"] + "," const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier + ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + +Overview +~~~~~~~~ + +GEMV implements the well-known GEMM BLAS-2 operation. + +.. math:: + + c := \alpha \text{op}_1(A) b + \beta C + +If the atomic flag is set, c is updated atomically. + +Arguments +~~~~~~~~~ + +The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. +The second, the third, and the fifth argument must have memref type and give +A, b, and c, respectively. + +The transpose modifier for A as in GEMM. + +:math:`\text{op}_1(A)` has the shape MxK and :math:`B` has the shape K then c must have the shape M. + +GER +... + +.. code:: abnf + + instruction =/ "ger" [".atomic"] + const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier + ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + +Overview +~~~~~~~~ + +Computes the general rank-1 update: + +.. math:: + + C := \alpha a b^T + \beta C + +If the atomic flag is set, C is updated atomically. + +Arguments +~~~~~~~~~ + +The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. +The second, the third, and the fifth argument must have memref type and give +a, b, and C, respectively. + +a and b must be vectors. If the size of a is M and the size of b is N the shape of C must be :math:`M\times N`. + + +Hadamard product +................ + +.. code:: abnf + + instruction =/ "hadamard_product" [".atomic"] + const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier + ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + +Overview +~~~~~~~~ + +Computes the Hadamard product of two tensors. +That is, in index notation we have + +.. math:: + + c_{i} := \alpha a_{i} b_{i} + \beta c_{i} + +If the atomic flag is set, c is updated atomically. + +Arguments +~~~~~~~~~ + +The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. +The second, the third, and the fifth argument must have memref type and give +a, b, and c, respectively. + +a, b, and c must be vectors and have equal shape. + +Sum +... + +.. code:: abnf + + instruction =/ "sum" transpose [".atomic"] + "," const-or-val "," local-identifier "," const-or-val "," local-identifier + ":" scalar-type "," memref-type "," scalar-type "," memref-type + +Overview +~~~~~~~~ + +Computes the matrix-vector product or the dot product of A with a vector of ones. +That is, for matrices we have + +.. math:: + + B := \alpha \text{op}(A) \vec{1} + \beta B + +and for vectors we have + +.. math:: + + b := \alpha \left + \beta b + +If the atomic flag is set, B is updated atomically. + + +Arguments +~~~~~~~~~ + +The first argument gives :math:`\alpha` and the third argument gives :math:`\beta`. +The second and the fourth argument must have memref type and give A and B, respectively. +If A is a matrix then B must be a vector. +The first mode size of :math:`\text{op}(A)` must match the size of B. +If A is a vector, then B must be a scalar memref. + +The transpose op is defined as in the axpby instruction. + + + +Mixed instructions +------------------ + Arithmetic (binary) -------------------- +................... .. code:: abnf - identifier-or-constant = local-identifier / integer-constant / floating-constant - arith-binary-type = ".add" / - ".sub" / - ".mul" / - ".div" / - ".rem" / - ".shl" / - ".shr" / - ".and" / - ".or" / - ".xor" - arith-binary-instruction = "arith" arith-binary-type - identifier-or-constant "," identifier-or-constant ":" scalar-type + identifier-or-constant = local-identifier / integer-constant / floating-constant + arith-binary-type = ".add" / + ".sub" / + ".mul" / + ".div" / + ".rem" / + ".shl" / + ".shr" / + ".and" / + ".or" / + ".xor" + value-instruction =/ "arith" arith-binary-type + identifier-or-constant "," identifier-or-constant ":" scalar-type Overview -........ +~~~~~~~~ -*Replicated instruction.* Binary arithmetic operation on scalars. Both operands, as well as the returned type, have the same scalar type. @@ -338,17 +558,16 @@ Op Allowed type Description ==== ============ ============================================================================== Arithmetic (unary) ------------------- +.................. .. code:: abnf - arith-unary-type = ".neg" / ".not" - arith-unary-instruction = "arith" arith-unary-type identifier-or-constant ":" scalar-type + arith-unary-type = ".neg" / ".not" + value-instruction =/ "arith" arith-unary-type identifier-or-constant ":" scalar-type Overview -........ +~~~~~~~~ -*Replicated instruction.* Unary arithmetic operation on scalars. The returned value has the same type as the operand. @@ -360,30 +579,28 @@ Op Allowed type Description ==== ============ ============================================================================== Cast ----- +.... .. code:: abnf - cast-instruction = "cast" identifier-or-constant ":" scalar-type "->" scalar-type + value-instruction =/ "cast" identifier-or-constant ":" scalar-type "->" scalar-type Overview -........ +~~~~~~~~ -*Replicated instruction.* Cast scalar values. Comparison ----------- +.......... .. code:: abnf - comparison-instruction = "cmp" (".eq" / ".ne" / ".gt" / ".ge" / ".lt" / ".le") - identifier-or-constant "," identifier-or-constant ":" scalar-type + value-instruction =/ "cmp" (".eq" / ".ne" / ".gt" / ".ge" / ".lt" / ".le") + identifier-or-constant "," identifier-or-constant ":" scalar-type Overview -........ +~~~~~~~~ -*Replicated instruction.* Scalar comparison. Both operands must have the same scalar type and the returned value is boolean. @@ -399,22 +616,21 @@ Cond Description ==== ===================== Expand ------- +...... .. code:: abnf - expand-instruction = "expand" local-identifier "[" integer-constant "->" expand-shape "]" ":" memref-type - expand-shape = constant-or-dynamic-or-identifier 1*("x" constant-or-dynamic-or-identifier) + value-instruction =/ "expand" local-identifier "[" integer-constant "->" expand-shape "]" ":" memref-type + expand-shape = constant-or-dynamic-or-identifier 1*("x" constant-or-dynamic-or-identifier) constant-or-dynamic-or-identifier = integer-constant / "?" / local-identifier Overview -........ +~~~~~~~~ -*Replicated instruction.* The expand instruction returns a view on a tensor with a mode viewed as higher-order mode. Arguments -......... +~~~~~~~~~ The first argument must point to a value of memref type. The integer constant in square brackets gives the mode that shall be expanded. @@ -454,7 +670,7 @@ The output type is a memref type according to the following rules: expand %0[0->4x?] : memref> ; -> memref> Restrictions -............ +~~~~~~~~~~~~ At most one mode in expand-shape must be dynamic. @@ -462,20 +678,19 @@ The product of the expand shape must be the same as the mode size. If one entry in the expand shape is dynamic then the other must evenly divide the mode size. Fuse ----- +.... .. code:: abnf - fuse-instruction = "fuse" local-identifier "[" integer-constant "," integer-constant "]" ":" memref-type + value-instruction =& "fuse" local-identifier "[" integer-constant "," integer-constant "]" ":" memref-type Overview -........ +~~~~~~~~ -*Replicated instruction.* The fuse instruction returns a view on a tensor with two or more adjacent modes viewed as a single mode. Arguments -......... +~~~~~~~~~ The first argument must point to a value of memref type. The fused modes are specified as the interval [from, to], where from is given @@ -504,7 +719,7 @@ The output type is a memref type according to the following rules: fuse %0[0,1] : memref> ; -> memref> Restrictions -............ +~~~~~~~~~~~~ Let i be the first mode and j the last mode. The stride vector S and the shape vector s must satisify the following compatibility condition: @@ -523,57 +738,91 @@ is undefined beheaviour. Group id --------- +........ .. code:: abnf - group-id-instruction = "group_id" + value-instruction =/ "group_id" Overview -........ +~~~~~~~~ -*Replicated instruction.* Returns the group id, an integer of type "index" inbetween 0 and the group size - 1. Group size ----------- +.......... .. code:: abnf - group-size-instruction = "group_size" + value-instruction =/ "group_size" Overview -........ +~~~~~~~~ -*Replicated instruction.* Returns the group size, an integer of type "index". +If +.. + +.. code:: abnf + + multi-value-instruction = "if" identifier-or-int-constant ["->" "(" scalar-type-list ")"] + region ["else" region] + type-list = scalar-type *("," scalar-type) + +Overview +~~~~~~~~ + +An if statement. +Both regions are *mixed regions*. + +The condition must be of bool type. + +Arguments +~~~~~~~~~ + +The if instruction may return multiple values, where the number of values and the value types +are given by the scalar-type-list. +If values are returned, the last instruction in both the "then"-region and the "else"-region must +be a yield instruction (the "else"-region cannot be omitted). + +Example: + + .. code:: + + %1 = cmp.lt %0, 16 : i32 + %x = if %1 -> (i32) { + yield %0 : i32 + } else { + yield 16 : i32 + } + Load ----- +.... .. code:: abnf - load-instruction = "load" local-identifier "[" [index-list] "]" ":" memref-or-group-type - index-list = identifier-or-int-constant *("," identifier-or-int-constant) - identifier-or-int-constant = integer-constant / local-identifier - memref-or-group-type = memref-type / group-type + value-instruction =/ "load" local-identifier "[" [index-list] "]" ":" memref-or-group-type + index-list = identifier-or-int-constant *("," identifier-or-int-constant) + identifier-or-int-constant = integer-constant / local-identifier + memref-or-group-type = memref-type / group-type Overview -........ +~~~~~~~~ Load the element given by the index list from a memref or group. The number of indices must match the order of the memref and a single index must be given for a group. Arguments -......... +~~~~~~~~~ The first operand must have memref or group type. The indices must be of ``index`` type. Returns -....... +~~~~~~~ A value of the memref's element type or the group's memref type. Examples: @@ -583,22 +832,42 @@ Examples: #. ``load %0[%1] : group>`` returns a ``memref`` value. #. ``load %0[%1] : group, offset: ?>`` returns a ``memref`` value. -Size ----- +For +... .. code:: abnf - size-instruction = "size" local-identifier "[" integer-constant "]" ":" memref-type + instruction =/ "for" local-identifier "=" identifier-or-int-constant "," identifier-or-int-constant + ["," identifier-or-int-constant] [":" integer-type] region Overview -........ +~~~~~~~~ + +A for loop. +Instructions in the for loop execute sequentially and its region is a *mixed region*. + +The loop's range [from; to) is given by the first integer constant and second integer constant, +and the trip count is stored in the local identifier. +A step size can be given with the third integer constant. +The step size defaults to 1 if omitted. +The integer type of the loop variable and the loop bounds is given after the colon. +The default integer type is ``index``. + +Size +.... + +.. code:: abnf + + value-instruction =/ "size" local-identifier "[" integer-constant "]" ":" memref-type + +Overview +~~~~~~~~ -*Replicated instruction.* The size instruction returns the i-th entry of the tensor's shape, where "i" is given by the integer constant in square brackets. Arguments -......... +~~~~~~~~~ The first argument must point to a value of memref type. The integer constant i gives the mode for which the size shall be returned. @@ -612,22 +881,21 @@ The local identifier must have the memref type specified last. The instruction returns an integer of index type. Subview -------- +....... .. code:: abnf - subview-instruction = "subview" local-identifier "[" [index-or-slice-list] "]" ":" memref-type - index-or-slice-list = index-or-slice *("," index-or-slice) - index-or-slice = identifier-or-int-constant [":" (identifier-or-int-constant / "?")] / ":" + value-instruction =/ "subview" local-identifier "[" [index-or-slice-list] "]" ":" memref-type + index-or-slice-list = index-or-slice *("," index-or-slice) + index-or-slice = identifier-or-int-constant [":" (identifier-or-int-constant / "?")] / ":" Overview -........ +~~~~~~~~ -*Replicated instruction.* The subview instruction returns a view on a tensor. Arguments -......... +~~~~~~~~~ The first argument must point to a value of memref type. The number of indices in square brackets must match the order of the memref. @@ -687,358 +955,106 @@ The output type is a memref type according to the following rules: subview %0[5:?] : memref ; Returns memref subview %0[%2:?] : memref ; Returns memref -If --- - -.. code:: abnf - - if-instruction = "if" identifier-or-int-constant ["->" "(" scalar-type-list ")"] - region ["else" region] - type-list = scalar-type *("," scalar-type) - -Overview -........ - -An if statement. -Both regions are *mixed regions*. - -The condition must be of bool type. - -Arguments -......... - -The if instruction may return multiple values, where the number of values and the value types -are given by the scalar-type-list. -If values are returned, the last instruction in both the "then"-region and the "else"-region must -be a yield instruction (the "else"-region cannot be omitted). - -Example: - - .. code:: - - %1 = cmp.lt %0, 16 : i32 - %x = if %1 -> (i32) { - yield %0 : i32 - } else { - yield 16 : i32 - } - -Axpby ------ - -.. code:: abnf - - transpose = ".t" / ".n" - const-or-val = floating-constant / local-identifier - axpby-instruction = "axpby" transpose [".atomic"] - const-or-val "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," scalar-type "," memref-type - -Overview -........ - -*Collective instruction.* -Axpby implements - -.. math:: - - B := \alpha \text{op}(A) + \beta B - -for vectors and matrices. -If the atomic flag is set, B is updated atomically. - -Arguments -......... - -The first argument gives :math:`\alpha`, and the third argument gives :math:`\beta`. -The second and the fourth argument must have memref type and give A and B, respectively. - -The transpose modifier defines :math:`\text{op}` as following: - -.. math:: - - \text{op}_i(X) := \left\{ - \begin{array}{rcl} - X^T & \text{ if } & \text{modifier}_i= t \wedge \text{order}(X) = 2,\\ - X & \text{ else. } - \end{array} - \right. - -(Note that ".t" has no effect on vectors.) - -The shape of :math:`\text{op}(A)` and B must be identical and the order of A and B needs to be 1 (vector) -or 2 (matrix). - - -For ---- - -.. code:: abnf - - for-instruction = "for" local-identifier "=" identifier-or-int-constant "," identifier-or-int-constant - ["," identifier-or-int-constant] [":" integer-type] region - -Overview -........ - -A for loop. -Instructions in the for loop execute sequentially and its region is a *mixed region*. - -The loop's range [from; to) is given by the first integer constant and second integer constant, -and the trip count is stored in the local identifier. -A step size can be given with the third integer constant. -The step size defaults to 1 if omitted. -The integer type of the loop variable and the loop bounds is given after the colon. -The default integer type is ``index``. - -Foreach -------- - -.. code:: abnf - - foreach-instruction = "foreach" local-identifier "=" identifier-or-int-constant "," identifier-or-int-constant - [":" integer-type] region - -Overview -........ - -A foreach loop that executes the loop's range [from; to) without any sequence guarantee. -The region of a foreach is a *spmd region*. - -The loop's range [from; to) is given by the first integer constant and second integer constant, -and the trip count is stored in the local identifier. -The integer type of the loop variable is given after the colon. -The integer type of the loop variable and the loop bounds is given after the colon. -The default integer type is ``index``. - -GEMM ----- +Store +..... .. code:: abnf - gemm-instruction = "gemm" transpose transpose [".atomic"] - "," const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + instruction =/ "store" local-identifier "," local-identifier "[" [index-list] "]" ":" memref-type Overview -........ - -*Collective instruction.* -GEMM implements the well-known GEMM BLAS-3 operation. +~~~~~~~~ -.. math:: - - C := \alpha \text{op}_1(A) \text{op}_2(B) + \beta C +Store a scalar value in a memref at the position given by the index list. +The number of indices must match the order of the memref. -If the atomic flag is set, C is updated atomically. +*Note:* Store should only be used in SPMD regions as otherwise the same memory location is written +from all work-items. Arguments -......... +~~~~~~~~~ -The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. -The second, the third, and the fifth argument must have memref type and give -A, B, and C, respectively. - -The first transpose modifier defines :math:`\text{op}_1` and the second transpose modifier -defines :math:`\text{op}_2` as following: - -.. math:: - - \text{op}_i(X) := \left\{ - \begin{array}{rcl} - X^T & \text{ if } & \text{modifier}_i = t,\\ - X & \text{ if } & \text{modifier}_i = n. - \end{array} - \right. - - -If :math:`\text{op}_1(A)` has the shape MxK and -:math:`\text{op}_2(B)` has the shape KxN then C must have the shape MxN. +The first operand must have the same scalar type as the memref type. +The indices must be of ``index`` type. -GEMV ----- +Yield +..... .. code:: abnf - gemv-instruction = "gemm" transpose [".atomic"] - "," const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + instruction =/ "yield" [local-identifier-list] ":" [scalar-type-list] + identifier-or-constant-list = identifier-or-constant *("," identifier-or-constant) Overview -........ - -*Collective instruction.* -GEMV implements the well-known GEMM BLAS-2 operation. +~~~~~~~~ -.. math:: - - c := \alpha \text{op}_1(A) b + \beta C - -If the atomic flag is set, c is updated atomically. +Yield returns values from an if or for instruction. Arguments -......... +~~~~~~~~~ -The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. -The second, the third, and the fifth argument must have memref type and give -A, b, and c, respectively. - -The transpose modifier for A as in GEMM. - -:math:`\text{op}_1(A)` has the shape MxK and :math:`B` has the shape K then c must have the shape M. +The length of the local identifier list must equal the length of the scalar type list. -GER ---- +Additional instructions +....................... .. code:: abnf - ger-instruction = "ger" [".atomic"] - const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type - -Overview -........ - -Computes the general rank-1 update: - -.. math:: - - C := \alpha a b^T + \beta C - -If the atomic flag is set, C is updated atomically. - -Arguments -......... - -The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. -The second, the third, and the fifth argument must have memref type and give -a, b, and C, respectively. - -a and b must be vectors. If the size of a is M and the size of b is N the shape of C must be :math:`M\times N`. + barrier-instruction = "barrier" + lifetime-stop-instruction = "lifetime_stop" local-identifier +SPMD instructions +----------------- -Hadamard product ----------------- +Number of subgroups +................... .. code:: abnf - hadamard-product-instruction = "hadamard_product" [".atomic"] - const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + value-instruction =/ "num_subgroups" Overview -........ - -*Collective instruction.* -Computes the Hadamard product of two tensors. -That is, in index notation we have - -.. math:: - - c_{i} := \alpha a_{i} b_{i} + \beta c_{i} - -If the atomic flag is set, c is updated atomically. - -Arguments -......... - -The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. -The second, the third, and the fifth argument must have memref type and give -a, b, and c, respectively. +~~~~~~~~ -a, b, and c must be vectors and have equal shape. +Returns the number of subgroups the work-group is divided in; i32 integer. - -Store ------ +Subgroup id +........... .. code:: abnf - store-instruction = "store" local-identifier "," local-identifier "[" [index-list] "]" ":" memref-type + value-instruction =/ "subgroup_id" Overview -........ +~~~~~~~~ -*Replicated instruction.* -Store a scalar value in a memref at the position given by the index list. -The number of indices must match the order of the memref. +Returns the subgroup id; i32 integer from 0 to num_subgroups - 1. -*Note:* Store should only be used in SPMD regions as otherwise the same memory location is written -from all work-items. - -Arguments -......... - -The first operand must have the same scalar type as the memref type. -The indices must be of ``index`` type. - -Sum ---- +Subgroup local id +................. .. code:: abnf - sum-instruction = "sum" transpose [".atomic"] - "," const-or-val "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," scalar-type "," memref-type + value-instruction =/ "subgroup_local_id" Overview -........ - -*Collective instruction.* -Computes the matrix-vector product or the dot product of A with a vector of ones. -That is, for matrices we have +~~~~~~~~ -.. math:: - - B := \alpha \text{op}(A) \vec{1} + \beta B +Returns the work-item id within the sub-group; i32 integer from 0 to subgroup_size - 1. -and for vectors we have - -.. math:: - - b := \alpha \left + \beta b - -If the atomic flag is set, B is updated atomically. - - -Arguments -......... - -The first argument gives :math:`\alpha` and the third argument gives :math:`\beta`. -The second and the fourth argument must have memref type and give A and B, respectively. -If A is a matrix then B must be a vector. -The first mode size of :math:`\text{op}(A)` must match the size of B. -If A is a vector, then B must be a scalar memref. - -The transpose op is defined as in the axpby instruction. - -Yield ------ +Subgroup size +............. .. code:: abnf - yield-instruction = "yield" [local-identifier-list] ":" [scalar-type-list] - identifier-or-constant-list = identifier-or-constant *("," identifier-or-constant) + value-instruction =/ "subgroup_size" Overview -........ +~~~~~~~~ -Yield returns values from an if or for instruction. - -Arguments -......... - -The length of the local identifier list must equal the length of the scalar type list. +Returns the subgroup size; i32 integer. -Additional instructions ------------------------ - -.. code:: abnf - - barrier-instruction = "barrier" - lifetime-stop-instruction = "lifetime_stop" local-identifier - Sample code =========== diff --git a/docs/manual/tutorial_matrix_chain.rst b/docs/manual/tutorial_matrix_chain.rst index e9c552d6..53141413 100644 --- a/docs/manual/tutorial_matrix_chain.rst +++ b/docs/manual/tutorial_matrix_chain.rst @@ -50,10 +50,10 @@ Compilation with the Tiny Tensor Compiler generates the following OpenCL-C code kernel __attribute__((reqd_work_group_size(64,1,1))) __attribute__((intel_reqd_sub_group_size(32))) - fused_kernel(global float *K, global float *P, uint P_shape2, global float *global *A, - global float *Q, uint Q_shape2) { + fused_kernel(global float *K, global float *P, long P_shape2, global float *global *A, + global float *Q, long Q_shape2) { local uchar stack[2016] __attribute__((aligned(64))); - uint gid = get_global_id(2); + long gid = get_global_id(2); global float *p = P + 0ll * 1 + 0ll * 56 + gid * 504; global float *a = *(A + gid); global float *q = Q + 0ll * 1 + 0ll * 56 + gid * 504; From a74b8e504c268fcd232f54821dc6d49cea6e5f87 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 1 Aug 2024 17:10:30 +0200 Subject: [PATCH 009/297] Add new instructions Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.rst | 35 ++++++++++ docs/api/builder_capi.yaml | 5 ++ docs/api/builder_cxxapi.rst | 35 ++++++++++ docs/api/builder_cxxapi.yaml | 5 ++ docs/manual/tensor-ir.rst | 64 ++++++++++------- include/tinytc/tinytc.h | 69 ++++++++++++++++++ include/tinytc/tinytc.hpp | 66 ++++++++++++++++++ include/tinytc/types.h | 1 + include/tinytc/types.hpp | 1 + src/error.cpp | 5 +- src/inst.cpp | 45 ++++++++++++ src/node/inst_node.hpp | 112 +++++++++++++++++++++++------- src/parser/lexer.re | 4 ++ src/parser/parser_impl.yy | 47 +++++++++++-- src/visitor/alias_analysis.cpp | 2 + src/visitor/alias_analysis.hpp | 1 + src/visitor/check_ir.cpp | 11 ++- src/visitor/check_ir.hpp | 1 + src/visitor/dump_ir.cpp | 25 +++++++ src/visitor/dump_ir.hpp | 5 ++ src/visitor/insert_barrier.cpp | 4 ++ src/visitor/insert_barrier.hpp | 1 + src/visitor/lifetime_analysis.cpp | 13 ++++ src/visitor/lifetime_analysis.hpp | 2 + src/visitor/opencl_ast.cpp | 32 +++++++++ src/visitor/opencl_ast.hpp | 5 ++ src/visitor/slot_tracker.cpp | 2 + src/visitor/slot_tracker.hpp | 1 + src/visitor/stack.cpp | 7 ++ src/visitor/stack.hpp | 1 + src/visitor/work_group_size.cpp | 1 + src/visitor/work_group_size.hpp | 1 + test/codegen/nesting2.ir | 8 +++ test/codegen/nesting3.ir | 11 +++ test/codegen/subgroup.ir | 16 +++++ 35 files changed, 584 insertions(+), 60 deletions(-) create mode 100644 test/codegen/nesting2.ir create mode 100644 test/codegen/nesting3.ir create mode 100644 test/codegen/subgroup.ir diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index 6af8efdd..07205980 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -356,10 +356,20 @@ Instruction * :ref:`tinytc_load_inst_create` + * :ref:`tinytc_num_subgroups_inst_create` + + * :ref:`tinytc_parallel_inst_create` + * :ref:`tinytc_size_inst_create` * :ref:`tinytc_store_inst_create` + * :ref:`tinytc_subgroup_id_inst_create` + + * :ref:`tinytc_subgroup_local_id_inst_create` + + * :ref:`tinytc_subgroup_size_inst_create` + * :ref:`tinytc_subview_inst_create` * :ref:`tinytc_sum_inst_create` @@ -467,6 +477,16 @@ tinytc_load_inst_create .. doxygenfunction:: tinytc_load_inst_create +tinytc_num_subgroups_inst_create +................................ + +.. doxygenfunction:: tinytc_num_subgroups_inst_create + +tinytc_parallel_inst_create +........................... + +.. doxygenfunction:: tinytc_parallel_inst_create + tinytc_size_inst_create ....................... @@ -477,6 +497,21 @@ tinytc_store_inst_create .. doxygenfunction:: tinytc_store_inst_create +tinytc_subgroup_id_inst_create +.............................. + +.. doxygenfunction:: tinytc_subgroup_id_inst_create + +tinytc_subgroup_local_id_inst_create +.................................... + +.. doxygenfunction:: tinytc_subgroup_local_id_inst_create + +tinytc_subgroup_size_inst_create +................................ + +.. doxygenfunction:: tinytc_subgroup_size_inst_create + tinytc_subview_inst_create .......................... diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index 4412c745..9447ba9b 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -69,8 +69,13 @@ Builder C-API: - tinytc_hadamard_inst_create - tinytc_if_inst_create - tinytc_load_inst_create + - tinytc_num_subgroups_inst_create + - tinytc_parallel_inst_create - tinytc_size_inst_create - tinytc_store_inst_create + - tinytc_subgroup_id_inst_create + - tinytc_subgroup_local_id_inst_create + - tinytc_subgroup_size_inst_create - tinytc_subview_inst_create - tinytc_sum_inst_create - tinytc_yield_inst_create diff --git a/docs/api/builder_cxxapi.rst b/docs/api/builder_cxxapi.rst index 57668d4a..41f0cb28 100644 --- a/docs/api/builder_cxxapi.rst +++ b/docs/api/builder_cxxapi.rst @@ -306,10 +306,20 @@ Instruction * :ref:`make_load` + * :ref:`make_num_subgroups` + + * :ref:`make_parallel` + * :ref:`make_size` * :ref:`make_store` + * :ref:`make_subgroup_id` + + * :ref:`make_subgroup_local_id` + + * :ref:`make_subgroup_size` + * :ref:`make_subview` * :ref:`make_sum` @@ -413,6 +423,16 @@ make_load .. doxygenfunction:: tinytc::make_load +make_num_subgroups +.................. + +.. doxygenfunction:: tinytc::make_num_subgroups + +make_parallel +............. + +.. doxygenfunction:: tinytc::make_parallel + make_size ......... @@ -423,6 +443,21 @@ make_store .. doxygenfunction:: tinytc::make_store +make_subgroup_id +................ + +.. doxygenfunction:: tinytc::make_subgroup_id + +make_subgroup_local_id +...................... + +.. doxygenfunction:: tinytc::make_subgroup_local_id + +make_subgroup_size +.................. + +.. doxygenfunction:: tinytc::make_subgroup_size + make_subview ............ diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index b2358d2a..7b5f1eeb 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -63,8 +63,13 @@ Builder C++-API: - tinytc::make_hadamard - tinytc::make_if - tinytc::make_load + - tinytc::make_num_subgroups + - tinytc::make_parallel - tinytc::make_size - tinytc::make_store + - tinytc::make_subgroup_id + - tinytc::make_subgroup_local_id + - tinytc::make_subgroup_size - tinytc::make_subview - tinytc::make_sum - tinytc::make_yield diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index ef2547e4..51377663 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -28,7 +28,7 @@ Mixed instructions accept both varying and uniform local variables. Regions come in two different kinds: collective and SPMD. A collective instructions must only appear in a collective region, and an SPMD instruction -must only appear in a in a SPMD region. Mixed instructions might appear in both kinds of regions. +must only appear in a SPMD region. Mixed instructions might appear in both kinds of regions. SPMD regions may be nested in collective regions but collective regions must not be nested in SPMD regions. Core rules @@ -473,6 +473,18 @@ a, b, and c, respectively. a, b, and c must be vectors and have equal shape. +Parallel +........ + +.. code:: abnf + + instruction =/ "parallel" region + +Overview +~~~~~~~~ + +Opens an *spmd region*. + Sum ... @@ -832,6 +844,18 @@ Examples: #. ``load %0[%1] : group>`` returns a ``memref`` value. #. ``load %0[%1] : group, offset: ?>`` returns a ``memref`` value. +Number of subgroups +................... + +.. code:: abnf + + value-instruction =/ "num_subgroups" + +Overview +~~~~~~~~ + +Returns the number of subgroups the work-group is divided in; i32 integer. + For ... @@ -880,6 +904,19 @@ It is required that The local identifier must have the memref type specified last. The instruction returns an integer of index type. +Subgroup size +............. + +.. code:: abnf + + value-instruction =/ "subgroup_size" + +Overview +~~~~~~~~ + +Returns the subgroup size; i32 integer. + + Subview ....... @@ -1006,18 +1043,6 @@ Additional instructions SPMD instructions ----------------- -Number of subgroups -................... - -.. code:: abnf - - value-instruction =/ "num_subgroups" - -Overview -~~~~~~~~ - -Returns the number of subgroups the work-group is divided in; i32 integer. - Subgroup id ........... @@ -1042,19 +1067,6 @@ Overview Returns the work-item id within the sub-group; i32 integer from 0 to subgroup_size - 1. -Subgroup size -............. - -.. code:: abnf - - value-instruction =/ "subgroup_size" - -Overview -~~~~~~~~ - -Returns the subgroup size; i32 integer. - - Sample code =========== diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index fb70475c..1d9200c1 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -504,6 +504,36 @@ TINYTC_EXPORT tinytc_status_t tinytc_hadamard_inst_create( tinytc_inst_t *instr, tinytc_bool_t atomic, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, tinytc_value_t beta, tinytc_value_t C, const tinytc_location_t *loc); +/** + * @brief Create num_subgroups instruction + * + * @code %value = num_subgroups @endcode + * + * @param instr [out] pointer to the inst object created + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_num_subgroups_inst_create(tinytc_inst_t *instr, + const tinytc_location_t *loc); + +/** + * @brief Create parallel region + * + * @code + * parallel { %body } + * @endcode + * + * @param instr [out] pointer to the inst object created + * @param body [in] loop body + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_parallel_inst_create(tinytc_inst_t *instr, + tinytc_region_t body, + const tinytc_location_t *loc); + /** * @brief Create size instruction * @@ -519,6 +549,45 @@ TINYTC_EXPORT tinytc_status_t tinytc_hadamard_inst_create( TINYTC_EXPORT tinytc_status_t tinytc_size_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t mode, const tinytc_location_t *loc); +/** + * @brief Create subgroup_id instruction + * + * @code %value = subgroup_id @endcode + * + * @param instr [out] pointer to the inst object created + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_subgroup_id_inst_create(tinytc_inst_t *instr, + const tinytc_location_t *loc); + +/** + * @brief Create subgroup_local_id instruction + * + * @code %value = subgroup_local_id @endcode + * + * @param instr [out] pointer to the inst object created + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_subgroup_local_id_inst_create(tinytc_inst_t *instr, + const tinytc_location_t *loc); + +/** + * @brief Create subgroup_size instruction + * + * @code %value = subgroup_size @endcode + * + * @param instr [out] pointer to the inst object created + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_subgroup_size_inst_create(tinytc_inst_t *instr, + const tinytc_location_t *loc); + /** * @brief Create subview instruction * diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 6c828935..018bfef3 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1008,6 +1008,33 @@ inline inst make_hadamard(bool atomic, value const &alpha, value const &A, value return inst(instr); } +/** + * @brief Make num_subgroups instruction + * + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_num_subgroups(location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_num_subgroups_inst_create(&instr, &loc), loc); + return inst(instr); +} + +/** + * @brief Make parallel region + * + * @param body Loop body + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_parallel(region const &body, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_parallel_inst_create(&instr, body.get(), &loc), loc); + return inst(instr); +} + /** * @brief Make size instruction * @@ -1023,6 +1050,45 @@ inline inst make_size(value const &a, std::int64_t mode, location const &loc = { return inst(instr); } +/** + * @brief Make subgroup_id instruction + * + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_subgroup_id(location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_subgroup_id_inst_create(&instr, &loc), loc); + return inst(instr); +} + +/** + * @brief Make subgroup_local_id instruction + * + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_subgroup_local_id(location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_subgroup_local_id_inst_create(&instr, &loc), loc); + return inst(instr); +} + +/** + * @brief Make subgroup_size instruction + * + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_subgroup_size(location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_subgroup_size_inst_create(&instr, &loc), loc); + return inst(instr); +} + /** * @brief Make subview instruction * diff --git a/include/tinytc/types.h b/include/tinytc/types.h index f9f1f7ac..8d0170bf 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -60,6 +60,7 @@ typedef enum { tinytc_status_ir_expand_shape_mismatch = 0x110, ///< Invalid expand shape tinytc_status_ir_collective_called_from_spmd = 0x111, ///< Collective instruction from SPMD tinytc_status_ir_fp_unsupported = 0x112, ///< Instruction does not support floating type + tinytc_status_ir_spmd_called_from_collective = 0x113, ///< SPMD instruction from collective // Level zero errors tinytc_status_ze_result_not_ready = 0x10000, ///< ZE_RESULT_NOT_READY tinytc_status_ze_result_error_device_lost = 0x10001, ///< ZE_RESULT_ERROR_DEVICE_LOST diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 91ed0eb6..6465805a 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -70,6 +70,7 @@ enum class status { ir_expand_shape_mismatch = tinytc_status_ir_expand_shape_mismatch, ir_collective_called_from_spmd = tinytc_status_ir_collective_called_from_spmd, ir_fp_unsupported = tinytc_status_ir_fp_unsupported, + ir_spmd_called_from_collective = tinytc_status_ir_spmd_called_from_collective, // Level Zero errors ze_result_not_ready = tinytc_status_ze_result_not_ready, ze_result_error_device_lost = tinytc_status_ze_result_error_device_lost, diff --git a/src/error.cpp b/src/error.cpp index d315e54c..67486bdd 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -113,7 +113,8 @@ char const *tinytc_error_string(tinytc_status_t status) { case tinytc_status_unsupported_device: return "Unsupported device"; case tinytc_status_invalid_core_info: - return "Invalid core info object (e.g. max work group size is 0 or subgroup sizes vector is empty)"; + return "Invalid core info object (e.g. max work group size is 0 or subgroup sizes vector " + "is empty)"; // IR case tinytc_status_ir_out_of_bounds: return "Argument is out of bounds"; @@ -153,6 +154,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Collective instruction must not be called from SPMD region"; case tinytc_status_ir_fp_unsupported: return "Floating point type unsupported for instruction"; + case tinytc_status_ir_spmd_called_from_collective: + return "SPMD instruction must not be called from collective region"; // Level Zero case tinytc_status_ze_result_not_ready: return "ZE_RESULT_NOT_READY"; diff --git a/src/inst.cpp b/src/inst.cpp index 06f41460..2a8685eb 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -284,6 +284,25 @@ tinytc_status_t tinytc_hadamard_inst_create(tinytc_inst_t *instr, tinytc_bool_t }); } +tinytc_status_t tinytc_num_subgroups_inst_create(tinytc_inst_t *instr, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *instr = std::make_unique(get_optional(loc)).release(); }); +} + +tinytc_status_t tinytc_parallel_inst_create(tinytc_inst_t *instr, tinytc_region_t body, + const tinytc_location_t *loc) { + if (instr == nullptr || body == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = std::make_unique(region(body, true), get_optional(loc)).release(); + }); +} + tinytc_status_t tinytc_size_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t mode, const tinytc_location_t *loc) { if (instr == nullptr) { @@ -294,6 +313,32 @@ tinytc_status_t tinytc_size_inst_create(tinytc_inst_t *instr, tinytc_value_t a, }); } +tinytc_status_t tinytc_subgroup_id_inst_create(tinytc_inst_t *instr, const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *instr = std::make_unique(get_optional(loc)).release(); }); +} + +tinytc_status_t tinytc_subgroup_local_id_inst_create(tinytc_inst_t *instr, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *instr = std::make_unique(get_optional(loc)).release(); }); +} + +tinytc_status_t tinytc_subgroup_size_inst_create(tinytc_inst_t *instr, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *instr = std::make_unique(get_optional(loc)).release(); }); +} + tinytc_status_t tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t a, uint32_t slice_list_size, tinytc_value_t *offset_list, tinytc_value_t *size_list, diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 1dc54846..6822b93e 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -20,18 +20,20 @@ namespace tinytc { //! Instruction classification enum class inst_kind { - replicated, ///< replicated instruction executed in every work-item - collective ///< collective instruction distributed among work-items + mixed, ///< mixed instruction on uniform or varying data + collective, ///< collective instruction on uniform data, distributed among work-items + spmd ///< SPMD instruction on varying data + }; -using inst_nodes = - clir::virtual_type_list; +using inst_nodes = clir::virtual_type_list< + class alloca_inst, class axpby_inst, class barrier_inst, class arith_inst, + class arith_unary_inst, class cast_inst, class compare_inst, class expand_inst, class fuse_inst, + class load_inst, class group_id_inst, class group_size_inst, class lifetime_stop_inst, + class gemm_inst, class gemv_inst, class ger_inst, class for_inst, class foreach_inst, + class hadamard_inst, class if_inst, class num_subgroups_inst, class parallel_inst, + class size_inst, class subview_inst, class store_inst, class subgroup_id_inst, + class subgroup_local_id_inst, class subgroup_size_inst, class sum_inst, class yield_inst>; } // namespace tinytc @@ -148,7 +150,7 @@ class arith_inst : public clir::visitable { inline auto a() const -> value const & { return a_; } inline auto b() const -> value const & { return b_; } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline inst_kind kind() const override { return inst_kind::mixed; } private: arithmetic op_; @@ -162,7 +164,7 @@ class arith_unary_inst : public clir::visitable { inline arithmetic_unary op() const { return op_; } inline auto a() const -> value const & { return a_; } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline inst_kind kind() const override { return inst_kind::mixed; } private: arithmetic_unary op_; @@ -180,7 +182,7 @@ class cast_inst : public clir::visitable { cast_inst(value a, scalar_type to_ty, location const &lc = {}); inline auto a() const -> value const & { return a_; } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline inst_kind kind() const override { return inst_kind::mixed; } private: value a_, result_; @@ -194,7 +196,7 @@ class compare_inst : public clir::visitable { inline auto a() const -> value const & { return a_; } inline auto b() const -> value const & { return b_; } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline inst_kind kind() const override { return inst_kind::mixed; } private: cmp_condition cond_; @@ -211,7 +213,7 @@ class expand_inst : public clir::visitable { inline auto expand_shape() const -> std::vector const & { return expand_shape_; } inline auto expand_shape(std::int64_t i) const -> value const & { return expand_shape_[i]; } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline inst_kind kind() const override { return inst_kind::mixed; } private: value op_, result_; @@ -227,7 +229,7 @@ class fuse_inst : public clir::visitable { inline std::int64_t from() const { return from_; } inline std::int64_t to() const { return to_; } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline inst_kind kind() const override { return inst_kind::mixed; } private: value op_, result_; @@ -241,7 +243,7 @@ class load_inst : public clir::visitable { inline auto operand() const -> value const & { return op_; } inline auto index_list() const -> std::vector const & { return index_list_; } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline inst_kind kind() const override { return inst_kind::mixed; } private: value op_; @@ -255,7 +257,7 @@ class group_id_inst : public clir::visitable { loc(lc); } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline inst_kind kind() const override { return inst_kind::mixed; } private: value result_; @@ -267,7 +269,7 @@ class group_size_inst : public clir::visitable { loc(lc); } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline inst_kind kind() const override { return inst_kind::mixed; } private: value result_; @@ -320,7 +322,7 @@ class for_inst : public clir::visitable { public: using super = clir::visitable; using super::super; - inline inst_kind kind() const override { return inst_kind::replicated; } + inline inst_kind kind() const override { return inst_kind::mixed; } }; class foreach_inst : public clir::visitable { @@ -352,7 +354,7 @@ class if_inst : public clir::visitable { inline auto num_results() const -> std::size_t override { return results_.size(); } inline auto results_ref() -> std::vector & { return results_; } inline auto results_ref() const -> std::vector const & { return results_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline inst_kind kind() const override { return inst_kind::mixed; } private: value condition_; @@ -360,6 +362,30 @@ class if_inst : public clir::visitable { std::vector results_; }; +class num_subgroups_inst : public clir::visitable { + public: + inline num_subgroups_inst(location const &lc = {}) : result_{make_value(scalar_type::i32)} { + loc(lc); + } + inline value result() const override { return result_; } + inline inst_kind kind() const override { return inst_kind::mixed; } + + private: + value result_; +}; + +class parallel_inst : public clir::visitable { + public: + using super = clir::visitable; + inline parallel_inst(region body, location const &lc = {}) : body_(std::move(body)) { loc(lc); } + inline auto body() const -> region const & { return body_; } + inline inst_kind kind() const override { return inst_kind::collective; } + inline value result() const override { return value{}; } + + private: + region body_; +}; + class size_inst : public clir::visitable { public: size_inst(value op, std::int64_t mode, location const &lc = {}); @@ -367,13 +393,49 @@ class size_inst : public clir::visitable { inline auto operand() const -> value const & { return op_; } inline std::int64_t mode() const { return mode_; } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline inst_kind kind() const override { return inst_kind::mixed; } private: value op_, result_; std::int64_t mode_; }; +class subgroup_id_inst : public clir::visitable { + public: + inline subgroup_id_inst(location const &lc = {}) : result_{make_value(scalar_type::i32)} { + loc(lc); + } + inline value result() const override { return result_; } + inline inst_kind kind() const override { return inst_kind::spmd; } + + private: + value result_; +}; + +class subgroup_local_id_inst : public clir::visitable { + public: + inline subgroup_local_id_inst(location const &lc = {}) : result_{make_value(scalar_type::i32)} { + loc(lc); + } + inline value result() const override { return result_; } + inline inst_kind kind() const override { return inst_kind::spmd; } + + private: + value result_; +}; + +class subgroup_size_inst : public clir::visitable { + public: + inline subgroup_size_inst(location const &lc = {}) : result_{make_value(scalar_type::i32)} { + loc(lc); + } + inline value result() const override { return result_; } + inline inst_kind kind() const override { return inst_kind::mixed; } + + private: + value result_; +}; + class subview_inst : public clir::visitable { public: subview_inst(value op, std::vector slices, location const &lc = {}); @@ -381,7 +443,7 @@ class subview_inst : public clir::visitable { inline auto slices() const -> std::vector const & { return slices_; } inline auto operand() const -> value const & { return op_; } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline inst_kind kind() const override { return inst_kind::mixed; } private: value op_; @@ -397,7 +459,7 @@ class store_inst : public clir::visitable { inline auto operand() const -> value const & { return op_; } inline auto index_list() const -> std::vector const & { return index_list_; } inline value result() const override { return {}; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline inst_kind kind() const override { return inst_kind::mixed; } private: value val_, op_; @@ -423,7 +485,7 @@ class yield_inst : public clir::visitable { } inline value result() const override { return value{}; } inline auto vals() const -> std::vector const & { return vals_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline inst_kind kind() const override { return inst_kind::mixed; } private: std::vector vals_; diff --git a/src/parser/lexer.re b/src/parser/lexer.re index b59772ca..66187da1 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -139,8 +139,12 @@ lex: "for" { adv_loc(); return parser::make_FOR(loc_); } "foreach" { adv_loc(); return parser::make_FOREACH(loc_); } "if" { adv_loc(); return parser::make_IF(loc_); } + "num_subgroups" { adv_loc(); return parser::make_NUM_SUBGROUPS(loc_); } + "parallel" { adv_loc(); return parser::make_PARALLEL(loc_); } "else" { adv_loc(); return parser::make_ELSE(loc_); } "size" { adv_loc(); return parser::make_SIZE(loc_); } + "subgroup_id" { adv_loc(); return parser::make_SUBGROUP_ID(loc_); } + "subgroup_local_id" { adv_loc(); return parser::make_SUBGROUP_LOCAL_ID(loc_); } "subview" { adv_loc(); return parser::make_SUBVIEW(loc_); } "store" { adv_loc(); return parser::make_STORE(loc_); } "sum" { adv_loc(); return parser::make_SUM(loc_); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 07ff319e..8af96bf8 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -123,11 +123,15 @@ LOAD "load" FOR "for" FOREACH "foreach" - IF "if" - ELSE "else" GROUP_ID "group_id" GROUP_SIZE "group_size" + IF "if" + ELSE "else" + NUM_SUBGROUPS "num_subgroups" + PARALLEL "parallel" SIZE "size" + SUBGROUP_ID "subgroup_id" + SUBGROUP_LOCAL_ID "subgroup_local_id" SUBVIEW "subview" STORE "store" SUM "sum" @@ -203,7 +207,12 @@ %nterm <::tinytc::value> index_identifier_or_const %nterm group_id_inst %nterm group_size_inst +%nterm num_subgroups_inst +%nterm parallel_inst %nterm size_inst +%nterm subgroup_id_inst +%nterm subgroup_local_id_inst +%nterm subgroup_size_inst %nterm store_inst %nterm subview_inst %nterm > optional_slice_list @@ -395,6 +404,7 @@ instruction: | foreach_inst | hadamard_inst | if_inst + | parallel_inst | var_definition | store_inst | sum_inst @@ -674,11 +684,15 @@ valued_inst: | compare_inst | expand_inst | fuse_inst - | if_inst - | load_inst | group_id_inst | group_size_inst + | if_inst + | load_inst + | num_subgroups_inst | size_inst + | subgroup_id_inst + | subgroup_local_id_inst + | subgroup_size_inst | subview_inst ; @@ -873,11 +887,11 @@ store_inst: ; group_id_inst: - GROUP_ID { $$ = inst{std::make_unique().release()}; } + GROUP_ID { $$ = inst{std::make_unique(@GROUP_ID).release()}; } ; group_size_inst: - GROUP_SIZE { $$ = inst{std::make_unique().release()}; } + GROUP_SIZE { $$ = inst{std::make_unique(@GROUP_SIZE).release()}; } ; if_inst: @@ -911,6 +925,15 @@ scalar_type_list: | scalar_type_list COMMA scalar_type { $$ = std::move($1); $$.push_back($scalar_type); } ; +num_subgroups_inst: + NUM_SUBGROUPS { $$ = inst{std::make_unique(@NUM_SUBGROUPS).release()}; } +; + +parallel_inst: + PARALLEL region { + $$ = inst{std::make_unique(std::move($region), @parallel_inst) .release()}; + } +; size_inst: SIZE var LSQBR INTEGER_CONSTANT[mode] RSQBR COLON memref_type { @@ -928,6 +951,18 @@ size_inst: } ; +subgroup_id_inst: + SUBGROUP_ID { $$ = inst{std::make_unique(@SUBGROUP_ID).release()}; } +; + +subgroup_local_id_inst: + SUBGROUP_LOCAL_ID { $$ = inst{std::make_unique(@SUBGROUP_LOCAL_ID).release()}; } +; + +subgroup_size_inst: + SUBGROUP_SIZE { $$ = inst{std::make_unique(@SUBGROUP_SIZE).release()}; } +; + subview_inst: SUBVIEW var LSQBR optional_slice_list RSQBR COLON memref_type { if (!$var->ty() || !is_equal(*$var->ty(), *$memref_type)) { diff --git a/src/visitor/alias_analysis.cpp b/src/visitor/alias_analysis.cpp index 25e78ff8..1d0abb1d 100644 --- a/src/visitor/alias_analysis.cpp +++ b/src/visitor/alias_analysis.cpp @@ -48,6 +48,8 @@ void alias_analyser::operator()(if_inst const &in) { } } +void alias_analyser::operator()(parallel_inst const &p) { visit(*this, *p.body()); } + void alias_analyser::operator()(subview_inst const &s) { value_node const *source = s.operand().get(); while (alias_.find(source) != alias_.end()) { diff --git a/src/visitor/alias_analysis.hpp b/src/visitor/alias_analysis.hpp index d1ce730c..183b5c7b 100644 --- a/src/visitor/alias_analysis.hpp +++ b/src/visitor/alias_analysis.hpp @@ -23,6 +23,7 @@ class alias_analyser { void operator()(expand_inst const &e); void operator()(fuse_inst const &f); void operator()(if_inst const &in); + void operator()(parallel_inst const &p); void operator()(subview_inst const &s); /* Region nodes */ diff --git a/src/visitor/check_ir.cpp b/src/visitor/check_ir.cpp index 33a1365e..ab107ff7 100644 --- a/src/visitor/check_ir.cpp +++ b/src/visitor/check_ir.cpp @@ -16,9 +16,10 @@ namespace tinytc { /* Stmt nodes */ void ir_checker::operator()(inst_node const &in) { - bool ok = in.kind() != inst_kind::collective || !inside_spmd_region_; - if (!ok) { + if (in.kind() == inst_kind::collective && inside_spmd_region_) { throw compilation_error(in.loc(), status::ir_collective_called_from_spmd); + } else if (in.kind() == inst_kind::spmd && !inside_spmd_region_) { + throw compilation_error(in.loc(), status::ir_spmd_called_from_collective); } } void ir_checker::operator()(for_inst const &p) { return visit(*this, *p.body()); } @@ -34,6 +35,12 @@ void ir_checker::operator()(if_inst const &in) { visit(*this, *in.otherwise()); } } +void ir_checker::operator()(parallel_inst const &p) { + this->operator()(static_cast(p)); + inside_spmd_region_ = true; + visit(*this, *p.body()); + inside_spmd_region_ = false; +} /* Region nodes */ void ir_checker::operator()(rgn const &b) { diff --git a/src/visitor/check_ir.hpp b/src/visitor/check_ir.hpp index da135768..0b09eaa3 100644 --- a/src/visitor/check_ir.hpp +++ b/src/visitor/check_ir.hpp @@ -18,6 +18,7 @@ class ir_checker { void operator()(for_inst const &p); void operator()(foreach_inst const &p); void operator()(if_inst const &in); + void operator()(parallel_inst const &p); /* Region nodes */ void operator()(rgn const &b); diff --git a/src/visitor/dump_ir.cpp b/src/visitor/dump_ir.cpp index c99df024..b462ff8c 100644 --- a/src/visitor/dump_ir.cpp +++ b/src/visitor/dump_ir.cpp @@ -271,6 +271,16 @@ void ir_dumper::operator()(if_inst const &in) { } } +void ir_dumper::operator()(num_subgroups_inst const &sg) { + visit(*this, *sg.result()); + os_ << " = num_subgroups"; +} + +void ir_dumper::operator()(parallel_inst const &p) { + os_ << "parallel "; + visit(*this, *p.body()); +} + void ir_dumper::operator()(size_inst const &s) { visit(*this, *s.result()); os_ << " = size "; @@ -280,6 +290,21 @@ void ir_dumper::operator()(size_inst const &s) { visit(*this, *s.operand()->ty()); } +void ir_dumper::operator()(subgroup_id_inst const &sg) { + visit(*this, *sg.result()); + os_ << " = subgroup_id"; +} + +void ir_dumper::operator()(subgroup_local_id_inst const &sg) { + visit(*this, *sg.result()); + os_ << " = subgroup_local_id"; +} + +void ir_dumper::operator()(subgroup_size_inst const &sg) { + visit(*this, *sg.result()); + os_ << " = subgroup_size"; +} + void ir_dumper::operator()(subview_inst const &s) { visit(*this, *s.result()); os_ << " = subview "; diff --git a/src/visitor/dump_ir.hpp b/src/visitor/dump_ir.hpp index c54388a3..6912fde3 100644 --- a/src/visitor/dump_ir.hpp +++ b/src/visitor/dump_ir.hpp @@ -53,7 +53,12 @@ class ir_dumper { void operator()(foreach_inst const &p); void operator()(hadamard_inst const &g); void operator()(if_inst const &in); + void operator()(num_subgroups_inst const &sg); + void operator()(parallel_inst const &p); void operator()(size_inst const &s); + void operator()(subgroup_id_inst const &sg); + void operator()(subgroup_local_id_inst const &sg); + void operator()(subgroup_size_inst const &sg); void operator()(subview_inst const &s); void operator()(store_inst const &s); void operator()(sum_inst const &s); diff --git a/src/visitor/insert_barrier.cpp b/src/visitor/insert_barrier.cpp index d4be98cf..4cc21435 100644 --- a/src/visitor/insert_barrier.cpp +++ b/src/visitor/insert_barrier.cpp @@ -85,6 +85,10 @@ std::unordered_set insert_barrier::operator()(if_inst &in) { std::unordered_set insert_barrier::operator()(lifetime_stop_inst &) { return {}; } +std::unordered_set insert_barrier::operator()(parallel_inst &p) { + return visit(*this, *p.body()); +} + std::unordered_set insert_barrier::operator()(size_inst &) { return {}; } std::unordered_set insert_barrier::operator()(store_inst &s) { diff --git a/src/visitor/insert_barrier.hpp b/src/visitor/insert_barrier.hpp index ce87904a..0fe15094 100644 --- a/src/visitor/insert_barrier.hpp +++ b/src/visitor/insert_barrier.hpp @@ -41,6 +41,7 @@ class insert_barrier { std::unordered_set operator()(load_inst &e); std::unordered_set operator()(if_inst &in); std::unordered_set operator()(lifetime_stop_inst &); + std::unordered_set operator()(parallel_inst &p); std::unordered_set operator()(size_inst &s); std::unordered_set operator()(store_inst &s); std::unordered_set operator()(subview_inst &s); diff --git a/src/visitor/lifetime_analysis.cpp b/src/visitor/lifetime_analysis.cpp index 13cb7734..fa11a1b2 100644 --- a/src/visitor/lifetime_analysis.cpp +++ b/src/visitor/lifetime_analysis.cpp @@ -26,6 +26,15 @@ value find_alloca::operator()(for_inst &p) { } return value{}; } +value find_alloca::operator()(if_inst &in) { + if (recursive_) { + visit(*this, *in.then()); + if (in.otherwise()) { + visit(*this, *in.otherwise()); + } + } + return value{}; +} /* Region nodes */ value find_alloca::operator()(rgn &b) { @@ -87,6 +96,10 @@ auto lifetime_inserter::operator()(lifetime_stop_inst &ls) return {ls.object().get()}; } +auto lifetime_inserter::operator()(parallel_inst &p) -> std::unordered_set { + return visit(*this, *p.body()); +} + auto lifetime_inserter::operator()(size_inst &s) -> std::unordered_set { return std::unordered_set{s.operand().get()}; } diff --git a/src/visitor/lifetime_analysis.hpp b/src/visitor/lifetime_analysis.hpp index 5e3a3528..2137eaab 100644 --- a/src/visitor/lifetime_analysis.hpp +++ b/src/visitor/lifetime_analysis.hpp @@ -25,6 +25,7 @@ class find_alloca { value operator()(inst_node &); value operator()(alloca_inst &a); value operator()(for_inst &p); + value operator()(if_inst &p); /* Region nodes */ value operator()(rgn &); @@ -50,6 +51,7 @@ class lifetime_inserter { auto operator()(load_inst &e) -> std::unordered_set<::tinytc_value const *>; auto operator()(if_inst &in) -> std::unordered_set<::tinytc_value const *>; auto operator()(lifetime_stop_inst &) -> std::unordered_set<::tinytc_value const *>; + auto operator()(parallel_inst &p) -> std::unordered_set<::tinytc_value const *>; auto operator()(size_inst &s) -> std::unordered_set<::tinytc_value const *>; auto operator()(store_inst &s) -> std::unordered_set<::tinytc_value const *>; auto operator()(subview_inst &s) -> std::unordered_set<::tinytc_value const *>; diff --git a/src/visitor/opencl_ast.cpp b/src/visitor/opencl_ast.cpp index 9d7b241a..9a347ec2 100644 --- a/src/visitor/opencl_ast.cpp +++ b/src/visitor/opencl_ast.cpp @@ -856,6 +856,17 @@ std::vector opencl_ast::operator()(if_inst const &in) { return clinst; } +std::vector opencl_ast::operator()(num_subgroups_inst const &sg) { + auto rhs = clir::get_num_sub_groups(); + auto lhs = declare(*sg.result()); + return { + declaration_assignment(visit(*this, *sg.result()->ty()), std::move(lhs), std::move(rhs))}; +} + +std::vector opencl_ast::operator()(parallel_inst const &p) { + return {visit(*this, *p.body())}; +} + std::vector opencl_ast::operator()(size_inst const &s) { auto v = declare(*s.result()); auto &dv = get_dope_vector(s.operand().get()); @@ -864,6 +875,27 @@ std::vector opencl_ast::operator()(size_inst const &s) { dv.shape(s.mode()))}; } +std::vector opencl_ast::operator()(subgroup_id_inst const &sg) { + auto rhs = clir::get_sub_group_id(); + auto lhs = declare(*sg.result()); + return { + declaration_assignment(visit(*this, *sg.result()->ty()), std::move(lhs), std::move(rhs))}; +} + +std::vector opencl_ast::operator()(subgroup_local_id_inst const &sg) { + auto rhs = clir::get_sub_group_local_id(); + auto lhs = declare(*sg.result()); + return { + declaration_assignment(visit(*this, *sg.result()->ty()), std::move(lhs), std::move(rhs))}; +} + +std::vector opencl_ast::operator()(subgroup_size_inst const &sg) { + auto rhs = clir::get_sub_group_size(); + auto lhs = declare(*sg.result()); + return { + declaration_assignment(visit(*this, *sg.result()->ty()), std::move(lhs), std::move(rhs))}; +} + std::vector opencl_ast::operator()(subview_inst const &s) { auto result_var = declare(*s.result()); auto t = get_memref_type(*s.operand()); diff --git a/src/visitor/opencl_ast.hpp b/src/visitor/opencl_ast.hpp index e76a8c45..94f54271 100644 --- a/src/visitor/opencl_ast.hpp +++ b/src/visitor/opencl_ast.hpp @@ -91,7 +91,12 @@ class opencl_ast { std::vector operator()(foreach_inst const &in); std::vector operator()(hadamard_inst const &g); std::vector operator()(if_inst const &in); + std::vector operator()(num_subgroups_inst const &sg); + std::vector operator()(parallel_inst const &p); std::vector operator()(size_inst const &s); + std::vector operator()(subgroup_id_inst const &sg); + std::vector operator()(subgroup_local_id_inst const &sg); + std::vector operator()(subgroup_size_inst const &sg); std::vector operator()(subview_inst const &s); std::vector operator()(store_inst const &s); std::vector operator()(sum_inst const &s); diff --git a/src/visitor/slot_tracker.cpp b/src/visitor/slot_tracker.cpp index ec7191a9..663e9162 100644 --- a/src/visitor/slot_tracker.cpp +++ b/src/visitor/slot_tracker.cpp @@ -37,6 +37,8 @@ void slot_tracker::operator()(if_inst const &in) { } } +void slot_tracker::operator()(parallel_inst const &p) { return visit(*this, *p.body()); } + /* Region nodes */ void slot_tracker::operator()(rgn const &b) { for (auto const &s : b.insts()) { diff --git a/src/visitor/slot_tracker.hpp b/src/visitor/slot_tracker.hpp index c92400b3..3c93e683 100644 --- a/src/visitor/slot_tracker.hpp +++ b/src/visitor/slot_tracker.hpp @@ -21,6 +21,7 @@ class slot_tracker { void operator()(inst_node const &in); void operator()(loop_inst const &p); void operator()(if_inst const &in); + void operator()(parallel_inst const &p); /* Region nodes */ void operator()(rgn const &b); diff --git a/src/visitor/stack.cpp b/src/visitor/stack.cpp index 1702507d..aa646bc5 100644 --- a/src/visitor/stack.cpp +++ b/src/visitor/stack.cpp @@ -52,6 +52,13 @@ void stack_ptr::operator()(lifetime_stop_inst &s) { } void stack_ptr::operator()(for_inst &p) { visit(*this, *p.body()); } +void stack_ptr::operator()(if_inst &in) { + visit(*this, *in.then()); + if (in.otherwise()) { + visit(*this, *in.otherwise()); + } +} + /* Region nodes */ void stack_ptr::operator()(rgn &b) { for (auto &s : b.insts()) { diff --git a/src/visitor/stack.hpp b/src/visitor/stack.hpp index d360be46..f515d87a 100644 --- a/src/visitor/stack.hpp +++ b/src/visitor/stack.hpp @@ -22,6 +22,7 @@ class stack_ptr { void operator()(alloca_inst &a); void operator()(lifetime_stop_inst &s); void operator()(for_inst &p); + void operator()(if_inst &in); /* Region nodes */ void operator()(rgn &b); diff --git a/src/visitor/work_group_size.cpp b/src/visitor/work_group_size.cpp index 9e6c3b70..c67cfb57 100644 --- a/src/visitor/work_group_size.cpp +++ b/src/visitor/work_group_size.cpp @@ -61,6 +61,7 @@ void work_group_size::operator()(if_inst &in) { } } void work_group_size::operator()(loop_inst &in) { visit(*this, *in.body()); } +void work_group_size::operator()(parallel_inst &p) { visit(*this, *p.body()); } /* Region nodes */ void work_group_size::operator()(rgn &b) { diff --git a/src/visitor/work_group_size.hpp b/src/visitor/work_group_size.hpp index b7dec02f..db350be1 100644 --- a/src/visitor/work_group_size.hpp +++ b/src/visitor/work_group_size.hpp @@ -25,6 +25,7 @@ class work_group_size { void operator()(blas_a3_inst &in); void operator()(if_inst &in); void operator()(loop_inst &in); + void operator()(parallel_inst &p); /* Region nodes */ void operator()(rgn &b); diff --git a/test/codegen/nesting2.ir b/test/codegen/nesting2.ir new file mode 100644 index 00000000..e336529d --- /dev/null +++ b/test/codegen/nesting2.ir @@ -0,0 +1,8 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s +func @illegal_nesting() { + %0 = subgroup_id +; CHECK: 6.10-20: SPMD instruction must not be called from collective region +} diff --git a/test/codegen/nesting3.ir b/test/codegen/nesting3.ir new file mode 100644 index 00000000..9e3151b2 --- /dev/null +++ b/test/codegen/nesting3.ir @@ -0,0 +1,11 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s +func @illegal_nesting() { + parallel { + foreach %j=1,16 { + } +; CHECK: 7.9-8.9: Collective instruction must not be called from SPMD region + } +} diff --git a/test/codegen/subgroup.ir b/test/codegen/subgroup.ir new file mode 100644 index 00000000..a834a55b --- /dev/null +++ b/test/codegen/subgroup.ir @@ -0,0 +1,16 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc < %s | filecheck %s +func @t1() { + parallel { + %0 = num_subgroups + %1 = subgroup_id + %2 = subgroup_local_id + %3 = subgroup_size + } +; CHECK: int x0 = get_num_sub_groups(); +; CHECK-NEXT: int x1 = get_sub_group_id(); +; CHECK-NEXT: int x2 = get_sub_group_local_id(); +; CHECK-NEXT: int x3 = get_sub_group_size(); +} From d95fed67a59c3a0d83eb477f287bb103408e7cd1 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 6 Aug 2024 06:24:12 -0700 Subject: [PATCH 010/297] fix faulty core features logic Signed-off-by: Carsten Uphoff --- examples/tall_and_skinny/args.cpp | 4 ++++ examples/tall_and_skinny/args.hpp | 1 + examples/tall_and_skinny/main.cpp | 4 +++- src/device_info.cpp | 20 ++++++++------------ src/device_info.hpp | 2 +- 5 files changed, 17 insertions(+), 14 deletions(-) diff --git a/examples/tall_and_skinny/args.cpp b/examples/tall_and_skinny/args.cpp index 6aa1268a..faadc5e0 100644 --- a/examples/tall_and_skinny/args.cpp +++ b/examples/tall_and_skinny/args.cpp @@ -14,6 +14,7 @@ args arg_parser::parse_args(int argc, char **argv) { a.beta = 0.0; a.specialize_M = false; a.specialize_ld = false; + a.large_GRF = true; auto num = std::vector(3); for (int i = 1; i < argc; ++i) { if (argv[i][0] == '-') { @@ -29,6 +30,8 @@ args arg_parser::parse_args(int argc, char **argv) { a.specialize_M = true; } else if (std::strcmp(argv[i], "--specialize-ld") == 0) { a.specialize_ld = true; + } else if (std::strcmp(argv[i], "--small-grf") == 0) { + a.large_GRF = false; } else if (i + 1 < argc) { if (std::strcmp(argv[i], "-b") == 0 || std::strcmp(argv[i], "--beta") == 0) { ++i; @@ -80,5 +83,6 @@ optional arguments: -v, --verify Verify optimized implementation --specialize-M Specialize M instead of using dynamic value --specialize-ld Specialize ldA, ldB, ldC instead of using dynamic value + --small-grf Request small GRF mode instead of large GRF mode )HELP"; } diff --git a/examples/tall_and_skinny/args.hpp b/examples/tall_and_skinny/args.hpp index 701710fe..67edb5dc 100644 --- a/examples/tall_and_skinny/args.hpp +++ b/examples/tall_and_skinny/args.hpp @@ -22,6 +22,7 @@ struct args { double beta; bool specialize_M; bool specialize_ld; + bool large_GRF; }; class arg_parser { diff --git a/examples/tall_and_skinny/main.cpp b/examples/tall_and_skinny/main.cpp index 8538cf0e..646b8f2a 100644 --- a/examples/tall_and_skinny/main.cpp +++ b/examples/tall_and_skinny/main.cpp @@ -104,7 +104,9 @@ template void test(queue q, args &a) { try { source_ctx = make_source_context(); auto info = make_core_info(q.get_device()); - info.set_core_features(tinytc_core_feature_flag_large_register_file); + if (a.large_GRF) { + info.set_core_features(tinytc_core_feature_flag_large_register_file); + } std::int64_t M = a.specialize_M ? c.m : dynamic; std::int64_t ldA = dynamic, ldB = dynamic, ldC = dynamic; diff --git a/src/device_info.cpp b/src/device_info.cpp index 5d239378..424c986c 100644 --- a/src/device_info.cpp +++ b/src/device_info.cpp @@ -51,7 +51,6 @@ core_info_intel::core_info_intel(std::uint32_t ip_version, std::int32_t num_eus_ if (ip_version_ >= static_cast(intel_gpu_architecture::pvc)) { register_size_ = 64; } - num_registers_per_thread_ = num_reg_small_grf(); } auto core_info_intel::num_reg_small_grf() const -> std::int32_t { return 128; } @@ -62,29 +61,26 @@ auto core_info_intel::num_reg_large_grf() const -> std::int32_t { : num_reg_small_grf(); } +auto core_info_intel::num_reg() const -> std::int32_t { + return core_features_ & tinytc_core_feature_flag_large_register_file ? num_reg_large_grf() + : num_reg_small_grf(); +} + auto core_info_intel::subgroup_sizes() const -> std::vector const & { return subgroup_sizes_; } -auto core_info_intel::register_space() const -> std::int32_t { - return register_size_ * num_registers_per_thread_; -} +auto core_info_intel::register_space() const -> std::int32_t { return register_size_ * num_reg(); } auto core_info_intel::core_features() const -> tinytc_core_feature_flags_t { return core_features_; } -void core_info_intel::core_features(tinytc_core_feature_flags_t flags) { - if (flags & tinytc_core_feature_flag_large_register_file) { - num_registers_per_thread_ = num_reg_large_grf(); - } else { - num_registers_per_thread_ = num_reg_small_grf(); - } -} +void core_info_intel::core_features(tinytc_core_feature_flags_t flags) { core_features_ = flags; } auto core_info_intel::max_work_group_size(std::int32_t subgroup_size) const -> std::int32_t { auto const num_threads_per_eu_due_to_register_use = - num_threads_per_eu_ * num_reg_small_grf() / num_registers_per_thread_; + num_threads_per_eu_ * num_reg_small_grf() / num_reg(); auto const num_threads_per_eu_due_to_subgroup_size = num_threads_per_eu_ * subgroup_sizes_.front() / subgroup_size; auto const num_threads_per_eu = diff --git a/src/device_info.hpp b/src/device_info.hpp index a0e638ba..7c0e58d9 100644 --- a/src/device_info.hpp +++ b/src/device_info.hpp @@ -91,6 +91,7 @@ class core_info_intel : public ::tinytc_core_info { private: auto num_reg_small_grf() const -> std::int32_t; auto num_reg_large_grf() const -> std::int32_t; + auto num_reg() const -> std::int32_t; auto max_work_group_size(std::int32_t subgroup_size) const -> std::int32_t; std::uint32_t ip_version_; @@ -98,7 +99,6 @@ class core_info_intel : public ::tinytc_core_info { std::int32_t num_threads_per_eu_; std::vector subgroup_sizes_; std::int32_t register_size_; - std::int32_t num_registers_per_thread_; tinytc_core_feature_flags_t core_features_; }; From 43297878fb43c6779356d1024b7bc383f07f7b77 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 2 Sep 2024 15:03:17 +0200 Subject: [PATCH 011/297] Change inst node to make operands iterable Signed-off-by: Carsten Uphoff --- include/tinytc/tinytc.hpp | 10 +- src/inst.cpp | 12 +- src/node/inst_node.cpp | 233 +++++++++++++++-------------- src/node/inst_node.hpp | 234 ++++++++++++++++++------------ src/parser/parser_impl.yy | 26 ++-- src/slice.hpp | 23 --- src/value.cpp | 9 +- src/visitor/dump_ir.cpp | 15 +- src/visitor/insert_barrier.cpp | 4 +- src/visitor/insert_barrier.hpp | 2 +- src/visitor/lifetime_analysis.cpp | 8 +- src/visitor/lifetime_analysis.hpp | 2 +- src/visitor/opencl_ast.cpp | 34 ++--- 13 files changed, 333 insertions(+), 279 deletions(-) delete mode 100644 src/slice.hpp diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 018bfef3..6fc3f516 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1465,7 +1465,9 @@ class region_builder { } /** - * @brief Build for-loop with functor f(region_builder&) -> void + * @brief Build for-loop with functor f(region_builder&, value) -> void + * + * The loop trip count is passed as second argument to the functor. * * @tparam F Functor type * @param loop_var_ty Type of loop variable @@ -1482,7 +1484,9 @@ class region_builder { std::forward(f), name, loc); } /** - * @brief Build for-loop with functor f(region_builder&) -> void + * @brief Build for-loop with functor f(region_builder&, value) -> void + * + * The loop trip count is passed as second argument to the functor. * * @tparam F Functor type * @param loop_var_ty Type of loop variable @@ -1501,7 +1505,7 @@ class region_builder { loop_var.name(name); } auto bb = region_builder{}; - f(bb); + f(bb, loop_var); add(::tinytc::make_for(std::move(loop_var), from, to, step, bb.get_product(), loc)); } /** diff --git a/src/inst.cpp b/src/inst.cpp index 2a8685eb..dedf2308 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -4,7 +4,6 @@ #include "error.hpp" #include "location.hpp" #include "node/inst_node.hpp" -#include "slice.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" @@ -348,13 +347,16 @@ tinytc_status_t tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto slice_vec = std::vector(); - slice_vec.reserve(slice_list_size); + auto offset_vec = std::vector(); + auto size_vec = std::vector(); + offset_vec.reserve(slice_list_size); + size_vec.reserve(slice_list_size); for (uint32_t i = 0; i < slice_list_size; ++i) { - slice_vec.emplace_back(value(offset_list[i], true), value(size_list[i], true)); + offset_vec.emplace_back(value(offset_list[i], true)); + size_vec.emplace_back(value(size_list[i], true)); } *instr = - std::make_unique(value(a, true), std::move(slice_vec), get_optional(loc)) + std::make_unique(value(a, true), offset_vec, size_vec, get_optional(loc)) .release(); }); } diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 0fa368ea..fd101276 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -18,7 +18,7 @@ namespace tinytc { -scalar_data_type *get_scalar_type(location const &loc, value &v) { +scalar_data_type *get_scalar_type(location const &loc, value const &v) { auto m = dynamic_cast(v->ty().get()); if (m == nullptr) { throw compilation_error(loc, status::ir_expected_scalar); @@ -26,7 +26,7 @@ scalar_data_type *get_scalar_type(location const &loc, value &v) { return m; } -memref_data_type *get_memref_type(location const &loc, value &v) { +memref_data_type *get_memref_type(location const &loc, value const &v) { auto m = dynamic_cast(v->ty().get()); if (m == nullptr) { throw compilation_error(loc, status::ir_expected_memref); @@ -35,27 +35,27 @@ memref_data_type *get_memref_type(location const &loc, value &v) { } blas_a2_inst::blas_a2_inst(value alpha, value A, value beta, value B, bool atomic) - : alpha_(std::move(alpha)), A_(std::move(A)), beta_(std::move(beta)), B_(std::move(B)), + : standard_inst{std::move(alpha), std::move(A), std::move(beta), std::move(B)}, atomic_(atomic) {} blas_a3_inst::blas_a3_inst(value alpha, value A, value B, value beta, value C, bool atomic) - : alpha_(std::move(alpha)), A_(std::move(A)), B_(std::move(B)), beta_(std::move(beta)), - C_(std::move(C)), atomic_(atomic) {} + : standard_inst{std::move(alpha), std::move(A), std::move(B), std::move(beta), std::move(C)}, + atomic_(atomic) {} loop_inst::loop_inst(value loop_var, value from, value to, region body, location const &lc) : loop_inst(std::move(loop_var), std::move(from), std::move(to), {}, std::move(body), lc) {} -loop_inst::loop_inst(value loop_var, value from, value to, value step, region body, +loop_inst::loop_inst(value loop_var0, value from0, value to0, value step0, region body, location const &lc) - : loop_var_(std::move(loop_var)), from_(std::move(from)), to_(std::move(to)), - step_(std::move(step)), body_(std::move(body)) { + : standard_inst{std::move(loop_var0), std::move(from0), std::move(to0), std::move(step0)}, + body_(std::move(body)) { loc(lc); - auto lvt = get_scalar_type(loc(), loop_var_); - auto fromt = get_scalar_type(loc(), from_); - auto tot = get_scalar_type(loc(), to_); + auto lvt = get_scalar_type(loc(), loop_var()); + auto fromt = get_scalar_type(loc(), from()); + auto tot = get_scalar_type(loc(), to()); bool step_ok = true; - if (step_) { - auto stept = get_scalar_type(loc(), step_); + if (step()) { + auto stept = get_scalar_type(loc(), step()); step_ok = lvt->ty() == stept->ty(); } @@ -74,12 +74,12 @@ alloca_inst::alloca_inst(data_type ty, location const &lc) memref->addrspace(clir::address_space::local_t); } -axpby_inst::axpby_inst(transpose tA, value alpha, value A, value beta, value B, bool atomic, +axpby_inst::axpby_inst(transpose tA, value alpha0, value A0, value beta0, value B0, bool atomic, location const &lc) - : super(std::move(alpha), std::move(A), std::move(beta), std::move(B), atomic), tA_(tA) { + : super(std::move(alpha0), std::move(A0), std::move(beta0), std::move(B0), atomic), tA_(tA) { loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); bool shape_equal = false; if (tA_ == transpose::T && a->dim() == 2 && b->dim() == 2) { @@ -97,18 +97,18 @@ axpby_inst::axpby_inst(transpose tA, value alpha, value A, value beta, value B, } } -arith_inst::arith_inst(arithmetic op, value a, value b, location const &lc) - : op_(op), a_(std::move(a)), b_(std::move(b)) { +arith_inst::arith_inst(arithmetic operation, value a0, value b0, location const &lc) + : super{std::move(a0), std::move(b0)}, operation_(operation) { loc(lc); - auto at = get_scalar_type(loc(), a_); - auto bt = get_scalar_type(loc(), b_); + auto at = get_scalar_type(loc(), a()); + auto bt = get_scalar_type(loc(), b()); if (at->ty() != bt->ty()) { throw compilation_error(loc(), status::ir_scalar_mismatch); } bool inst_supports_fp = false; - switch (op) { + switch (operation) { case arithmetic::add: case arithmetic::sub: case arithmetic::mul: @@ -130,13 +130,13 @@ arith_inst::arith_inst(arithmetic op, value a, value b, location const &lc) result_ = make_value(at->ty()); } -arith_unary_inst::arith_unary_inst(arithmetic_unary op, value a, location const &lc) - : op_(op), a_(std::move(a)) { +arith_unary_inst::arith_unary_inst(arithmetic_unary operation, value a0, location const &lc) + : super{std::move(a0)}, operation_(operation) { loc(lc); - auto at = get_scalar_type(loc(), a_); + auto at = get_scalar_type(loc(), a()); bool inst_supports_fp = false; - switch (op) { + switch (operation) { case arithmetic_unary::neg: inst_supports_fp = true; break; @@ -151,30 +151,31 @@ arith_unary_inst::arith_unary_inst(arithmetic_unary op, value a, location const } cast_inst::cast_inst(value a, scalar_type to_ty, location const &lc) - : a_(std::move(a)), result_{make_value(to_ty)} { + : super{std::move(a)}, result_{make_value(to_ty)} { loc(lc); } -compare_inst::compare_inst(cmp_condition cond, value a, value b, location const &lc) - : cond_(cond), a_(std::move(a)), b_(std::move(b)), result_{make_value(scalar_type::i1)} { +compare_inst::compare_inst(cmp_condition cond, value a0, value b0, location const &lc) + : super{std::move(a0), std::move(b0)}, cond_(cond), result_{make_value(scalar_type::i1)} { loc(lc); - auto at = get_scalar_type(loc(), a_); - auto bt = get_scalar_type(loc(), b_); + auto at = get_scalar_type(loc(), a()); + auto bt = get_scalar_type(loc(), b()); if (at->ty() != bt->ty()) { throw compilation_error(loc(), status::ir_scalar_mismatch); } } -gemm_inst::gemm_inst(transpose tA, transpose tB, value alpha, value A, value B, value beta, value C, - bool atomic, location const &lc) - : super(std::move(alpha), std::move(A), std::move(B), std::move(beta), std::move(C), atomic), +gemm_inst::gemm_inst(transpose tA, transpose tB, value alpha0, value A0, value B0, value beta0, + value C0, bool atomic, location const &lc) + : super(std::move(alpha0), std::move(A0), std::move(B0), std::move(beta0), std::move(C0), + atomic), tA_(tA), tB_(tB) { loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); - auto c = get_memref_type(loc(), C_); + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + auto c = get_memref_type(loc(), C()); if (a->dim() != 2 || b->dim() != 2 || c->dim() != 2) { throw compilation_error(loc(), status::ir_expected_vector_or_matrix, @@ -196,14 +197,15 @@ gemm_inst::gemm_inst(transpose tA, transpose tB, value alpha, value A, value B, } } -gemv_inst::gemv_inst(transpose tA, value alpha, value A, value B, value beta, value C, bool atomic, - location const &lc) - : super(std::move(alpha), std::move(A), std::move(B), std::move(beta), std::move(C), atomic), +gemv_inst::gemv_inst(transpose tA, value alpha0, value A0, value B0, value beta0, value C0, + bool atomic, location const &lc) + : super(std::move(alpha0), std::move(A0), std::move(B0), std::move(beta0), std::move(C0), + atomic), tA_(tA) { loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); - auto c = get_memref_type(loc(), C_); + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + auto c = get_memref_type(loc(), C()); if (a->dim() != 2 || b->dim() != 1 || c->dim() != 1) { throw compilation_error(loc(), status::ir_expected_vector_or_matrix, @@ -223,13 +225,14 @@ gemv_inst::gemv_inst(transpose tA, value alpha, value A, value B, value beta, va } } -ger_inst::ger_inst(value alpha, value A, value B, value beta, value C, bool atomic, +ger_inst::ger_inst(value alpha0, value A0, value B0, value beta0, value C0, bool atomic, location const &lc) - : super(std::move(alpha), std::move(A), std::move(B), std::move(beta), std::move(C), atomic) { + : super(std::move(alpha0), std::move(A0), std::move(B0), std::move(beta0), std::move(C0), + atomic) { loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); - auto c = get_memref_type(loc(), C_); + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + auto c = get_memref_type(loc(), C()); if (a->dim() != 1 || b->dim() != 1 || c->dim() != 2) { throw compilation_error(loc(), status::ir_expected_vector_or_matrix, @@ -248,13 +251,14 @@ ger_inst::ger_inst(value alpha, value A, value B, value beta, value C, bool atom } } -hadamard_inst::hadamard_inst(value alpha, value A, value B, value beta, value C, bool atomic, +hadamard_inst::hadamard_inst(value alpha0, value A0, value B0, value beta0, value C0, bool atomic, location const &lc) - : super(std::move(alpha), std::move(A), std::move(B), std::move(beta), std::move(C), atomic) { + : super(std::move(alpha0), std::move(A0), std::move(B0), std::move(beta0), std::move(C0), + atomic) { loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); - auto c = get_memref_type(loc(), C_); + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + auto c = get_memref_type(loc(), C()); if (a->dim() != 1 || b->dim() != 1 || c->dim() != 1) { throw compilation_error(loc(), status::ir_expected_vector_or_matrix, @@ -272,25 +276,27 @@ hadamard_inst::hadamard_inst(value alpha, value A, value B, value beta, value C, } } -expand_inst::expand_inst(value op, std::int64_t mode, std::vector expand_shape, +expand_inst::expand_inst(value op0, std::int64_t mode, std::vector const &expand_shape0, location const &lc) - : op_(std::move(op)), mode_(mode), expand_shape_(std::move(expand_shape)) { + : super{std::move(op0)}, mode_(mode) { loc(lc); - auto m = get_memref_type(loc(), op_); + ops().insert(ops().end(), expand_shape0.begin(), expand_shape0.end()); + + auto m = get_memref_type(loc(), operand()); bool const range_ok = 0 <= mode_ && mode_ < m->dim(); if (!range_ok) { throw compilation_error(loc(), status::ir_out_of_bounds); } - if (expand_shape_.size() < 2) { + if (expand_shape().size() < 2) { throw compilation_error(loc(), status::ir_expand_shape_order_too_small); } auto known_expand_shape = std::vector(); - known_expand_shape.reserve(expand_shape_.size()); + known_expand_shape.reserve(expand_shape().size()); std::size_t dyn_count = 0, non_imm_count = 0; - for (auto &s : expand_shape_) { + for (auto &s : expand_shape()) { visit(overloaded{[&](int_imm &i) { if (is_dynamic_value(i.value())) { known_expand_shape.push_back(dynamic); @@ -328,7 +334,7 @@ expand_inst::expand_inst(value op, std::int64_t mode, std::vector expand_ if (dyn_mode >= 0) { std::int64_t const s = size / prod; known_expand_shape[dyn_mode] = s; - expand_shape_[dyn_mode] = make_imm(s); + expand_shape()[dyn_mode] = make_imm(s); prod *= s; } if (prod != size) { @@ -363,10 +369,10 @@ expand_inst::expand_inst(value op, std::int64_t mode, std::vector expand_ result_ = make_value(data_type(r.release())); } -fuse_inst::fuse_inst(value op, std::int64_t from, std::int64_t to, location const &lc) - : op_(std::move(op)), from_(from), to_(to) { +fuse_inst::fuse_inst(value op0, std::int64_t from, std::int64_t to, location const &lc) + : super{std::move(op0)}, from_(from), to_(to) { loc(lc); - auto m = get_memref_type(loc(), op_); + auto m = get_memref_type(loc(), operand()); bool const range_ok = 0 <= from_ && from_ < to_ && to_ < m->dim(); if (!range_ok) { throw compilation_error(loc(), status::ir_out_of_bounds); @@ -402,37 +408,40 @@ fuse_inst::fuse_inst(value op, std::int64_t from, std::int64_t to, location cons if_inst::if_inst(value condition, region then, region otherwise, std::vector const &return_types, location const &lc) - : condition_(std::move(condition)), then_(std::move(then)), otherwise_(std::move(otherwise)) { + : super{std::move(condition)}, then_(std::move(then)), otherwise_(std::move(otherwise)) { loc(lc); for (auto &ty : return_types) { results_.push_back(make_value(ty)); } } -load_inst::load_inst(value op, std::vector index_list, location const &lc) - : op_(std::move(op)), index_list_(std::move(index_list)) { +load_inst::load_inst(value op0, std::vector const &index_list0, location const &lc) + : super{std::move(op0)} { loc(lc); + + ops().insert(ops().end(), index_list0.begin(), index_list0.end()); + visit(overloaded{ [&](group_data_type &g) { - if (static_cast(index_list_.size()) != 1) { + if (static_cast(index_list().size()) != 1) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } result_ = make_value(g.ty()); }, [&](memref_data_type &m) { - if (m.dim() != static_cast(index_list_.size())) { + if (m.dim() != static_cast(index_list().size())) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } result_ = make_value(m.element_ty()); }, [&](auto &) { throw compilation_error(loc(), status::ir_expected_memref_or_group); }}, - *op_->ty()); + *operand()->ty()); } -size_inst::size_inst(value op, std::int64_t mode, location const &lc) - : op_(std::move(op)), mode_(mode) { +size_inst::size_inst(value op0, std::int64_t mode, location const &lc) + : super{std::move(op0)}, mode_(mode) { loc(lc); - auto m = get_memref_type(loc(), op_); + auto m = get_memref_type(loc(), operand()); bool const range_ok = 0 <= mode_ && mode_ < m->dim(); if (!range_ok) { throw compilation_error(loc(), status::ir_out_of_bounds); @@ -441,52 +450,60 @@ size_inst::size_inst(value op, std::int64_t mode, location const &lc) result_ = make_value(scalar_type::index); } -subview_inst::subview_inst(value op, std::vector slices, location const &lc) - : op_(std::move(op)), slices_(std::move(slices)) { +subview_inst::subview_inst(value op0, std::vector const &offset_list0, + std::vector const &size_list0, location const &lc) + : super{std::move(op0)} { + loc(lc); - auto m = get_memref_type(loc(), op_); - if (m->dim() != static_cast(slices_.size())) { + + auto m = get_memref_type(loc(), operand()); + if (m->dim() != static_cast(offset_list0.size()) || + m->dim() != static_cast(size_list0.size())) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } + ops().insert(ops().end(), offset_list0.begin(), offset_list0.end()); + ops().insert(ops().end(), size_list0.begin(), size_list0.end()); auto shape = std::vector{}; auto stride = std::vector{}; shape.reserve(m->dim()); stride.reserve(m->dim()); for (std::int64_t i = 0; i < m->dim(); ++i) { - auto &slice = slices_[i]; + auto &offset = offset_list()[i]; + auto &size = size_list()[i]; visit(overloaded{[&](int_imm &i) { if (i.value() < 0) { throw compilation_error(loc(), status::ir_invalid_slice); } }, [](auto &) {}}, - *slice.first); - if (slice.second) { // if size is given + *offset); + if (size) { // if size is given visit(overloaded{[&](int_imm &i) { if (i.value() < 1 && !is_dynamic_value(i.value())) { throw compilation_error(loc(), status::ir_invalid_slice); } }, [](auto &) {}}, - *slice.second); - auto size = visit(overloaded{[&](int_imm &offset, int_imm &size) -> std::int64_t { - if (is_dynamic_value(size.value())) { - return is_dynamic_value(m->shape(i)) - ? dynamic - : m->shape(i) - offset.value(); - } - return size.value(); - }, - [&](val &, int_imm &size) -> std::int64_t { - if (is_dynamic_value(size.value())) { - return dynamic; - } - return size.value(); - }, - [](auto &, auto &) -> std::int64_t { return dynamic; }}, - *slice.first, *slice.second); - shape.push_back(size); + *size); + auto size_value = + visit(overloaded{[&](int_imm &offset, int_imm &size) -> std::int64_t { + if (is_dynamic_value(size.value())) { + return is_dynamic_value(m->shape(i)) + ? dynamic + : m->shape(i) - offset.value(); + } + return size.value(); + }, + [&](val &, int_imm &size) -> std::int64_t { + if (is_dynamic_value(size.value())) { + return dynamic; + } + return size.value(); + }, + [](auto &, auto &) -> std::int64_t { return dynamic; }}, + *offset, *size); + shape.push_back(size_value); stride.push_back(m->stride(i)); } } @@ -496,27 +513,31 @@ subview_inst::subview_inst(value op, std::vector slices, location const & result_ = make_value(data_type(r.release())); } -store_inst::store_inst(value val, value op, std::vector index_list, location const &lc) - : val_(std::move(val)), op_(std::move(op)), index_list_(std::move(index_list)) { +store_inst::store_inst(value val0, value op0, std::vector const &index_list0, + location const &lc) + : super{std::move(val0), std::move(op0)} { loc(lc); - auto v = get_scalar_type(loc(), val_); - auto o = get_memref_type(loc(), op_); + + ops().insert(ops().end(), index_list0.begin(), index_list0.end()); + + auto v = get_scalar_type(loc(), val()); + auto o = get_memref_type(loc(), operand()); if (v->ty() != o->element_ty()) { throw compilation_error(loc(), status::ir_scalar_mismatch); } - if (o->dim() != static_cast(index_list_.size())) { + if (o->dim() != static_cast(index_list0.size())) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } } -sum_inst::sum_inst(transpose tA, value alpha, value A, value beta, value B, bool atomic, +sum_inst::sum_inst(transpose tA, value alpha0, value A0, value beta0, value B0, bool atomic, location const &lc) - : super(std::move(alpha), std::move(A), std::move(beta), std::move(B), atomic), tA_(tA) { + : super(std::move(alpha0), std::move(A0), std::move(beta0), std::move(B0), atomic), tA_(tA) { loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); bool const size_ok = (a->dim() == 2 && b->dim() == 1) || (a->dim() == 1 && b->dim() == 0); if (!size_ok) { diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 6822b93e..e3213573 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -5,14 +5,15 @@ #define INST_NODE_20230327_HPP #include "reference_counted.hpp" -#include "slice.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" #include +#include #include #include +#include #include #include @@ -42,6 +43,14 @@ struct tinytc_inst : tinytc::reference_counted, tinytc::inst_nodes { inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } + // Iterator over operands + virtual auto begin() -> tinytc::value * = 0; + virtual auto end() -> tinytc::value * = 0; + virtual auto cbegin() const -> tinytc::value const * = 0; + virtual auto cend() const -> tinytc::value const * = 0; + inline auto begin() const -> tinytc::value const * { return cbegin(); } + inline auto end() const -> tinytc::value const * { return cend(); } + virtual tinytc::value result() const = 0; inline virtual auto results() const -> std::vector { if (auto r = result(); r) { @@ -60,63 +69,92 @@ namespace tinytc { using inst_node = ::tinytc_inst; -class scalar_inst : public inst_node {}; +template class standard_inst : public inst_node { + public: + template inline standard_inst(Ts &&...ts) : ops_{std::forward(ts)...} {} + + inline auto begin() -> tinytc::value * override { return ops_.data(); } + inline auto end() -> tinytc::value * override { return ops_.data() + ops_.size(); } + inline auto cbegin() const -> tinytc::value const * override { return ops_.data(); } + inline auto cend() const -> tinytc::value const * override { return ops_.data() + ops_.size(); } + + inline auto op(std::size_t pos) -> value & { return ops_[pos]; } + inline auto op(std::size_t pos) const -> value const & { return ops_[pos]; } + + private: + std::array ops_; +}; +class standard_variadic_inst : public inst_node { + public: + template + inline standard_variadic_inst(Ts &&...ts) : ops_{std::forward(ts)...} {} + + inline auto begin() -> tinytc::value * override { return ops_.data(); } + inline auto end() -> tinytc::value * override { return ops_.data() + ops_.size(); } + inline auto cbegin() const -> tinytc::value const * override { return ops_.data(); } + inline auto cend() const -> tinytc::value const * override { return ops_.data() + ops_.size(); } + + inline auto op(std::size_t pos) -> value & { return ops_[pos]; } + inline auto op(std::size_t pos) const -> value const & { return ops_[pos]; } + inline auto ops() -> std::vector & { return ops_; } + inline auto ops() const -> std::vector const & { return ops_; } + + private: + std::vector ops_; +}; -class blas_a2_inst : public inst_node { +class blas_a2_inst : public standard_inst<4u> { public: blas_a2_inst(value alpha, value A, value beta, value B, bool atomic); inline bool atomic() const { return atomic_; } inline void atomic(bool a) { atomic_ = a; } - inline auto alpha() const -> value const & { return alpha_; } - inline auto A() const -> value const & { return A_; } - inline auto beta() const -> value const & { return beta_; } - inline auto B() const -> value const & { return B_; } + inline auto alpha() const -> value const & { return op(0); } + inline auto A() const -> value const & { return op(1); } + inline auto beta() const -> value const & { return op(2); } + inline auto B() const -> value const & { return op(3); } inline value result() const override { return value{}; } inline inst_kind kind() const override { return inst_kind::collective; } protected: - value alpha_, A_, beta_, B_; bool atomic_; }; -class blas_a3_inst : public inst_node { +class blas_a3_inst : public standard_inst<5u> { public: blas_a3_inst(value alpha, value A, value B, value beta, value C, bool atomic); inline bool atomic() const { return atomic_; } inline void atomic(bool a) { atomic_ = a; } - inline auto alpha() const -> value const & { return alpha_; } - inline auto A() const -> value const & { return A_; } - inline auto B() const -> value const & { return B_; } - inline auto beta() const -> value const & { return beta_; } - inline auto C() const -> value const & { return C_; } + inline auto alpha() const -> value const & { return op(0); } + inline auto A() const -> value const & { return op(1); } + inline auto B() const -> value const & { return op(2); } + inline auto beta() const -> value const & { return op(3); } + inline auto C() const -> value const & { return op(4); } inline value result() const override { return value{}; } inline inst_kind kind() const override { return inst_kind::collective; } protected: - value alpha_, A_, B_, beta_, C_; bool atomic_; }; -class loop_inst : public inst_node { +class loop_inst : public standard_inst<4u> { public: loop_inst(value loop_var, value from, value to, region body, location const &loc = {}); loop_inst(value loop_var, value from, value to, value step, region body, location const &loc = {}); - inline auto loop_var() const -> value const & { return loop_var_; } - inline auto from() const -> value const & { return from_; } - inline auto to() const -> value const & { return to_; } - inline auto step() const -> value const & { return step_; } + inline auto loop_var() const -> value const & { return op(0); } + inline auto from() const -> value const & { return op(1); } + inline auto to() const -> value const & { return op(2); } + inline auto step() const -> value const & { return op(3); } inline auto body() const -> region const & { return body_; } inline value result() const override { return value{}; } private: - value loop_var_, from_, to_, step_; region body_; }; -class alloca_inst : public clir::visitable { +class alloca_inst : public clir::visitable> { public: alloca_inst(data_type ty, location const &loc = {}); @@ -142,116 +180,121 @@ class axpby_inst : public clir::visitable { transpose tA_; }; -class arith_inst : public clir::visitable { +class arith_inst : public clir::visitable> { public: + using super = clir::visitable>; arith_inst(arithmetic op, value a, value b, location const &lc = {}); - inline arithmetic op() const { return op_; } - inline auto a() const -> value const & { return a_; } - inline auto b() const -> value const & { return b_; } + inline arithmetic operation() const { return operation_; } + inline auto a() const -> value const & { return op(0); } + inline auto b() const -> value const & { return op(1); } inline value result() const override { return result_; } inline inst_kind kind() const override { return inst_kind::mixed; } private: - arithmetic op_; - value a_, b_, result_; + arithmetic operation_; + value result_; }; -class arith_unary_inst : public clir::visitable { +class arith_unary_inst : public clir::visitable> { public: + using super = clir::visitable>; arith_unary_inst(arithmetic_unary op, value a, location const &lc = {}); - inline arithmetic_unary op() const { return op_; } - inline auto a() const -> value const & { return a_; } + inline arithmetic_unary operation() const { return operation_; } + inline auto a() const -> value const & { return op(0); } inline value result() const override { return result_; } inline inst_kind kind() const override { return inst_kind::mixed; } private: - arithmetic_unary op_; - value a_, result_; + arithmetic_unary operation_; + value result_; }; -class barrier_inst : public clir::visitable { +class barrier_inst : public clir::visitable> { public: inline value result() const override { return value{}; } inline inst_kind kind() const override { return inst_kind::collective; } }; -class cast_inst : public clir::visitable { +class cast_inst : public clir::visitable> { public: + using super = clir::visitable>; cast_inst(value a, scalar_type to_ty, location const &lc = {}); - inline auto a() const -> value const & { return a_; } + inline auto a() const -> value const & { return op(0); } inline value result() const override { return result_; } inline inst_kind kind() const override { return inst_kind::mixed; } private: - value a_, result_; + value result_; }; -class compare_inst : public clir::visitable { +class compare_inst : public clir::visitable> { public: + using super = clir::visitable>; compare_inst(cmp_condition cond, value a, value b, location const &lc = {}); inline cmp_condition cond() const { return cond_; } - inline auto a() const -> value const & { return a_; } - inline auto b() const -> value const & { return b_; } + inline auto a() const -> value const & { return op(0); } + inline auto b() const -> value const & { return op(1); } inline value result() const override { return result_; } inline inst_kind kind() const override { return inst_kind::mixed; } private: cmp_condition cond_; - value a_, b_, result_; + value result_; }; -class expand_inst : public clir::visitable { +class expand_inst : public clir::visitable { public: - expand_inst(value op, std::int64_t mode, std::vector expand_shape, + using super = clir::visitable; + expand_inst(value op, std::int64_t mode, std::vector const &expand_shape, location const &lc = {}); - inline auto operand() const -> value const & { return op_; } + inline auto operand() const -> value const & { return op(0); } inline std::int64_t mode() const { return mode_; } - inline auto expand_shape() const -> std::vector const & { return expand_shape_; } - inline auto expand_shape(std::int64_t i) const -> value const & { return expand_shape_[i]; } + inline auto expand_shape() { return ops() | std::views::drop(1); } + inline auto expand_shape() const { return ops() | std::views::drop(1); } + inline auto expand_shape(std::int64_t i) const -> value const & { return op(i + 1); } inline value result() const override { return result_; } inline inst_kind kind() const override { return inst_kind::mixed; } private: - value op_, result_; + value result_; std::int64_t mode_; - std::vector expand_shape_; }; -class fuse_inst : public clir::visitable { +class fuse_inst : public clir::visitable> { public: + using super = clir::visitable>; fuse_inst(value op, std::int64_t from, std::int64_t to, location const &lc = {}); - inline auto operand() const -> value const & { return op_; } + inline auto operand() const -> value const & { return op(0); } inline std::int64_t from() const { return from_; } inline std::int64_t to() const { return to_; } inline value result() const override { return result_; } inline inst_kind kind() const override { return inst_kind::mixed; } private: - value op_, result_; + value result_; std::int64_t from_, to_; }; -class load_inst : public clir::visitable { +class load_inst : public clir::visitable { public: - load_inst(value op, std::vector index_list, location const &lc = {}); + using super = clir::visitable; + load_inst(value op, std::vector const &index_list, location const &lc = {}); - inline auto operand() const -> value const & { return op_; } - inline auto index_list() const -> std::vector const & { return index_list_; } + inline auto operand() const -> value const & { return op(0); } + inline auto index_list() const { return ops() | std::views::drop(1); } inline value result() const override { return result_; } inline inst_kind kind() const override { return inst_kind::mixed; } private: - value op_; - std::vector index_list_; value result_; }; -class group_id_inst : public clir::visitable { +class group_id_inst : public clir::visitable> { public: inline group_id_inst(location const &lc = {}) : result_{make_value(scalar_type::index)} { loc(lc); @@ -263,7 +306,7 @@ class group_id_inst : public clir::visitable { value result_; }; -class group_size_inst : public clir::visitable { +class group_size_inst : public clir::visitable> { public: inline group_size_inst(location const &lc = {}) : result_{make_value(scalar_type::index)} { loc(lc); @@ -275,15 +318,13 @@ class group_size_inst : public clir::visitable { value result_; }; -class lifetime_stop_inst : public clir::visitable { +class lifetime_stop_inst : public clir::visitable> { public: - inline lifetime_stop_inst(value obj) : obj_(std::move(obj)) {} - inline auto object() const -> value const & { return obj_; } + using super = clir::visitable>; + inline lifetime_stop_inst(value obj) : super{std::move(obj)} {} + inline auto object() const -> value const & { return op(0); } inline value result() const override { return value{}; } inline inst_kind kind() const override { return inst_kind::collective; } - - private: - value obj_; }; class gemm_inst : public clir::visitable { @@ -340,11 +381,12 @@ class hadamard_inst : public clir::visitable { location const &lc = {}); }; -class if_inst : public clir::visitable { +class if_inst : public clir::visitable> { public: + using super = clir::visitable>; if_inst(value condition, region then, region otherwise = {}, std::vector const &return_types = {}, location const &lc = {}); - inline auto condition() const -> value const & { return condition_; } + inline auto condition() const -> value const & { return op(0); } inline auto then() const -> region const & { return then_; } inline auto otherwise() const -> region const & { return otherwise_; } inline value result() const override { @@ -357,12 +399,11 @@ class if_inst : public clir::visitable { inline inst_kind kind() const override { return inst_kind::mixed; } private: - value condition_; region then_, otherwise_; std::vector results_; }; -class num_subgroups_inst : public clir::visitable { +class num_subgroups_inst : public clir::visitable> { public: inline num_subgroups_inst(location const &lc = {}) : result_{make_value(scalar_type::i32)} { loc(lc); @@ -374,7 +415,7 @@ class num_subgroups_inst : public clir::visitable { +class parallel_inst : public clir::visitable> { public: using super = clir::visitable; inline parallel_inst(region body, location const &lc = {}) : body_(std::move(body)) { loc(lc); } @@ -386,21 +427,22 @@ class parallel_inst : public clir::visitable { region body_; }; -class size_inst : public clir::visitable { +class size_inst : public clir::visitable> { public: + using super = clir::visitable>; size_inst(value op, std::int64_t mode, location const &lc = {}); - inline auto operand() const -> value const & { return op_; } + inline auto operand() const -> value const & { return op(0); } inline std::int64_t mode() const { return mode_; } inline value result() const override { return result_; } inline inst_kind kind() const override { return inst_kind::mixed; } private: - value op_, result_; + value result_; std::int64_t mode_; }; -class subgroup_id_inst : public clir::visitable { +class subgroup_id_inst : public clir::visitable> { public: inline subgroup_id_inst(location const &lc = {}) : result_{make_value(scalar_type::i32)} { loc(lc); @@ -412,7 +454,7 @@ class subgroup_id_inst : public clir::visitable { value result_; }; -class subgroup_local_id_inst : public clir::visitable { +class subgroup_local_id_inst : public clir::visitable> { public: inline subgroup_local_id_inst(location const &lc = {}) : result_{make_value(scalar_type::i32)} { loc(lc); @@ -424,7 +466,7 @@ class subgroup_local_id_inst : public clir::visitable { +class subgroup_size_inst : public clir::visitable> { public: inline subgroup_size_inst(location const &lc = {}) : result_{make_value(scalar_type::i32)} { loc(lc); @@ -436,34 +478,36 @@ class subgroup_size_inst : public clir::visitable { +class subview_inst : public clir::visitable { public: - subview_inst(value op, std::vector slices, location const &lc = {}); + using super = clir::visitable; + subview_inst(value op, std::vector const &offset_list, + std::vector const &size_list, location const &lc = {}); - inline auto slices() const -> std::vector const & { return slices_; } - inline auto operand() const -> value const & { return op_; } + inline auto operand() const -> value const & { return op(0); } + // We have ops().size() = 1 + 2 * num_indices() + inline auto num_indices() const { return (ops().size() - 1) / 2; } + inline auto offset_list() const { + return ops() | std::views::drop(1) | std::views::take(num_indices()); + } + inline auto size_list() const { return ops() | std::views::drop(1 + num_indices()); } inline value result() const override { return result_; } inline inst_kind kind() const override { return inst_kind::mixed; } private: - value op_; - std::vector slices_; value result_; }; -class store_inst : public clir::visitable { +class store_inst : public clir::visitable { public: - store_inst(value val, value op, std::vector index_list, location const &lc = {}); + using super = clir::visitable; + store_inst(value val, value op, std::vector const &index_list, location const &lc = {}); - inline auto val() const -> value const & { return val_; } - inline auto operand() const -> value const & { return op_; } - inline auto index_list() const -> std::vector const & { return index_list_; } + inline auto val() const -> value const & { return op(0); } + inline auto operand() const -> value const & { return op(1); } + inline auto index_list() const { return ops() | std::views::drop(2); } inline value result() const override { return {}; } inline inst_kind kind() const override { return inst_kind::mixed; } - - private: - value val_, op_; - std::vector index_list_; }; class sum_inst : public clir::visitable { @@ -478,13 +522,15 @@ class sum_inst : public clir::visitable { transpose tA_; }; -class yield_inst : public clir::visitable { +class yield_inst : public clir::visitable { public: - inline yield_inst(std::vector vals, location const &lc = {}) : vals_(std::move(vals)) { + using super = clir::visitable; + inline yield_inst(std::vector const &vals, location const &lc = {}) { loc(lc); + ops().insert(ops().end(), vals.begin(), vals.end()); } inline value result() const override { return value{}; } - inline auto vals() const -> std::vector const & { return vals_; } + inline auto vals() const -> std::vector const & { return ops(); } inline inst_kind kind() const override { return inst_kind::mixed; } private: diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 8af96bf8..07b01187 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -6,7 +6,6 @@ %code requires { #include "node/function_node.hpp" - #include "slice.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" #include @@ -215,9 +214,9 @@ %nterm subgroup_size_inst %nterm store_inst %nterm subview_inst -%nterm > optional_slice_list -%nterm > slice_list -%nterm slice +%nterm , std::vector<::tinytc::value>>> optional_slice_list +%nterm , std::vector<::tinytc::value>>> slice_list +%nterm > slice %nterm <::tinytc::value> slice_size %% @@ -972,8 +971,8 @@ subview_inst: } try { $$ = inst { - std::make_unique(std::move($var), std::move($optional_slice_list), - @subview_inst) + std::make_unique(std::move($var), $optional_slice_list.first, + $optional_slice_list.second, @subview_inst) .release() }; } catch (compilation_error const &e) { @@ -989,13 +988,20 @@ optional_slice_list: ; slice_list: - slice { $$.push_back($slice); } - | slice_list COMMA slice { $$ = std::move($1); $$.push_back($slice); } + slice { + $$.first.emplace_back(std::move($slice.first)); + $$.second.emplace_back(std::move($slice.second)); + } + | slice_list COMMA slice { + $$ = std::move($1); + $$.first.emplace_back(std::move($slice.first)); + $$.second.emplace_back(std::move($slice.second)); + } ; slice: - COLON { $$ = slice(make_index(0), make_dynamic()); } - | index_identifier_or_const slice_size { $$ = slice(std::move($1), std::move($2)); } + COLON { $$ = std::make_pair(make_index(0), make_dynamic()); } + | index_identifier_or_const slice_size { $$ = std::make_pair(std::move($1), std::move($2)); } ; slice_size: diff --git a/src/slice.hpp b/src/slice.hpp deleted file mode 100644 index d545242e..00000000 --- a/src/slice.hpp +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef SLICE_20240412_HPP -#define SLICE_20240412_HPP - -#include "tinytc/tinytc.hpp" - -#include - -namespace tinytc { - -//! Slice storing offset:size -class slice : public std::pair { - public: - //! ctor - inline slice(value offset = {}, value size = {}) - : std::pair{std::move(offset), std::move(size)} {} -}; - -} // namespace tinytc - -#endif // SLICE_20240412_HPP diff --git a/src/value.cpp b/src/value.cpp index ac29c800..ac96580d 100644 --- a/src/value.cpp +++ b/src/value.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause #include "error.hpp" +#include "location.hpp" #include "node/value_node.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" @@ -37,12 +38,8 @@ tinytc_status_t tinytc_value_create(tinytc_value_t *vl, tinytc_data_type_t type, if (vl == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { - *vl = std::make_unique(data_type(type, true)).release(); - if (lc) { - (*vl)->loc(*lc); - } - }); + return exception_to_status_code( + [&] { *vl = std::make_unique(data_type(type, true), get_optional(lc)).release(); }); } tinytc_status_t tinytc_float_imm_create(tinytc_value_t *vl, double imm, tinytc_scalar_type_t type, diff --git a/src/visitor/dump_ir.cpp b/src/visitor/dump_ir.cpp index b462ff8c..46ebff7d 100644 --- a/src/visitor/dump_ir.cpp +++ b/src/visitor/dump_ir.cpp @@ -2,7 +2,6 @@ // SPDX-License-Identifier: BSD-3-Clause #include "visitor/dump_ir.hpp" -#include "slice.hpp" #include "tinytc/tinytc.hpp" #include @@ -122,7 +121,7 @@ void ir_dumper::operator()(axpby_inst const &a) { void ir_dumper::operator()(arith_inst const &a) { visit(*this, *a.result()); - os_ << " = arith." << to_string(a.op()) << " "; + os_ << " = arith." << to_string(a.operation()) << " "; visit(*this, *a.a()); os_ << ", "; visit(*this, *a.b()); @@ -132,7 +131,7 @@ void ir_dumper::operator()(arith_inst const &a) { void ir_dumper::operator()(arith_unary_inst const &a) { visit(*this, *a.result()); - os_ << " = arith." << to_string(a.op()) << " "; + os_ << " = arith." << to_string(a.operation()) << " "; visit(*this, *a.a()); os_ << " : "; visit(*this, *a.a()->ty()); @@ -310,11 +309,13 @@ void ir_dumper::operator()(subview_inst const &s) { os_ << " = subview "; visit(*this, *s.operand()); os_ << "["; - do_with_infix(s.slices().begin(), s.slices().end(), [this](auto const &i) { - visit(*this, *i.first); - if (i.second) { + auto irange = std::ranges::iota_view{std::size_t{0}, s.offset_list().size()}; + do_with_infix(irange.begin(), irange.end(), [&](auto const &i) { + visit(*this, *s.offset_list()[i]); + auto &size = s.size_list()[i]; + if (size) { os_ << ":"; - visit(*this, *i.second); + visit(*this, *size); } }); os_ << "]"; diff --git a/src/visitor/insert_barrier.cpp b/src/visitor/insert_barrier.cpp index 4cc21435..d7d8c9e5 100644 --- a/src/visitor/insert_barrier.cpp +++ b/src/visitor/insert_barrier.cpp @@ -35,6 +35,8 @@ value_node *insert_barrier::operator()(val &v) { } /* Inst nodes */ +std::unordered_set insert_barrier::operator()(inst_node &) { return {}; } + std::unordered_set insert_barrier::operator()(blas_a2_inst &g) { auto rw = std::unordered_set{}; rw.emplace(visit(*this, *g.A())); @@ -54,8 +56,6 @@ std::unordered_set insert_barrier::operator()(loop_inst &p) { return visit(*this, *p.body()); } -std::unordered_set insert_barrier::operator()(scalar_inst &) { return {}; } - std::unordered_set insert_barrier::operator()(alloca_inst &) { return {}; } std::unordered_set insert_barrier::operator()(barrier_inst &) { diff --git a/src/visitor/insert_barrier.hpp b/src/visitor/insert_barrier.hpp index 0fe15094..d7524b84 100644 --- a/src/visitor/insert_barrier.hpp +++ b/src/visitor/insert_barrier.hpp @@ -30,10 +30,10 @@ class insert_barrier { value_node *operator()(val &v); /* Stmt nodes */ + std::unordered_set operator()(inst_node &inst); std::unordered_set operator()(blas_a2_inst &inst); std::unordered_set operator()(blas_a3_inst &inst); std::unordered_set operator()(loop_inst &p); - std::unordered_set operator()(scalar_inst &inst); std::unordered_set operator()(alloca_inst &a); std::unordered_set operator()(barrier_inst &b); std::unordered_set operator()(expand_inst &e); diff --git a/src/visitor/lifetime_analysis.cpp b/src/visitor/lifetime_analysis.cpp index fa11a1b2..24e9c376 100644 --- a/src/visitor/lifetime_analysis.cpp +++ b/src/visitor/lifetime_analysis.cpp @@ -47,6 +47,10 @@ value find_alloca::operator()(rgn &b) { std::vector find_alloca::allocas() const { return alloca_; } /* Inst nodes */ +auto lifetime_inserter::operator()(inst_node &) -> std::unordered_set { + return {}; +} + auto lifetime_inserter::operator()(blas_a2_inst &a) -> std::unordered_set { return {a.A().get(), a.B().get()}; } @@ -59,10 +63,6 @@ auto lifetime_inserter::operator()(loop_inst &p) -> std::unordered_set std::unordered_set { - return {}; -} - auto lifetime_inserter::operator()(alloca_inst &a) -> std::unordered_set { return {a.result().get()}; } diff --git a/src/visitor/lifetime_analysis.hpp b/src/visitor/lifetime_analysis.hpp index 2137eaab..1cbb26a9 100644 --- a/src/visitor/lifetime_analysis.hpp +++ b/src/visitor/lifetime_analysis.hpp @@ -40,10 +40,10 @@ class find_alloca { class lifetime_inserter { public: /* Inst nodes */ + auto operator()(inst_node &inst) -> std::unordered_set<::tinytc_value const *>; auto operator()(blas_a2_inst &inst) -> std::unordered_set<::tinytc_value const *>; auto operator()(blas_a3_inst &inst) -> std::unordered_set<::tinytc_value const *>; auto operator()(loop_inst &p) -> std::unordered_set<::tinytc_value const *>; - auto operator()(scalar_inst &inst) -> std::unordered_set<::tinytc_value const *>; auto operator()(alloca_inst &a) -> std::unordered_set<::tinytc_value const *>; auto operator()(barrier_inst &b) -> std::unordered_set<::tinytc_value const *>; auto operator()(expand_inst &e) -> std::unordered_set<::tinytc_value const *>; diff --git a/src/visitor/opencl_ast.cpp b/src/visitor/opencl_ast.cpp index 9a347ec2..b0dd4f80 100644 --- a/src/visitor/opencl_ast.cpp +++ b/src/visitor/opencl_ast.cpp @@ -6,7 +6,6 @@ #include "error.hpp" #include "gemm_generator.hpp" #include "scalar_type.hpp" -#include "slice.hpp" #include "tinytc/tinytc.hpp" #include "util.hpp" @@ -222,7 +221,7 @@ std::vector opencl_ast::operator()(alloca_inst const &a) { stack_high_water_mark_ = std::max(stack_high_water_mark_, static_cast(a.stack_ptr()) + t->size_in_bytes()); - // no declarations are neceesary as alloca only accepts fixed-size memrefs + // no declarations are necessary as alloca only accepts fixed-size memrefs set_dope_vector(a.result().get(), dope_vector::from_value(*a.result(), [](clir::data_type, clir::var, dope_vector::type, std::int64_t) {})); @@ -350,8 +349,9 @@ std::vector opencl_ast::operator()(arith_inst const &a) { }; auto sty = get_scalar_type(*a.a()->ty()); auto v = declare(*a.result()); - return {declaration_assignment(visit(*this, *a.result()->ty()), std::move(v), - make(a.op(), visit(*this, *a.a()), visit(*this, *a.b()), sty))}; + return {declaration_assignment( + visit(*this, *a.result()->ty()), std::move(v), + make(a.operation(), visit(*this, *a.a()), visit(*this, *a.b()), sty))}; } std::vector opencl_ast::operator()(arith_unary_inst const &a) { @@ -370,7 +370,7 @@ std::vector opencl_ast::operator()(arith_unary_inst const &a) { auto sty = get_scalar_type(*a.a()->ty()); auto v = declare(*a.result()); return {declaration_assignment(visit(*this, *a.result()->ty()), std::move(v), - make(a.op(), visit(*this, *a.a()), sty))}; + make(a.operation(), visit(*this, *a.a()), sty))}; } std::vector opencl_ast::operator()(cast_inst const &c) { @@ -407,7 +407,7 @@ std::vector opencl_ast::operator()(expand_inst const &e) { auto result_var = declare(*e.result()); auto m = get_memref_type(*e.operand()); auto &dv = get_dope_vector(e.operand().get()); - auto &eshape = e.expand_shape(); + auto eshape = e.expand_shape(); auto rhs = visit(*this, *e.operand()); auto clinst = std::vector{}; @@ -899,7 +899,7 @@ std::vector opencl_ast::operator()(subgroup_size_inst const &sg) { std::vector opencl_ast::operator()(subview_inst const &s) { auto result_var = declare(*s.result()); auto t = get_memref_type(*s.operand()); - if (t->dim() != static_cast(s.slices().size())) { + if (t->dim() != static_cast(s.num_indices())) { throw compilation_error(s.loc(), status::ir_invalid_number_of_indices); } @@ -911,23 +911,23 @@ std::vector opencl_ast::operator()(subview_inst const &s) { auto stride_out = std::vector{}; shape_out.reserve(t->dim()); stride_out.reserve(t->dim()); - auto &slices = s.slices(); - for (auto &slice : slices) { - auto offset = visit(*this, *slice.first); - rhs = rhs + std::move(offset) * dv.stride(j); - if (slice.second) { + for (std::int64_t i = 0; i < t->dim(); ++i) { + auto &offset = s.offset_list()[i]; + auto &size = s.size_list()[i]; + rhs = rhs + visit(*this, *offset) * dv.stride(j); + if (size) { bool is_size_unknown = visit(overloaded{[&](int_imm const &size) -> bool { return is_dynamic_value(size.value()); }, [](auto const &) -> bool { return false; }}, - *slice.second); - auto size = clir::expr{}; + *size); + auto size_value = clir::expr{}; if (is_size_unknown) { - size = dv.shape(j) - visit(*this, *slice.first); + size_value = dv.shape(j) - visit(*this, *offset); } else { - size = visit(*this, *slice.second); + size_value = visit(*this, *size); } - shape_out.emplace_back(size); + shape_out.emplace_back(size_value); stride_out.emplace_back(dv.stride(j)); } ++j; From c268de9688ca16bd3f48270d8788608f4b9115b1 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 6 Sep 2024 16:19:39 +0200 Subject: [PATCH 012/297] Refactor rtti Signed-off-by: Carsten Uphoff --- examples/benchmark/main.cpp | 2 +- src/binary.cpp | 2 +- src/data_type.cpp | 2 +- src/func.cpp | 5 +- src/inst.cpp | 2 +- src/node/data_type_node.cpp | 3 +- src/node/data_type_node.hpp | 37 ++- src/node/function_node.hpp | 27 +- src/node/inst_node.cpp | 358 ++++++++++++---------- src/node/inst_node.hpp | 477 +++++++++++++++++++----------- src/node/program_node.hpp | 21 +- src/node/region_node.hpp | 21 +- src/node/value_node.hpp | 40 ++- src/parser/parser_impl.yy | 7 +- src/recipe/small_gemm_batched.cpp | 2 +- src/recipe/tall_and_skinny.cpp | 2 +- src/support/casting.hpp | 47 +++ src/support/type_list.hpp | 42 +++ src/{ => support}/util.hpp | 11 + src/support/visit.hpp | 103 +++++++ src/value.cpp | 2 +- src/visitor/alias_analysis.cpp | 8 +- src/visitor/check_ir.cpp | 9 +- src/visitor/dump_ir.cpp | 10 +- src/visitor/equal.cpp | 5 +- src/visitor/insert_barrier.cpp | 7 +- src/visitor/lifetime_analysis.cpp | 5 +- src/visitor/metadata.cpp | 5 +- src/visitor/opencl_ast.cpp | 30 +- src/visitor/slot_tracker.cpp | 5 +- src/visitor/stack.cpp | 8 +- src/visitor/work_group_size.cpp | 8 +- src/ze/kernel.cpp | 2 +- 33 files changed, 864 insertions(+), 451 deletions(-) create mode 100644 src/support/casting.hpp create mode 100644 src/support/type_list.hpp rename src/{ => support}/util.hpp (61%) create mode 100644 src/support/visit.hpp diff --git a/examples/benchmark/main.cpp b/examples/benchmark/main.cpp index e414c3d6..2e86dfe8 100644 --- a/examples/benchmark/main.cpp +++ b/examples/benchmark/main.cpp @@ -77,7 +77,7 @@ auto gemm_kernel_with_inner_repetition(scalar_type ty, transpose tA, transpose t auto c = bb.add(make_load(C, {gid}, my_loc())); bb.for_loop( scalar_type::index, make_index(0, my_loc()), make_index(repetitions, my_loc()), - [&](region_builder &bb) { + [&](region_builder &bb, value const &) { bb.add(make_gemm(tA, tB, atomic, make_imm(1.0, ty, my_loc()), a, b, make_imm(beta, ty, my_loc()), c, my_loc())); }, diff --git a/src/binary.cpp b/src/binary.cpp index 360fc0ab..df09ca92 100644 --- a/src/binary.cpp +++ b/src/binary.cpp @@ -3,9 +3,9 @@ #include "binary.hpp" #include "error.hpp" +#include "support/util.hpp" #include "tinytc/tinytc.h" #include "tinytc/types.h" -#include "util.hpp" #include #include diff --git a/src/data_type.cpp b/src/data_type.cpp index cb9827b3..96e0a281 100644 --- a/src/data_type.cpp +++ b/src/data_type.cpp @@ -8,7 +8,7 @@ #include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" -#include "util.hpp" +#include "support/util.hpp" #include #include diff --git a/src/func.cpp b/src/func.cpp index 53d5fcfa..5cb2ea46 100644 --- a/src/func.cpp +++ b/src/func.cpp @@ -4,6 +4,7 @@ #include "error.hpp" #include "location.hpp" #include "node/function_node.hpp" +#include "support/casting.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" @@ -48,7 +49,7 @@ tinytc_status_t tinytc_function_create(tinytc_func_t *fun, tinytc_func_t prototy } tinytc_status_t tinytc_function_set_work_group_size(tinytc_func_t fun, int32_t x, int32_t y) { - function *f = dynamic_cast(fun); + function *f = dyn_cast(fun); if (f == nullptr) { return tinytc_status_invalid_arguments; } @@ -56,7 +57,7 @@ tinytc_status_t tinytc_function_set_work_group_size(tinytc_func_t fun, int32_t x } tinytc_status_t tinytc_function_set_subgroup_size(tinytc_func_t fun, int32_t sgs) { - function *f = dynamic_cast(fun); + function *f = dyn_cast(fun); if (f == nullptr) { return tinytc_status_invalid_arguments; } diff --git a/src/inst.cpp b/src/inst.cpp index dedf2308..220c06d9 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -4,11 +4,11 @@ #include "error.hpp" #include "location.hpp" #include "node/inst_node.hpp" +#include "support/util.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" -#include "util.hpp" #include #include diff --git a/src/node/data_type_node.cpp b/src/node/data_type_node.cpp index df570f75..4fefc97d 100644 --- a/src/node/data_type_node.cpp +++ b/src/node/data_type_node.cpp @@ -12,7 +12,8 @@ namespace tinytc { memref_data_type::memref_data_type(scalar_type type, std::vector shape, std::vector stride, location const &lc) - : element_ty_(std::move(type)), shape_(std::move(shape)), stride_(std::move(stride)) { + : data_type_node(DTK_memref), element_ty_(std::move(type)), shape_(std::move(shape)), + stride_(std::move(stride)) { loc(lc); for (auto const &s : shape_) { if (s < 0 && !is_dynamic_value(s)) { diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp index 8321554c..cc0929e2 100644 --- a/src/node/data_type_node.hpp +++ b/src/node/data_type_node.hpp @@ -6,12 +6,12 @@ #include "reference_counted.hpp" #include "scalar_type.hpp" +#include "support/type_list.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" #include #include -#include #include #include @@ -19,16 +19,23 @@ #include namespace tinytc { -using data_type_nodes = clir::virtual_type_list; +using data_type_nodes = type_list; } -struct tinytc_data_type : tinytc::reference_counted, tinytc::data_type_nodes { +struct tinytc_data_type : tinytc::reference_counted { public: + enum data_type_kind { DTK_group, DTK_memref, DTK_scalar, DTK_void }; + using leaves = tinytc::data_type_nodes; + + inline tinytc_data_type(std::int64_t tid) : tid_(tid) {} + inline auto type_id() const -> std::int64_t { return tid_; } + inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } private: + std::int64_t tid_; tinytc::location loc_; }; @@ -36,10 +43,11 @@ namespace tinytc { using data_type_node = ::tinytc_data_type; -class group_data_type : public clir::visitable { +class group_data_type : public data_type_node { public: + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK_group; } inline group_data_type(data_type ty, std::int64_t offset = 0, location const &lc = {}) - : ty_(std::move(ty)), offset_(offset) { + : data_type_node(DTK_group), ty_(std::move(ty)), offset_(offset) { loc(lc); } @@ -51,10 +59,15 @@ class group_data_type : public clir::visitable std::int64_t offset_; }; -class void_data_type : public clir::visitable {}; +class void_data_type : public data_type_node { + public: + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK_void; } + inline void_data_type() : data_type_node(DTK_void) {} +}; -class memref_data_type : public clir::visitable { +class memref_data_type : public data_type_node { public: + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK_memref; } memref_data_type(scalar_type type, std::vector shape, std::vector stride = {}, location const &lc = {}); @@ -91,9 +104,13 @@ class memref_data_type : public clir::visitable { +class scalar_data_type : public data_type_node { public: - inline scalar_data_type(scalar_type type, location const &lc) : ty_(type) { loc(lc); } + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK_scalar; } + inline scalar_data_type(scalar_type type, location const &lc) + : data_type_node(DTK_scalar), ty_(type) { + loc(lc); + } inline scalar_type ty() const { return ty_; } inline clir::data_type clir_ty() const { return to_clir_ty(ty_); } diff --git a/src/node/function_node.hpp b/src/node/function_node.hpp index aa964558..aa6cc567 100644 --- a/src/node/function_node.hpp +++ b/src/node/function_node.hpp @@ -6,10 +6,9 @@ #include "location.hpp" #include "reference_counted.hpp" +#include "support/type_list.hpp" #include "tinytc/tinytc.hpp" -#include - #include #include #include @@ -17,17 +16,25 @@ #include namespace tinytc { -using function_nodes = clir::virtual_type_list; +using function_nodes = type_list; } -struct tinytc_func : tinytc::reference_counted, tinytc::function_nodes { +struct tinytc_func : tinytc::reference_counted { public: + enum function_kind { FK_function, FK_prototype }; + using leaves = tinytc::function_nodes; + + inline tinytc_func(std::int64_t tid) : tid_(tid) {} + inline virtual ~tinytc_func() {} + inline auto type_id() const -> std::int64_t { return tid_; } + inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } virtual auto name() const -> std::string_view = 0; private: + std::int64_t tid_; tinytc::location loc_; }; @@ -35,10 +42,11 @@ namespace tinytc { using function_node = ::tinytc_func; -class prototype : public clir::visitable { +class prototype : public function_node { public: + inline static bool classof(function_node const &f) { return f.type_id() == FK_prototype; } inline prototype(std::string name, std::vector args = {}, location const &lc = {}) - : name_(std::move(name)), args_(std::move(args)) { + : function_node(FK_prototype), name_(std::move(name)), args_(std::move(args)) { loc(lc); } @@ -50,11 +58,12 @@ class prototype : public clir::visitable { std::vector args_; }; -class function : public clir::visitable { +class function : public function_node { public: + inline static bool classof(function_node const &f) { return f.type_id() == FK_function; } inline function(func prototype, region body, location const &lc = {}) - : prototype_(std::move(prototype)), body_(std::move(body)), work_group_size_{0, 0}, - subgroup_size_{0} { + : function_node(FK_function), prototype_(std::move(prototype)), body_(std::move(body)), + work_group_size_{0, 0}, subgroup_size_{0} { loc(lc); } diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index fd101276..b20928be 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -6,11 +6,12 @@ #include "node/data_type_node.hpp" #include "node/value_node.hpp" #include "scalar_type.hpp" +#include "support/casting.hpp" +#include "support/util.hpp" +#include "support/visit.hpp" #include "tinytc/types.hpp" -#include "util.hpp" #include -#include #include #include @@ -19,7 +20,7 @@ namespace tinytc { scalar_data_type *get_scalar_type(location const &loc, value const &v) { - auto m = dynamic_cast(v->ty().get()); + auto m = dyn_cast(v->ty().get()); if (m == nullptr) { throw compilation_error(loc, status::ir_expected_scalar); } @@ -27,28 +28,39 @@ scalar_data_type *get_scalar_type(location const &loc, value const &v) { } memref_data_type *get_memref_type(location const &loc, value const &v) { - auto m = dynamic_cast(v->ty().get()); + auto m = dyn_cast(v->ty().get()); if (m == nullptr) { throw compilation_error(loc, status::ir_expected_memref); } return m; } -blas_a2_inst::blas_a2_inst(value alpha, value A, value beta, value B, bool atomic) - : standard_inst{std::move(alpha), std::move(A), std::move(beta), std::move(B)}, - atomic_(atomic) {} +blas_a2_inst::blas_a2_inst(std::int64_t tid, value alpha, value A, value beta, value B, bool atomic) + : standard_inst{tid}, atomic_(atomic) { + op(op_alpha) = std::move(alpha); + op(op_A) = std::move(A); + op(op_beta) = std::move(beta); + op(op_B) = std::move(B); +} -blas_a3_inst::blas_a3_inst(value alpha, value A, value B, value beta, value C, bool atomic) - : standard_inst{std::move(alpha), std::move(A), std::move(B), std::move(beta), std::move(C)}, - atomic_(atomic) {} +blas_a3_inst::blas_a3_inst(std::int64_t tid, value alpha, value A, value B, value beta, value C, + bool atomic) + : standard_inst{tid}, atomic_(atomic) { + op(op_alpha) = std::move(alpha); + op(op_A) = std::move(A); + op(op_B) = std::move(B); + op(op_beta) = std::move(beta); + op(op_C) = std::move(C); +} -loop_inst::loop_inst(value loop_var, value from, value to, region body, location const &lc) - : loop_inst(std::move(loop_var), std::move(from), std::move(to), {}, std::move(body), lc) {} +loop_inst::loop_inst(std::int64_t tid, value loop_var0, value from0, value to0, value step0, + region body, location const &lc) + : standard_inst{tid}, body_(std::move(body)) { + op(op_loop_var) = std::move(loop_var0); + op(op_from) = std::move(from0); + op(op_to) = std::move(to0); + op(op_step) = std::move(step0); -loop_inst::loop_inst(value loop_var0, value from0, value to0, value step0, region body, - location const &lc) - : standard_inst{std::move(loop_var0), std::move(from0), std::move(to0), std::move(step0)}, - body_(std::move(body)) { loc(lc); auto lvt = get_scalar_type(loc(), loop_var()); auto fromt = get_scalar_type(loc(), from()); @@ -65,9 +77,9 @@ loop_inst::loop_inst(value loop_var0, value from0, value to0, value step0, regio } alloca_inst::alloca_inst(data_type ty, location const &lc) - : result_{make_value(std::move(ty))}, stack_ptr_{-1} { + : standard_inst{IK_alloca}, result_{make_value(std::move(ty))}, stack_ptr_{-1} { loc(lc); - auto memref = dynamic_cast(result_->ty().get()); + auto memref = dyn_cast(result_->ty().get()); if (memref == nullptr) { throw compilation_error(loc(), status::ir_expected_memref); } @@ -76,7 +88,9 @@ alloca_inst::alloca_inst(data_type ty, location const &lc) axpby_inst::axpby_inst(transpose tA, value alpha0, value A0, value beta0, value B0, bool atomic, location const &lc) - : super(std::move(alpha0), std::move(A0), std::move(beta0), std::move(B0), atomic), tA_(tA) { + : blas_a2_inst(IK_axpby_blas_a2, std::move(alpha0), std::move(A0), std::move(beta0), + std::move(B0), atomic), + tA_(tA) { loc(lc); auto a = get_memref_type(loc(), A()); auto b = get_memref_type(loc(), B()); @@ -98,7 +112,9 @@ axpby_inst::axpby_inst(transpose tA, value alpha0, value A0, value beta0, value } arith_inst::arith_inst(arithmetic operation, value a0, value b0, location const &lc) - : super{std::move(a0), std::move(b0)}, operation_(operation) { + : standard_inst{IK_arith}, operation_(operation) { + op(op_a) = std::move(a0); + op(op_b) = std::move(b0); loc(lc); auto at = get_scalar_type(loc(), a()); @@ -131,7 +147,8 @@ arith_inst::arith_inst(arithmetic operation, value a0, value b0, location const } arith_unary_inst::arith_unary_inst(arithmetic_unary operation, value a0, location const &lc) - : super{std::move(a0)}, operation_(operation) { + : standard_inst{IK_arith_unary}, operation_(operation) { + op(op_a) = std::move(a0); loc(lc); auto at = get_scalar_type(loc(), a()); @@ -151,12 +168,15 @@ arith_unary_inst::arith_unary_inst(arithmetic_unary operation, value a0, locatio } cast_inst::cast_inst(value a, scalar_type to_ty, location const &lc) - : super{std::move(a)}, result_{make_value(to_ty)} { + : standard_inst{IK_cast}, result_{make_value(to_ty)} { + op(op_a) = std::move(a); loc(lc); } compare_inst::compare_inst(cmp_condition cond, value a0, value b0, location const &lc) - : super{std::move(a0), std::move(b0)}, cond_(cond), result_{make_value(scalar_type::i1)} { + : standard_inst{IK_compare}, cond_(cond), result_{make_value(scalar_type::i1)} { + op(op_a) = std::move(a0); + op(op_b) = std::move(b0); loc(lc); auto at = get_scalar_type(loc(), a()); @@ -167,122 +187,15 @@ compare_inst::compare_inst(cmp_condition cond, value a0, value b0, location cons } } -gemm_inst::gemm_inst(transpose tA, transpose tB, value alpha0, value A0, value B0, value beta0, - value C0, bool atomic, location const &lc) - : super(std::move(alpha0), std::move(A0), std::move(B0), std::move(beta0), std::move(C0), - atomic), - tA_(tA), tB_(tB) { - loc(lc); - auto a = get_memref_type(loc(), A()); - auto b = get_memref_type(loc(), B()); - auto c = get_memref_type(loc(), C()); - - if (a->dim() != 2 || b->dim() != 2 || c->dim() != 2) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix, - "gemm only supported for memref of order 2 (matrices)"); - } - - auto ak = tA_ == transpose::T ? 0 : 1; - auto bk = tB_ == transpose::T ? 1 : 0; - auto M = c->shape(0); - auto N = c->shape(1); - auto K = a->shape(ak); - if (a->shape(1 - ak) != M || b->shape(bk) != K || b->shape(1 - bk) != N) { - std::ostringstream oss; - oss << "Got "; - oss << "A=" << a->shape(0) << "x" << a->shape(1) << ", "; - oss << "B=" << b->shape(0) << "x" << b->shape(1) << ", "; - oss << "C=" << c->shape(0) << "x" << c->shape(1); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); - } -} - -gemv_inst::gemv_inst(transpose tA, value alpha0, value A0, value B0, value beta0, value C0, - bool atomic, location const &lc) - : super(std::move(alpha0), std::move(A0), std::move(B0), std::move(beta0), std::move(C0), - atomic), - tA_(tA) { - loc(lc); - auto a = get_memref_type(loc(), A()); - auto b = get_memref_type(loc(), B()); - auto c = get_memref_type(loc(), C()); - - if (a->dim() != 2 || b->dim() != 1 || c->dim() != 1) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix, - "gemv only supports matrix-vector products"); - } - - auto ak = tA_ == transpose::T ? 0 : 1; - auto M = c->shape(0); - auto K = a->shape(ak); - if (a->shape(1 - ak) != M || b->shape(0) != K) { - std::ostringstream oss; - oss << "Got "; - oss << "A=" << a->shape(0) << "x" << a->shape(1) << ", "; - oss << "b=" << b->shape(0) << ", "; - oss << "c=" << c->shape(0); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); - } -} - -ger_inst::ger_inst(value alpha0, value A0, value B0, value beta0, value C0, bool atomic, - location const &lc) - : super(std::move(alpha0), std::move(A0), std::move(B0), std::move(beta0), std::move(C0), - atomic) { - loc(lc); - auto a = get_memref_type(loc(), A()); - auto b = get_memref_type(loc(), B()); - auto c = get_memref_type(loc(), C()); - - if (a->dim() != 1 || b->dim() != 1 || c->dim() != 2) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix, - "ger requires two vectors as input and one matrix as output"); - } - - auto M = c->shape(0); - auto N = c->shape(1); - if (a->shape(0) != M || b->shape(0) != N) { - std::ostringstream oss; - oss << "Got "; - oss << "a=" << a->shape(0) << ", "; - oss << "b=" << b->shape(0) << ", "; - oss << "C=" << c->shape(0) << "x" << c->shape(1); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); - } -} - -hadamard_inst::hadamard_inst(value alpha0, value A0, value B0, value beta0, value C0, bool atomic, - location const &lc) - : super(std::move(alpha0), std::move(A0), std::move(B0), std::move(beta0), std::move(C0), - atomic) { - loc(lc); - auto a = get_memref_type(loc(), A()); - auto b = get_memref_type(loc(), B()); - auto c = get_memref_type(loc(), C()); - - if (a->dim() != 1 || b->dim() != 1 || c->dim() != 1) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix, - "hadamard requires two vectors as input and one vector as output"); - } - - auto M = c->shape(0); - if (a->shape(0) != M || b->shape(0) != M) { - std::ostringstream oss; - oss << "Got "; - oss << "a=" << a->shape(0) << ", "; - oss << "b=" << b->shape(0) << ", "; - oss << "c=" << c->shape(0); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); - } -} - expand_inst::expand_inst(value op0, std::int64_t mode, std::vector const &expand_shape0, location const &lc) - : super{std::move(op0)}, mode_(mode) { + : standard_inst{IK_expand, static_cast(1 + expand_shape0.size())}, mode_(mode) { + op(0) = std::move(op0); + for (std::size_t i = 0; i < expand_shape0.size(); ++i) { + op(1 + i) = expand_shape0[i]; + } loc(lc); - ops().insert(ops().end(), expand_shape0.begin(), expand_shape0.end()); - auto m = get_memref_type(loc(), operand()); bool const range_ok = 0 <= mode_ && mode_ < m->dim(); if (!range_ok) { @@ -370,7 +283,8 @@ expand_inst::expand_inst(value op0, std::int64_t mode, std::vector const } fuse_inst::fuse_inst(value op0, std::int64_t from, std::int64_t to, location const &lc) - : super{std::move(op0)}, from_(from), to_(to) { + : standard_inst{IK_fuse}, from_(from), to_(to) { + op(0) = std::move(op0); loc(lc); auto m = get_memref_type(loc(), operand()); bool const range_ok = 0 <= from_ && from_ < to_ && to_ < m->dim(); @@ -406,21 +320,14 @@ fuse_inst::fuse_inst(value op0, std::int64_t from, std::int64_t to, location con result_ = make_value(data_type(r.release())); } -if_inst::if_inst(value condition, region then, region otherwise, - std::vector const &return_types, location const &lc) - : super{std::move(condition)}, then_(std::move(then)), otherwise_(std::move(otherwise)) { - loc(lc); - for (auto &ty : return_types) { - results_.push_back(make_value(ty)); - } -} - load_inst::load_inst(value op0, std::vector const &index_list0, location const &lc) - : super{std::move(op0)} { + : standard_inst{IK_load, static_cast(1 + index_list0.size())} { + op(0) = std::move(op0); + for (std::size_t i = 0; i < index_list0.size(); ++i) { + op(1 + i) = index_list0[i]; + } loc(lc); - ops().insert(ops().end(), index_list0.begin(), index_list0.end()); - visit(overloaded{ [&](group_data_type &g) { if (static_cast(index_list().size()) != 1) { @@ -438,8 +345,129 @@ load_inst::load_inst(value op0, std::vector const &index_list0, location *operand()->ty()); } +gemm_inst::gemm_inst(transpose tA, transpose tB, value alpha0, value A0, value B0, value beta0, + value C0, bool atomic, location const &lc) + : blas_a3_inst(IK_gemm_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), + std::move(beta0), std::move(C0), atomic), + tA_(tA), tB_(tB) { + loc(lc); + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + auto c = get_memref_type(loc(), C()); + + if (a->dim() != 2 || b->dim() != 2 || c->dim() != 2) { + throw compilation_error(loc(), status::ir_expected_vector_or_matrix, + "gemm only supported for memref of order 2 (matrices)"); + } + + auto ak = tA_ == transpose::T ? 0 : 1; + auto bk = tB_ == transpose::T ? 1 : 0; + auto M = c->shape(0); + auto N = c->shape(1); + auto K = a->shape(ak); + if (a->shape(1 - ak) != M || b->shape(bk) != K || b->shape(1 - bk) != N) { + std::ostringstream oss; + oss << "Got "; + oss << "A=" << a->shape(0) << "x" << a->shape(1) << ", "; + oss << "B=" << b->shape(0) << "x" << b->shape(1) << ", "; + oss << "C=" << c->shape(0) << "x" << c->shape(1); + throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); + } +} + +gemv_inst::gemv_inst(transpose tA, value alpha0, value A0, value B0, value beta0, value C0, + bool atomic, location const &lc) + : blas_a3_inst(IK_gemv_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), + std::move(beta0), std::move(C0), atomic), + tA_(tA) { + loc(lc); + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + auto c = get_memref_type(loc(), C()); + + if (a->dim() != 2 || b->dim() != 1 || c->dim() != 1) { + throw compilation_error(loc(), status::ir_expected_vector_or_matrix, + "gemv only supports matrix-vector products"); + } + + auto ak = tA_ == transpose::T ? 0 : 1; + auto M = c->shape(0); + auto K = a->shape(ak); + if (a->shape(1 - ak) != M || b->shape(0) != K) { + std::ostringstream oss; + oss << "Got "; + oss << "A=" << a->shape(0) << "x" << a->shape(1) << ", "; + oss << "b=" << b->shape(0) << ", "; + oss << "c=" << c->shape(0); + throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); + } +} + +ger_inst::ger_inst(value alpha0, value A0, value B0, value beta0, value C0, bool atomic, + location const &lc) + : blas_a3_inst(IK_ger_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), + std::move(beta0), std::move(C0), atomic) { + loc(lc); + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + auto c = get_memref_type(loc(), C()); + + if (a->dim() != 1 || b->dim() != 1 || c->dim() != 2) { + throw compilation_error(loc(), status::ir_expected_vector_or_matrix, + "ger requires two vectors as input and one matrix as output"); + } + + auto M = c->shape(0); + auto N = c->shape(1); + if (a->shape(0) != M || b->shape(0) != N) { + std::ostringstream oss; + oss << "Got "; + oss << "a=" << a->shape(0) << ", "; + oss << "b=" << b->shape(0) << ", "; + oss << "C=" << c->shape(0) << "x" << c->shape(1); + throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); + } +} + +hadamard_inst::hadamard_inst(value alpha0, value A0, value B0, value beta0, value C0, bool atomic, + location const &lc) + : blas_a3_inst(IK_hadamard_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), + std::move(beta0), std::move(C0), atomic) { + loc(lc); + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + auto c = get_memref_type(loc(), C()); + + if (a->dim() != 1 || b->dim() != 1 || c->dim() != 1) { + throw compilation_error(loc(), status::ir_expected_vector_or_matrix, + "hadamard requires two vectors as input and one vector as output"); + } + + auto M = c->shape(0); + if (a->shape(0) != M || b->shape(0) != M) { + std::ostringstream oss; + oss << "Got "; + oss << "a=" << a->shape(0) << ", "; + oss << "b=" << b->shape(0) << ", "; + oss << "c=" << c->shape(0); + throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); + } +} + +if_inst::if_inst(value condition, region then, region otherwise, + std::vector const &return_types, location const &lc) + : standard_inst{IK_if, 1, static_cast(return_types.size())}, then_(std::move(then)), + otherwise_(std::move(otherwise)) { + op(0) = std::move(condition); + loc(lc); + for (auto &ty : return_types) { + results_.push_back(make_value(ty)); + } +} + size_inst::size_inst(value op0, std::int64_t mode, location const &lc) - : super{std::move(op0)}, mode_(mode) { + : standard_inst{IK_size}, mode_(mode) { + op(0) = std::move(op0); loc(lc); auto m = get_memref_type(loc(), operand()); bool const range_ok = 0 <= mode_ && mode_ < m->dim(); @@ -452,8 +480,18 @@ size_inst::size_inst(value op0, std::int64_t mode, location const &lc) subview_inst::subview_inst(value op0, std::vector const &offset_list0, std::vector const &size_list0, location const &lc) - : super{std::move(op0)} { - + : standard_inst{IK_subview, + static_cast(1 + offset_list0.size() + size_list0.size())} { + op(0) = std::move(op0); + { + std::size_t i = 1; + for (auto const &val : offset_list0) { + op(i++) = val; + } + for (auto const &val : size_list0) { + op(i++) = val; + } + } loc(lc); auto m = get_memref_type(loc(), operand()); @@ -461,8 +499,6 @@ subview_inst::subview_inst(value op0, std::vector const &offset_list0, m->dim() != static_cast(size_list0.size())) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } - ops().insert(ops().end(), offset_list0.begin(), offset_list0.end()); - ops().insert(ops().end(), size_list0.begin(), size_list0.end()); auto shape = std::vector{}; auto stride = std::vector{}; @@ -515,11 +551,17 @@ subview_inst::subview_inst(value op0, std::vector const &offset_list0, store_inst::store_inst(value val0, value op0, std::vector const &index_list0, location const &lc) - : super{std::move(val0), std::move(op0)} { + : standard_inst{IK_store, static_cast(2 + index_list0.size())} { + op(op_val) = std::move(val0); + op(op_operand) = std::move(op0); + { + std::size_t i = op_operand; + for (auto const &val : index_list0) { + op(++i) = val; + } + } loc(lc); - ops().insert(ops().end(), index_list0.begin(), index_list0.end()); - auto v = get_scalar_type(loc(), val()); auto o = get_memref_type(loc(), operand()); @@ -534,7 +576,9 @@ store_inst::store_inst(value val0, value op0, std::vector const &index_li sum_inst::sum_inst(transpose tA, value alpha0, value A0, value beta0, value B0, bool atomic, location const &lc) - : super(std::move(alpha0), std::move(A0), std::move(beta0), std::move(B0), atomic), tA_(tA) { + : blas_a2_inst(IK_sum_blas_a2, std::move(alpha0), std::move(A0), std::move(beta0), + std::move(B0), atomic), + tA_(tA) { loc(lc); auto a = get_memref_type(loc(), A()); auto b = get_memref_type(loc(), B()); diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index e3213573..f1de7416 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -4,12 +4,13 @@ #ifndef INST_NODE_20230327_HPP #define INST_NODE_20230327_HPP +#include "error.hpp" #include "reference_counted.hpp" +#include "support/type_list.hpp" +#include "support/util.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" -#include - #include #include #include @@ -20,36 +21,94 @@ namespace tinytc { //! Instruction classification -enum class inst_kind { +enum class inst_execution_kind { mixed, ///< mixed instruction on uniform or varying data collective, ///< collective instruction on uniform data, distributed among work-items spmd ///< SPMD instruction on varying data }; -using inst_nodes = clir::virtual_type_list< - class alloca_inst, class axpby_inst, class barrier_inst, class arith_inst, - class arith_unary_inst, class cast_inst, class compare_inst, class expand_inst, class fuse_inst, - class load_inst, class group_id_inst, class group_size_inst, class lifetime_stop_inst, - class gemm_inst, class gemv_inst, class ger_inst, class for_inst, class foreach_inst, - class hadamard_inst, class if_inst, class num_subgroups_inst, class parallel_inst, - class size_inst, class subview_inst, class store_inst, class subgroup_id_inst, - class subgroup_local_id_inst, class subgroup_size_inst, class sum_inst, class yield_inst>; +using inst_nodes = + type_list; + +using op_range = iterator_range_wrapper; +using const_op_range = iterator_range_wrapper; +using result_range = iterator_range_wrapper; +using const_result_range = iterator_range_wrapper; } // namespace tinytc -struct tinytc_inst : tinytc::reference_counted, tinytc::inst_nodes { - public: +struct tinytc_inst : tinytc::reference_counted { + public: + enum inst_kind { + IK_alloca, + IK_arith, + IK_arith_unary, + IK_barrier, + IK_cast, + IK_compare, + IK_expand, + IK_fuse, + IK_load, + IK_group_id, + IK_group_size, + IK_lifetime_stop, + IK_if, + IK_num_subgroups, + IK_parallel, + IK_size, + IK_subgroup_id, + IK_subgroup_local_id, + IK_subgroup_size, + IK_subview, + IK_store, + IK_yield, + // blas a2 + IK_blas_a2, + IK_axpby_blas_a2, + IK_sum_blas_a2, + IK_last_blas_a2, + // blas a3 + IK_blas_a3, + IK_gemm_blas_a3, + IK_gemv_blas_a3, + IK_ger_blas_a3, + IK_hadamard_blas_a3, + IK_last_blas_a3, + // loop inst + IK_loop, + IK_for_loop, + IK_foreach_loop, + IK_last_loop + }; + using leaves = tinytc::inst_nodes; + + inline tinytc_inst(std::int64_t tid) : tid_(tid), op_begin_(nullptr), op_end_(nullptr) {} + inline virtual ~tinytc_inst() {} + inline auto type_id() const -> std::int64_t { return tid_; } + inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } // Iterator over operands - virtual auto begin() -> tinytc::value * = 0; - virtual auto end() -> tinytc::value * = 0; - virtual auto cbegin() const -> tinytc::value const * = 0; - virtual auto cend() const -> tinytc::value const * = 0; - inline auto begin() const -> tinytc::value const * { return cbegin(); } - inline auto end() const -> tinytc::value const * { return cend(); } + inline auto op_begin() -> tinytc::value * { return op_begin_; } + inline auto op_end() -> tinytc::value * { return op_end_; } + inline auto operands() -> tinytc::op_range { return tinytc::op_range{op_begin_, op_end_}; } + inline auto op_begin() const -> tinytc::value const * { return op_begin_; } + inline auto op_end() const -> tinytc::value const * { return op_end_; } + inline auto operands() const -> tinytc::const_op_range { + return tinytc::const_op_range{op_begin_, op_end_}; + } + inline auto op(std::size_t pos) -> tinytc::value & { return op_begin_[pos]; } + inline auto op(std::size_t pos) const -> tinytc::value const & { return op_begin_[pos]; } + inline auto num_operands() const -> std::int64_t { return op_end_ - op_begin_; } virtual tinytc::value result() const = 0; inline virtual auto results() const -> std::vector { @@ -59,94 +118,123 @@ struct tinytc_inst : tinytc::reference_counted, tinytc::inst_nodes { return {}; } inline virtual auto num_results() const -> std::size_t { return result() ? 1u : 0u; } - virtual tinytc::inst_kind kind() const = 0; + virtual tinytc::inst_execution_kind kind() const = 0; + + protected: + inline auto op_range(tinytc::value *begin, tinytc::value *end) { + op_begin_ = begin; + op_end_ = end; + } private: + std::int64_t tid_; tinytc::location loc_; + tinytc::value *op_begin_, *op_end_; }; namespace tinytc { using inst_node = ::tinytc_inst; -template class standard_inst : public inst_node { +template class value_container { public: - template inline standard_inst(Ts &&...ts) : ops_{std::forward(ts)...} {} - - inline auto begin() -> tinytc::value * override { return ops_.data(); } - inline auto end() -> tinytc::value * override { return ops_.data() + ops_.size(); } - inline auto cbegin() const -> tinytc::value const * override { return ops_.data(); } - inline auto cend() const -> tinytc::value const * override { return ops_.data() + ops_.size(); } - - inline auto op(std::size_t pos) -> value & { return ops_[pos]; } - inline auto op(std::size_t pos) const -> value const & { return ops_[pos]; } + value_container(std::int64_t num_values) { + if (num_values != NumValues) { + throw internal_compiler_error(); + } + } + inline auto get() -> tinytc::value * { + if constexpr (NumValues == 0) { + return nullptr; + } + return ops_.data(); + } private: - std::array ops_; + std::array ops_; }; -class standard_variadic_inst : public inst_node { + +template <> class value_container { public: - template - inline standard_variadic_inst(Ts &&...ts) : ops_{std::forward(ts)...} {} + value_container(std::int64_t num_values) : ops_{std::make_unique(num_values)} {} - inline auto begin() -> tinytc::value * override { return ops_.data(); } - inline auto end() -> tinytc::value * override { return ops_.data() + ops_.size(); } - inline auto cbegin() const -> tinytc::value const * override { return ops_.data(); } - inline auto cend() const -> tinytc::value const * override { return ops_.data() + ops_.size(); } + auto get() -> tinytc::value * { return ops_.get(); } - inline auto op(std::size_t pos) -> value & { return ops_[pos]; } - inline auto op(std::size_t pos) const -> value const & { return ops_[pos]; } - inline auto ops() -> std::vector & { return ops_; } - inline auto ops() const -> std::vector const & { return ops_; } + private: + std::unique_ptr ops_; +}; + +template +class standard_inst : public inst_node { + public: + standard_inst(std::int64_t tid, std::int64_t num_operands = NumOperands, + std::int64_t num_results = NumResults) + : inst_node{tid}, ops_{num_operands}, results_{num_results} { + if (num_operands > 0) { + op_range(ops_.get(), ops_.get() + num_operands); + } + } private: - std::vector ops_; + value_container ops_; + value_container results_; }; -class blas_a2_inst : public standard_inst<4u> { +class blas_a2_inst : public standard_inst<4, 1> { public: - blas_a2_inst(value alpha, value A, value beta, value B, bool atomic); + inline static bool classof(inst_node const &i) { + return i.type_id() >= IK_blas_a2 && i.type_id() <= IK_last_blas_a2; + } + enum op_number { op_alpha = 0, op_A = 1, op_beta = 2, op_B = 3 }; + blas_a2_inst(std::int64_t tid, value alpha, value A, value beta, value B, bool atomic); inline bool atomic() const { return atomic_; } inline void atomic(bool a) { atomic_ = a; } - inline auto alpha() const -> value const & { return op(0); } - inline auto A() const -> value const & { return op(1); } - inline auto beta() const -> value const & { return op(2); } - inline auto B() const -> value const & { return op(3); } + inline auto alpha() const -> value const & { return op(op_alpha); } + inline auto A() const -> value const & { return op(op_A); } + inline auto beta() const -> value const & { return op(op_beta); } + inline auto B() const -> value const & { return op(op_B); } inline value result() const override { return value{}; } - inline inst_kind kind() const override { return inst_kind::collective; } + inline inst_execution_kind kind() const override { return inst_execution_kind::collective; } protected: bool atomic_; }; -class blas_a3_inst : public standard_inst<5u> { +class blas_a3_inst : public standard_inst<5, 1> { public: - blas_a3_inst(value alpha, value A, value B, value beta, value C, bool atomic); + inline static bool classof(inst_node const &i) { + return i.type_id() >= IK_blas_a3 && i.type_id() <= IK_last_blas_a3; + } + enum op_number { op_alpha = 0, op_A = 1, op_B = 2, op_beta = 3, op_C = 4 }; + blas_a3_inst(std::int64_t tid, value alpha, value A, value B, value beta, value C, bool atomic); inline bool atomic() const { return atomic_; } inline void atomic(bool a) { atomic_ = a; } - inline auto alpha() const -> value const & { return op(0); } - inline auto A() const -> value const & { return op(1); } - inline auto B() const -> value const & { return op(2); } - inline auto beta() const -> value const & { return op(3); } - inline auto C() const -> value const & { return op(4); } + inline auto alpha() const -> value const & { return op(op_alpha); } + inline auto A() const -> value const & { return op(op_A); } + inline auto B() const -> value const & { return op(op_B); } + inline auto beta() const -> value const & { return op(op_beta); } + inline auto C() const -> value const & { return op(op_C); } inline value result() const override { return value{}; } - inline inst_kind kind() const override { return inst_kind::collective; } + inline inst_execution_kind kind() const override { return inst_execution_kind::collective; } protected: bool atomic_; }; -class loop_inst : public standard_inst<4u> { +class loop_inst : public standard_inst<4, 1> { public: - loop_inst(value loop_var, value from, value to, region body, location const &loc = {}); - loop_inst(value loop_var, value from, value to, value step, region body, + inline static bool classof(inst_node const &i) { + return i.type_id() >= IK_loop && i.type_id() <= IK_last_loop; + } + enum op_number { op_loop_var = 0, op_from = 1, op_to = 2, op_step = 3 }; + loop_inst(std::int64_t tid, value loop_var, value from, value to, value step, region body, location const &loc = {}); - inline auto loop_var() const -> value const & { return op(0); } - inline auto from() const -> value const & { return op(1); } - inline auto to() const -> value const & { return op(2); } - inline auto step() const -> value const & { return op(3); } + inline auto loop_var() const -> value const & { return op(op_loop_var); } + inline auto from() const -> value const & { return op(op_from); } + inline auto to() const -> value const & { return op(op_to); } + inline auto step() const -> value const & { return op(op_step); } inline auto body() const -> region const & { return body_; } inline value result() const override { return value{}; } @@ -154,23 +242,24 @@ class loop_inst : public standard_inst<4u> { region body_; }; -class alloca_inst : public clir::visitable> { +class alloca_inst : public standard_inst<0, 1> { public: + inline static bool classof(inst_node const &i) { return i.type_id() == IK_alloca; } alloca_inst(data_type ty, location const &loc = {}); inline value result() const override { return result_; } inline std::int64_t stack_ptr() const { return stack_ptr_; } inline void stack_ptr(std::int64_t ptr) { stack_ptr_ = ptr; } - inline inst_kind kind() const override { return inst_kind::collective; } + inline inst_execution_kind kind() const override { return inst_execution_kind::collective; } private: value result_; std::int64_t stack_ptr_; }; -class axpby_inst : public clir::visitable { +class axpby_inst : public blas_a2_inst { public: - using super = clir::visitable; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_axpby_blas_a2; } axpby_inst(transpose tA, value alpha, value A, value beta, value B, bool atomic = false, location const &lc = {}); @@ -180,156 +269,169 @@ class axpby_inst : public clir::visitable { transpose tA_; }; -class arith_inst : public clir::visitable> { +class arith_inst : public standard_inst<2, 1> { public: - using super = clir::visitable>; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_arith; } + enum op_number { op_a = 0, op_b = 1 }; arith_inst(arithmetic op, value a, value b, location const &lc = {}); inline arithmetic operation() const { return operation_; } - inline auto a() const -> value const & { return op(0); } - inline auto b() const -> value const & { return op(1); } + inline auto a() const -> value const & { return op(op_a); } + inline auto b() const -> value const & { return op(op_b); } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::mixed; } + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: arithmetic operation_; value result_; }; -class arith_unary_inst : public clir::visitable> { +class arith_unary_inst : public standard_inst<1, 1> { public: - using super = clir::visitable>; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_arith_unary; } + enum op_number { op_a = 0 }; arith_unary_inst(arithmetic_unary op, value a, location const &lc = {}); inline arithmetic_unary operation() const { return operation_; } - inline auto a() const -> value const & { return op(0); } + inline auto a() const -> value const & { return op(op_a); } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::mixed; } + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: arithmetic_unary operation_; value result_; }; -class barrier_inst : public clir::visitable> { +class barrier_inst : public standard_inst<0, 0> { public: + inline static bool classof(inst_node const &i) { return i.type_id() == IK_barrier; } + inline barrier_inst() : standard_inst{IK_barrier} {} + inline value result() const override { return value{}; } - inline inst_kind kind() const override { return inst_kind::collective; } + inline inst_execution_kind kind() const override { return inst_execution_kind::collective; } }; -class cast_inst : public clir::visitable> { +class cast_inst : public standard_inst<1, 1> { public: - using super = clir::visitable>; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_cast; } + enum op_number { op_a = 0 }; cast_inst(value a, scalar_type to_ty, location const &lc = {}); - inline auto a() const -> value const & { return op(0); } + inline auto a() const -> value const & { return op(op_a); } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::mixed; } + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: value result_; }; -class compare_inst : public clir::visitable> { +class compare_inst : public standard_inst<2, 1> { public: - using super = clir::visitable>; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_compare; } + enum op_number { op_a = 0, op_b = 1 }; compare_inst(cmp_condition cond, value a, value b, location const &lc = {}); inline cmp_condition cond() const { return cond_; } - inline auto a() const -> value const & { return op(0); } - inline auto b() const -> value const & { return op(1); } + inline auto a() const -> value const & { return op(op_a); } + inline auto b() const -> value const & { return op(op_b); } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::mixed; } + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: cmp_condition cond_; value result_; }; -class expand_inst : public clir::visitable { +class expand_inst : public standard_inst { public: - using super = clir::visitable; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_expand; } expand_inst(value op, std::int64_t mode, std::vector const &expand_shape, location const &lc = {}); inline auto operand() const -> value const & { return op(0); } inline std::int64_t mode() const { return mode_; } - inline auto expand_shape() { return ops() | std::views::drop(1); } - inline auto expand_shape() const { return ops() | std::views::drop(1); } + inline auto expand_shape() { return operands() | std::views::drop(1); } + inline auto expand_shape() const { return operands() | std::views::drop(1); } inline auto expand_shape(std::int64_t i) const -> value const & { return op(i + 1); } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::mixed; } + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: value result_; std::int64_t mode_; }; -class fuse_inst : public clir::visitable> { +class fuse_inst : public standard_inst<1, 1> { public: - using super = clir::visitable>; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_fuse; } fuse_inst(value op, std::int64_t from, std::int64_t to, location const &lc = {}); inline auto operand() const -> value const & { return op(0); } inline std::int64_t from() const { return from_; } inline std::int64_t to() const { return to_; } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::mixed; } + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: value result_; std::int64_t from_, to_; }; -class load_inst : public clir::visitable { +class load_inst : public standard_inst { public: - using super = clir::visitable; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_load; } load_inst(value op, std::vector const &index_list, location const &lc = {}); inline auto operand() const -> value const & { return op(0); } - inline auto index_list() const { return ops() | std::views::drop(1); } + inline auto index_list() const { return operands() | std::views::drop(1); } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::mixed; } + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: value result_; }; -class group_id_inst : public clir::visitable> { +class group_id_inst : public standard_inst<0, 1> { public: - inline group_id_inst(location const &lc = {}) : result_{make_value(scalar_type::index)} { + inline static bool classof(inst_node const &i) { return i.type_id() == IK_group_id; } + inline group_id_inst(location const &lc = {}) + : standard_inst{IK_group_id}, result_{make_value(scalar_type::index)} { loc(lc); } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::mixed; } + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: value result_; }; -class group_size_inst : public clir::visitable> { +class group_size_inst : public standard_inst<0, 1> { public: - inline group_size_inst(location const &lc = {}) : result_{make_value(scalar_type::index)} { + inline static bool classof(inst_node const &i) { return i.type_id() == IK_group_size; } + inline group_size_inst(location const &lc = {}) + : standard_inst{IK_group_size}, result_{make_value(scalar_type::index)} { loc(lc); } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::mixed; } + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: value result_; }; -class lifetime_stop_inst : public clir::visitable> { +class lifetime_stop_inst : public standard_inst<1, 1> { public: - using super = clir::visitable>; - inline lifetime_stop_inst(value obj) : super{std::move(obj)} {} + inline static bool classof(inst_node const &i) { return i.type_id() == IK_lifetime_stop; } + inline lifetime_stop_inst(value obj) : standard_inst{IK_lifetime_stop} { + op(0) = std::move(obj); + } inline auto object() const -> value const & { return op(0); } inline value result() const override { return value{}; } - inline inst_kind kind() const override { return inst_kind::collective; } + inline inst_execution_kind kind() const override { return inst_execution_kind::collective; } }; -class gemm_inst : public clir::visitable { +class gemm_inst : public blas_a3_inst { public: - using super = clir::visitable; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_gemm_blas_a3; } gemm_inst(transpose tA, transpose tB, value alpha, value A, value B, value beta, value C, bool atomic = false, location const &lc = {}); @@ -340,9 +442,9 @@ class gemm_inst : public clir::visitable { transpose tA_, tB_; }; -class gemv_inst : public clir::visitable { +class gemv_inst : public blas_a3_inst { public: - using super = clir::visitable; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_gemv_blas_a3; } gemv_inst(transpose tA, value alpha, value A, value B, value beta, value C, bool atomic = false, location const &lc = {}); @@ -352,38 +454,56 @@ class gemv_inst : public clir::visitable { transpose tA_; }; -class ger_inst : public clir::visitable { +class ger_inst : public blas_a3_inst { public: - using super = clir::visitable; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_ger_blas_a3; } ger_inst(value alpha, value A, value B, value beta, value C, bool atomic = false, location const &lc = {}); }; -class for_inst : public clir::visitable { +class for_inst : public loop_inst { public: - using super = clir::visitable; - using super::super; - inline inst_kind kind() const override { return inst_kind::mixed; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK_for_loop; } + inline for_inst(value loop_var, value from, value to, region body, location const &loc = {}) + : loop_inst{ + IK_for_loop, std::move(loop_var), std::move(from), std::move(to), {}, std::move(body), + loc} {} + inline for_inst(value loop_var, value from, value to, value step, region body, + location const &loc = {}) + : loop_inst{IK_for_loop, + std::move(loop_var), + std::move(from), + std::move(to), + std::move(step), + std::move(body), + loc} {} + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } }; -class foreach_inst : public clir::visitable { +class foreach_inst : public loop_inst { public: - using super = clir::visitable; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_foreach_loop; } inline foreach_inst(value loop_var, value from, value to, region body, location const &loc = {}) - : super(std::move(loop_var), std::move(from), std::move(to), std::move(body), loc) {} - inline inst_kind kind() const override { return inst_kind::collective; } + : loop_inst{IK_foreach_loop, + std::move(loop_var), + std::move(from), + std::move(to), + {}, + std::move(body), + loc} {} + inline inst_execution_kind kind() const override { return inst_execution_kind::collective; } }; -class hadamard_inst : public clir::visitable { +class hadamard_inst : public blas_a3_inst { public: - using super = clir::visitable; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_hadamard_blas_a3; } hadamard_inst(value alpha, value A, value B, value beta, value C, bool atomic = false, location const &lc = {}); }; -class if_inst : public clir::visitable> { +class if_inst : public standard_inst<1, dynamic> { public: - using super = clir::visitable>; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_if; } if_inst(value condition, region then, region otherwise = {}, std::vector const &return_types = {}, location const &lc = {}); inline auto condition() const -> value const & { return op(0); } @@ -396,123 +516,135 @@ class if_inst : public clir::visitable> { inline auto num_results() const -> std::size_t override { return results_.size(); } inline auto results_ref() -> std::vector & { return results_; } inline auto results_ref() const -> std::vector const & { return results_; } - inline inst_kind kind() const override { return inst_kind::mixed; } + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: region then_, otherwise_; std::vector results_; }; -class num_subgroups_inst : public clir::visitable> { +class num_subgroups_inst : public standard_inst<0, 1> { public: - inline num_subgroups_inst(location const &lc = {}) : result_{make_value(scalar_type::i32)} { + inline static bool classof(inst_node const &i) { return i.type_id() == IK_num_subgroups; } + inline num_subgroups_inst(location const &lc = {}) + : standard_inst{IK_num_subgroups}, result_{make_value(scalar_type::i32)} { loc(lc); } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::mixed; } + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: value result_; }; -class parallel_inst : public clir::visitable> { +class parallel_inst : public standard_inst<0, 0> { public: - using super = clir::visitable; - inline parallel_inst(region body, location const &lc = {}) : body_(std::move(body)) { loc(lc); } + inline static bool classof(inst_node const &i) { return i.type_id() == IK_parallel; } + inline parallel_inst(region body, location const &lc = {}) + : standard_inst{IK_parallel}, body_(std::move(body)) { + loc(lc); + } inline auto body() const -> region const & { return body_; } - inline inst_kind kind() const override { return inst_kind::collective; } + inline inst_execution_kind kind() const override { return inst_execution_kind::collective; } inline value result() const override { return value{}; } private: region body_; }; -class size_inst : public clir::visitable> { +class size_inst : public standard_inst<1, 1> { public: - using super = clir::visitable>; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_size; } size_inst(value op, std::int64_t mode, location const &lc = {}); inline auto operand() const -> value const & { return op(0); } inline std::int64_t mode() const { return mode_; } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::mixed; } + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: value result_; std::int64_t mode_; }; -class subgroup_id_inst : public clir::visitable> { +class subgroup_id_inst : public standard_inst<0, 1> { public: - inline subgroup_id_inst(location const &lc = {}) : result_{make_value(scalar_type::i32)} { + inline static bool classof(inst_node const &i) { return i.type_id() == IK_subgroup_id; } + inline subgroup_id_inst(location const &lc = {}) + : standard_inst{IK_subgroup_id}, result_{make_value(scalar_type::i32)} { loc(lc); } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::spmd; } + inline inst_execution_kind kind() const override { return inst_execution_kind::spmd; } private: value result_; }; -class subgroup_local_id_inst : public clir::visitable> { +class subgroup_local_id_inst : public standard_inst<0, 1> { public: - inline subgroup_local_id_inst(location const &lc = {}) : result_{make_value(scalar_type::i32)} { + inline static bool classof(inst_node const &i) { return i.type_id() == IK_subgroup_local_id; } + inline subgroup_local_id_inst(location const &lc = {}) + : standard_inst{IK_subgroup_local_id}, result_{make_value(scalar_type::i32)} { loc(lc); } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::spmd; } + inline inst_execution_kind kind() const override { return inst_execution_kind::spmd; } private: value result_; }; -class subgroup_size_inst : public clir::visitable> { +class subgroup_size_inst : public standard_inst<0, 1> { public: - inline subgroup_size_inst(location const &lc = {}) : result_{make_value(scalar_type::i32)} { + inline static bool classof(inst_node const &i) { return i.type_id() == IK_subgroup_size; } + inline subgroup_size_inst(location const &lc = {}) + : standard_inst{IK_subgroup_size}, result_{make_value(scalar_type::i32)} { loc(lc); } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::mixed; } + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: value result_; }; -class subview_inst : public clir::visitable { +class subview_inst : public standard_inst { public: - using super = clir::visitable; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_subview; } subview_inst(value op, std::vector const &offset_list, std::vector const &size_list, location const &lc = {}); inline auto operand() const -> value const & { return op(0); } - // We have ops().size() = 1 + 2 * num_indices() - inline auto num_indices() const { return (ops().size() - 1) / 2; } + // We have num_operands() = 1 + 2 * num_indices() + inline auto num_indices() const { return (num_operands() - 1) / 2; } inline auto offset_list() const { - return ops() | std::views::drop(1) | std::views::take(num_indices()); + return operands() | std::views::drop(1) | std::views::take(num_indices()); } - inline auto size_list() const { return ops() | std::views::drop(1 + num_indices()); } + inline auto size_list() const { return operands() | std::views::drop(1 + num_indices()); } inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::mixed; } + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: value result_; }; -class store_inst : public clir::visitable { +class store_inst : public standard_inst { public: - using super = clir::visitable; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_store; } + enum op_number { op_val = 0, op_operand = 1 }; store_inst(value val, value op, std::vector const &index_list, location const &lc = {}); - inline auto val() const -> value const & { return op(0); } - inline auto operand() const -> value const & { return op(1); } - inline auto index_list() const { return ops() | std::views::drop(2); } + inline auto val() const -> value const & { return op(op_val); } + inline auto operand() const -> value const & { return op(op_operand); } + inline auto index_list() const { return operands() | std::views::drop(2); } inline value result() const override { return {}; } - inline inst_kind kind() const override { return inst_kind::mixed; } + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } }; -class sum_inst : public clir::visitable { +class sum_inst : public blas_a2_inst { public: - using super = clir::visitable; + inline static bool classof(inst_node const &i) { return i.type_id() == IK_sum_blas_a2; } sum_inst(transpose tA, value alpha, value A, value beta, value B, bool atomic = false, location const &lc = {}); @@ -522,19 +654,18 @@ class sum_inst : public clir::visitable { transpose tA_; }; -class yield_inst : public clir::visitable { +class yield_inst : public standard_inst { public: - using super = clir::visitable; - inline yield_inst(std::vector const &vals, location const &lc = {}) { + inline static bool classof(inst_node const &i) { return i.type_id() == IK_yield; } + inline yield_inst(std::vector const &vals, location const &lc = {}) + : standard_inst{IK_yield, static_cast(vals.size())} { loc(lc); - ops().insert(ops().end(), vals.begin(), vals.end()); + for (std::size_t i = 0; i < vals.size(); ++i) { + op(i) = vals[i]; + } } inline value result() const override { return value{}; } - inline auto vals() const -> std::vector const & { return ops(); } - inline inst_kind kind() const override { return inst_kind::mixed; } - - private: - std::vector vals_; + inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } }; } // namespace tinytc diff --git a/src/node/program_node.hpp b/src/node/program_node.hpp index 372c41d7..7136f872 100644 --- a/src/node/program_node.hpp +++ b/src/node/program_node.hpp @@ -6,23 +6,30 @@ #include "location.hpp" #include "reference_counted.hpp" +#include "support/type_list.hpp" #include "tinytc/tinytc.hpp" -#include - +#include #include #include namespace tinytc { -using program_nodes = clir::virtual_type_list; +using program_nodes = type_list; } -struct tinytc_prog : tinytc::reference_counted, tinytc::program_nodes { +struct tinytc_prog : tinytc::reference_counted { public: + enum prog_kind { PK_prog }; + using leaves = tinytc::program_nodes; + + inline tinytc_prog(std::int64_t tid) : tid_(tid) {} + inline auto type_id() const -> std::int64_t { return tid_; } + inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } private: + std::int64_t tid_; tinytc::location loc_; }; @@ -30,9 +37,11 @@ namespace tinytc { using program_node = ::tinytc_prog; -class program : public clir::visitable { +class program : public program_node { public: - inline program(std::vector decls, location const &lc = {}) : decls_(std::move(decls)) { + inline static bool classof(program_node const &p) { return p.type_id() == PK_prog; } + inline program(std::vector decls, location const &lc = {}) + : program_node(PK_prog), decls_(std::move(decls)) { loc(lc); } inline auto declarations() -> std::vector & { return decls_; } diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp index bcf1d983..c100b54c 100644 --- a/src/node/region_node.hpp +++ b/src/node/region_node.hpp @@ -5,23 +5,30 @@ #define REGION_NODE_20230908_HPP #include "reference_counted.hpp" +#include "support/type_list.hpp" #include "tinytc/tinytc.hpp" -#include - +#include #include #include namespace tinytc { -using region_nodes = clir::virtual_type_list; +using region_nodes = type_list; } -struct tinytc_region : tinytc::reference_counted, tinytc::region_nodes { +struct tinytc_region : tinytc::reference_counted { public: + enum region_kind { RK_rgn }; + using leaves = tinytc::region_nodes; + + inline tinytc_region(std::int64_t tid) : tid_(tid) {} + inline auto type_id() const -> std::int64_t { return tid_; } + inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } private: + std::int64_t tid_; tinytc::location loc_; }; @@ -29,9 +36,11 @@ namespace tinytc { using region_node = ::tinytc_region; -class rgn : public clir::visitable { +class rgn : public region_node { public: - inline rgn(std::vector insts = {}, location const &lc = {}) : insts_(std::move(insts)) { + inline static bool classof(region_node const &r) { return r.type_id() == RK_rgn; } + inline rgn(std::vector insts = {}, location const &lc = {}) + : region_node(RK_rgn), insts_(std::move(insts)) { loc(lc); } diff --git a/src/node/value_node.hpp b/src/node/value_node.hpp index 461bbf8d..de342b35 100644 --- a/src/node/value_node.hpp +++ b/src/node/value_node.hpp @@ -5,20 +5,26 @@ #define VALUE_NODE_20230309_HPP #include "reference_counted.hpp" +#include "support/type_list.hpp" #include "tinytc/tinytc.hpp" -#include - #include #include #include namespace tinytc { -using value_nodes = clir::virtual_type_list; +using value_nodes = type_list; } -struct tinytc_value : tinytc::reference_counted, tinytc::value_nodes { +struct tinytc_value : tinytc::reference_counted { public: + enum value_kind { VK_float, VK_int, VK_val }; + using leaves = tinytc::value_nodes; + + inline tinytc_value(std::int64_t tid) : tid_(tid) {} + inline virtual ~tinytc_value() {} + inline auto type_id() const -> std::int64_t { return tid_; } + inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } @@ -29,6 +35,7 @@ struct tinytc_value : tinytc::reference_counted, tinytc::value_nodes { virtual auto has_name() const -> bool = 0; private: + std::int64_t tid_; tinytc::location loc_; }; @@ -36,10 +43,13 @@ namespace tinytc { using value_node = ::tinytc_value; -class float_imm : public clir::visitable { +class float_imm : public value_node { public: - inline float_imm(double v, scalar_type ty = scalar_type::f64) - : ty_{make_scalar(ty)}, value_(v) {} + inline static bool classof(value_node const &v) { return v.type_id() == VK_float; } + inline float_imm(double v, scalar_type ty = scalar_type::f64, location const &lc = {}) + : value_node(VK_float), ty_{make_scalar(ty)}, value_(v) { + loc(lc); + } inline data_type ty() const override { return ty_; } inline void ty(data_type ty) override { ty_ = std::move(ty); } @@ -54,10 +64,13 @@ class float_imm : public clir::visitable { double value_; }; -class int_imm : public clir::visitable { +class int_imm : public value_node { public: - inline int_imm(std::int64_t v, scalar_type ty = scalar_type::i64) - : ty_{make_scalar(ty)}, value_(v) {} + inline static bool classof(value_node const &v) { return v.type_id() == VK_int; } + inline int_imm(std::int64_t v, scalar_type ty = scalar_type::i64, location const &lc = {}) + : value_node(VK_int), ty_{make_scalar(ty)}, value_(v) { + loc(lc); + } inline data_type ty() const override { return ty_; } inline void ty(data_type ty) override { ty_ = std::move(ty); } @@ -72,9 +85,12 @@ class int_imm : public clir::visitable { std::int64_t value_; }; -class val : public clir::visitable { +class val : public value_node { public: - inline val(data_type ty) : ty_(std::move(ty)) {} + inline static bool classof(value_node const &v) { return v.type_id() == VK_val; } + inline val(data_type ty, location const &lc = {}) : value_node(VK_val), ty_(std::move(ty)) { + loc(lc); + } inline data_type ty() const override { return ty_; } inline void ty(data_type ty) override { ty_ = std::move(ty); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 07b01187..41ac60c7 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -30,9 +30,8 @@ #include "parser/lexer.hpp" #include "parser/parse_context.hpp" #include "passes.hpp" - #include "util.hpp" - - #include + #include "support/util.hpp" + #include "support/visit.hpp" #include #include @@ -44,7 +43,7 @@ namespace tinytc { void check_scalar_type(value & val, scalar_type const& sty, location & loc1, location & loc2) { - clir::visit( + visit( overloaded{[&](int_imm &i) { i.ty(make_scalar(sty)); }, [&](float_imm &i) { i.ty(make_scalar(sty)); }, [&](auto &) { diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index 0c01ddfc..d74557af 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -6,11 +6,11 @@ #include "parser.hpp" #include "recipe.hpp" #include "reference_counted.hpp" +#include "support/util.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" -#include "util.hpp" #include #include diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 46ec8109..521bf295 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -7,10 +7,10 @@ #include "parser.hpp" #include "recipe.hpp" #include "reference_counted.hpp" +#include "support/util.hpp" #include "tiling.hpp" #include "tinytc/tinytc.h" #include "tinytc/types.h" -#include "util.hpp" #include #include diff --git a/src/support/casting.hpp b/src/support/casting.hpp new file mode 100644 index 00000000..3ddbaa2f --- /dev/null +++ b/src/support/casting.hpp @@ -0,0 +1,47 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CASTING_20240903_HPP +#define CASTING_20240903_HPP + +#include + +// LLVM-style RTTI, cf. https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html + +namespace tinytc { + +template struct copy_cv { + private: + using T1 = std::conditional_t::value, std::add_const_t, T>; + using T2 = std::conditional_t::value, std::add_volatile_t, T1>; + + public: + using type = T2; +}; +template +using copy_cv_t = typename copy_cv::type; + +template +requires(std::is_base_of_v, std::decay_t>) +auto isa(From const &obj) -> bool { + return To::classof(obj); +} + +template +requires(std::is_base_of_v, std::decay_t>) +auto cast(From *obj) { + return (copy_cv_t *)obj; +} + +template +requires(std::is_base_of_v, std::decay_t>) +auto dyn_cast(From *obj) -> To * { + if (obj != nullptr && isa(*obj)) { + return cast(obj); + } + return nullptr; +} + +} // namespace tinytc + +#endif // CASTING_20240903_HPP diff --git a/src/support/type_list.hpp b/src/support/type_list.hpp new file mode 100644 index 00000000..19fdb6b5 --- /dev/null +++ b/src/support/type_list.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef TYPE_LIST_20240903_HPP +#define TYPE_LIST_20240903_HPP + +#include + +namespace tinytc { + +template class type_at { + private: + static_assert(Index < sizeof...(Types), "Type index out of bounds"); + + template struct impl; + template struct impl { + using type = typename impl::type; + }; + template struct impl<0, Head, Tail...> { + using type = Head; + }; + + public: + using type = typename impl::type; +}; + +template +using type_at_t = typename type_at::type; + +/** + * @brief Simple type list that allows to query the type at index 0,...,number_of_types()-1 + * + * @tparam Types + */ +template struct type_list { + template using type_at = type_at_t; + static constexpr auto number_of_types() { return sizeof...(Types); } +}; + +} // namespace tinytc + +#endif // TYPE_LIST_20240903_HPP diff --git a/src/util.hpp b/src/support/util.hpp similarity index 61% rename from src/util.hpp rename to src/support/util.hpp index f0f42720..c9a4b275 100644 --- a/src/util.hpp +++ b/src/support/util.hpp @@ -5,6 +5,7 @@ #define UTIL_20240201_HPP #include +#include namespace tinytc { @@ -17,6 +18,16 @@ template auto enum_cast(V val) { return T{std::underlying_type_t(val)}; } +template class iterator_range_wrapper { + public: + iterator_range_wrapper(ItT begin, ItT end) : begin_(std::move(begin)), end_(std::move(end)) {} + ItT begin() const { return begin_; } + ItT end() const { return end_; } + + private: + ItT begin_, end_; +}; + } // namespace tinytc #endif // UTIL_20240201_HPP diff --git a/src/support/visit.hpp b/src/support/visit.hpp new file mode 100644 index 00000000..2c4c9231 --- /dev/null +++ b/src/support/visit.hpp @@ -0,0 +1,103 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef VISIT_20240903_HPP +#define VISIT_20240903_HPP + +#include "casting.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +namespace detail { +/** + * @brief Computes \prod_{i=0}^{MaxMode-1} Size_i, where Size_0 = Head, and Size_i = Tail_i for i > + * 0. Always returns 1 for MaxMode = 0. + */ +template +constexpr auto partial_product(std::index_sequence) -> std::size_t { + if constexpr (MaxMode == 0) { + return 1; + } else { + return Head * partial_product(std::index_sequence{}); + } +} + +/** + * In an ND-tensor, the ND-index (i_0, ..., i_n) is flattend via + * Index = i_0 + i_1 * Size_0 + ... + i_n * Size_0 * ... * Size_{n-1} + * This function takes the compile-time flat Index and recovers the ND-index, and returns the + * ND-index as index_sequence. + */ +template +constexpr auto unflatten(std::index_sequence) { + return [](std::index_sequence) { + return std::index_sequence<(Index / partial_product(std::index_sequence{}) % + Size)...>{}; + }(std::make_index_sequence{}); +} +} // namespace detail + +template +concept visitable = requires(T ty) { + typename T::leaves; + { ty.type_id() } -> std::integral; +}; + +/** + * Multiple dispatch for class hierachies that implement LLVM-style RTTI and have + * a "leaves" type-list that lists all leaf classes (i.e. classes that have no children). + * + * The function works as following: + * + * 1. Input are one or multiple objects with LLVM-style RTTI and leaves type-list (T &...). + * The leaf type list can be different for each t. + * 2. Enumerate all potential cases of type combinations (see table_size). + * 3. Call the compile_time_switch lambda for all case numbers. + * 4. Use the unflatten function to get the type ND-index, i.e. the position in the type list for + * each t. + * 5. Use the "is-a all" lambda to check whether t... is of the type combination covered in the + * case. + * 6. If the case number matches the actual type combination of t..., then call the visitor, casting + * each t... to its actual type. + */ +template auto visit(Visitor &&visitor, T &...t) { + auto compile_time_switch = [&](std::index_sequence) { + const auto isa_all = [&](std::index_sequence) -> bool { + return (isa::leaves::template type_at>(t) && ...); + }; + const auto dispatch = [&](std::index_sequence) { + return visitor(*cast::leaves::template type_at>(&t)...); + }; + + using size = std::index_sequence::leaves::number_of_types()...>; + + using return_type = + std::common_type_t(size{})))...>; + + if constexpr (std::is_same_v) { + [[maybe_unused]] int discard = ((isa_all(detail::unflatten(size{})) + ? dispatch(detail::unflatten(size{})), + 0 : 0) || + ...); + } else { + return_type ret = {}; + [[maybe_unused]] int discard = ((isa_all(detail::unflatten(size{})) + ? (ret = dispatch(detail::unflatten(size{}))), + 0 : 0) || + ...); + return ret; + } + }; + + constexpr std::size_t table_size = (std::decay_t::leaves::number_of_types() * ...); + return compile_time_switch(std::make_index_sequence{}); +} + +} // namespace tinytc + +#endif // VISIT_20240903_HPP diff --git a/src/value.cpp b/src/value.cpp index ac96580d..c87f73dd 100644 --- a/src/value.cpp +++ b/src/value.cpp @@ -4,11 +4,11 @@ #include "error.hpp" #include "location.hpp" #include "node/value_node.hpp" +#include "support/util.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" -#include "util.hpp" #include #include diff --git a/src/visitor/alias_analysis.cpp b/src/visitor/alias_analysis.cpp index 1d0abb1d..891647dd 100644 --- a/src/visitor/alias_analysis.cpp +++ b/src/visitor/alias_analysis.cpp @@ -4,21 +4,19 @@ #include "visitor/alias_analysis.hpp" #include "error.hpp" #include "node/data_type_node.hpp" +#include "support/casting.hpp" +#include "support/visit.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" -#include - #include -using clir::visit; - namespace tinytc { /* Stmt nodes */ void alias_analyser::operator()(inst_node const &) {} void alias_analyser::operator()(alloca_inst const &a) { - auto t = dynamic_cast(a.result()->ty().get()); + auto t = dyn_cast(a.result()->ty().get()); if (t == nullptr) { throw compilation_error(a.loc(), status::ir_expected_memref); } diff --git a/src/visitor/check_ir.cpp b/src/visitor/check_ir.cpp index ab107ff7..e3eea4aa 100644 --- a/src/visitor/check_ir.cpp +++ b/src/visitor/check_ir.cpp @@ -3,22 +3,19 @@ #include "visitor/check_ir.hpp" #include "error.hpp" +#include "support/visit.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" -#include - #include -using clir::visit; - namespace tinytc { /* Stmt nodes */ void ir_checker::operator()(inst_node const &in) { - if (in.kind() == inst_kind::collective && inside_spmd_region_) { + if (in.kind() == inst_execution_kind::collective && inside_spmd_region_) { throw compilation_error(in.loc(), status::ir_collective_called_from_spmd); - } else if (in.kind() == inst_kind::spmd && !inside_spmd_region_) { + } else if (in.kind() == inst_execution_kind::spmd && !inside_spmd_region_) { throw compilation_error(in.loc(), status::ir_spmd_called_from_collective); } } diff --git a/src/visitor/dump_ir.cpp b/src/visitor/dump_ir.cpp index 46ebff7d..d02d1ec8 100644 --- a/src/visitor/dump_ir.cpp +++ b/src/visitor/dump_ir.cpp @@ -2,17 +2,14 @@ // SPDX-License-Identifier: BSD-3-Clause #include "visitor/dump_ir.hpp" +#include "support/visit.hpp" #include "tinytc/tinytc.hpp" -#include - #include #include #include #include -using clir::visit; - namespace tinytc { ir_dumper::ir_dumper(std::ostream &os) : os_(os) {} @@ -345,10 +342,9 @@ void ir_dumper::operator()(sum_inst const &a) { void ir_dumper::operator()(yield_inst const &y) { os_ << "yield "; - do_with_infix(y.vals().begin(), y.vals().end(), [this](auto const &i) { visit(*this, *i); }); + do_with_infix(y.op_begin(), y.op_end(), [this](auto const &i) { visit(*this, *i); }); os_ << " : "; - do_with_infix(y.vals().begin(), y.vals().end(), - [this](auto const &i) { visit(*this, *i->ty()); }); + do_with_infix(y.op_begin(), y.op_end(), [this](auto const &i) { visit(*this, *i->ty()); }); } /* Region nodes */ diff --git a/src/visitor/equal.cpp b/src/visitor/equal.cpp index d84f99c2..403331cf 100644 --- a/src/visitor/equal.cpp +++ b/src/visitor/equal.cpp @@ -2,14 +2,11 @@ // SPDX-License-Identifier: BSD-3-Clause #include "visitor/equal.hpp" +#include "support/visit.hpp" #include "tinytc/tinytc.hpp" -#include - #include -using clir::visit; - namespace tinytc { bool equal::operator()(data_type_node const &, data_type_node const &) { return false; } diff --git a/src/visitor/insert_barrier.cpp b/src/visitor/insert_barrier.cpp index d7d8c9e5..01923a9d 100644 --- a/src/visitor/insert_barrier.cpp +++ b/src/visitor/insert_barrier.cpp @@ -2,18 +2,17 @@ // SPDX-License-Identifier: BSD-3-Clause #include "visitor/insert_barrier.hpp" +#include "support/casting.hpp" +#include "support/visit.hpp" #include "tinytc/tinytc.hpp" #include "visitor/alias_analysis.hpp" #include -#include #include #include #include -using clir::visit; - namespace tinytc { /* Data type nodes */ @@ -68,7 +67,7 @@ std::unordered_set insert_barrier::operator()(fuse_inst &) { retur std::unordered_set insert_barrier::operator()(load_inst &e) { auto rw = std::unordered_set{}; - auto t = dynamic_cast(e.operand()->ty().get()); + auto t = dyn_cast(e.operand()->ty().get()); if (t) { rw.emplace(visit(*this, *e.operand())); } diff --git a/src/visitor/lifetime_analysis.cpp b/src/visitor/lifetime_analysis.cpp index 24e9c376..dd1e6420 100644 --- a/src/visitor/lifetime_analysis.cpp +++ b/src/visitor/lifetime_analysis.cpp @@ -3,16 +3,13 @@ #include "visitor/lifetime_analysis.hpp" #include "node/value_node.hpp" +#include "support/visit.hpp" #include "visitor/alias_analysis.hpp" -#include - #include #include #include -using clir::visit; - namespace tinytc { find_alloca::find_alloca(bool recursive) : recursive_(recursive) {} diff --git a/src/visitor/metadata.cpp b/src/visitor/metadata.cpp index 30e38ac2..0b4d4ef3 100644 --- a/src/visitor/metadata.cpp +++ b/src/visitor/metadata.cpp @@ -4,15 +4,12 @@ #include "visitor/metadata.hpp" #include "node/function_node.hpp" #include "node/program_node.hpp" +#include "support/visit.hpp" #include "tinytc/tinytc.hpp" -#include - #include #include -using clir::visit; - namespace tinytc { /* Function nodes */ diff --git a/src/visitor/opencl_ast.cpp b/src/visitor/opencl_ast.cpp index b0dd4f80..d33b411b 100644 --- a/src/visitor/opencl_ast.cpp +++ b/src/visitor/opencl_ast.cpp @@ -6,8 +6,10 @@ #include "error.hpp" #include "gemm_generator.hpp" #include "scalar_type.hpp" +#include "support/casting.hpp" +#include "support/util.hpp" +#include "support/visit.hpp" #include "tinytc/tinytc.hpp" -#include "util.hpp" #include #include @@ -29,8 +31,6 @@ #include #include -using clir::visit; - namespace tinytc { std::string var_name(std::string name) { @@ -49,7 +49,7 @@ dope_vector dope_vector::from_value(value_node const &v, decl_fun_t declare) { dt = to_clir_ty(scalar_type::index); }, [&](group_data_type const &g) { - m = dynamic_cast(g.ty().get()); + m = dyn_cast(g.ty().get()); dt = clir::pointer_to( to_clir_ty(scalar_type::index, clir::address_space::global_t)); }, @@ -142,7 +142,7 @@ clir::var opencl_ast::declare(value_node const &v) { } auto opencl_ast::get_memref_type(value_node const &v) const -> const memref_data_type * { - auto t = dynamic_cast(v.ty().get()); + auto t = dyn_cast(v.ty().get()); if (t == nullptr) { throw compilation_error(v.loc(), status::ir_expected_memref); } @@ -166,12 +166,12 @@ clir::data_type opencl_ast::operator()(void_data_type const &) { } clir::data_type opencl_ast::operator()(group_data_type const &g) { auto ptr_ty = visit(*this, *g.ty()); - ptr_ty = visit(overloaded{[](clir::internal::pointer &t) { - return clir::pointer_to( - clir::pointer_to(t.ty(), clir::address_space::global_t)); - }, - [](auto &) { return clir::data_type{}; }}, - *ptr_ty); + ptr_ty = clir::visit(overloaded{[](clir::internal::pointer &t) { + return clir::pointer_to(clir::pointer_to( + t.ty(), clir::address_space::global_t)); + }, + [](auto &) { return clir::data_type{}; }}, + *ptr_ty); if (!ptr_ty) { throw compilation_error(g.loc(), status::internal_compiler_error, "Could not determine OpenCL type of group type"); @@ -211,7 +211,7 @@ std::vector opencl_ast::operator()(alloca_inst const &a) { "Invalid stack_ptr in alloca. Did you run set_stack_ptrs?"); } auto result_var = declare(*a.result()); - auto t = dynamic_cast(a.result()->ty().get()); + auto t = dyn_cast(a.result()->ty().get()); if (t == nullptr) { throw compilation_error(a.loc(), status::ir_expected_memref); } @@ -1052,13 +1052,13 @@ std::vector opencl_ast::operator()(yield_inst const &in) { if (yielded_vars_.empty()) { throw compilation_error(in.loc(), status::ir_unexpected_yield); } - if (yielded_vars_.back().size() != in.vals().size()) { + if (static_cast(yielded_vars_.back().size()) != in.num_operands()) { throw compilation_error(in.loc(), status::ir_yield_mismatch); } std::vector clinst; - for (std::size_t i = 0; i < in.vals().size(); ++i) { + for (std::int64_t i = 0; i < in.num_operands(); ++i) { clinst.push_back(clir::expression_statement( - clir::assignment(yielded_vars_.back()[i], visit(*this, *in.vals()[i])))); + clir::assignment(yielded_vars_.back()[i], visit(*this, *in.op(i))))); } return clinst; } diff --git a/src/visitor/slot_tracker.cpp b/src/visitor/slot_tracker.cpp index 663e9162..439b60a5 100644 --- a/src/visitor/slot_tracker.cpp +++ b/src/visitor/slot_tracker.cpp @@ -2,15 +2,12 @@ // SPDX-License-Identifier: BSD-3-Clause #include "visitor/slot_tracker.hpp" +#include "support/visit.hpp" #include "tinytc/tinytc.hpp" -#include - #include #include -using clir::visit; - namespace tinytc { void slot_tracker::set_slot(value_node const &v) { diff --git a/src/visitor/stack.cpp b/src/visitor/stack.cpp index aa646bc5..2b291b88 100644 --- a/src/visitor/stack.cpp +++ b/src/visitor/stack.cpp @@ -4,21 +4,19 @@ #include "visitor/stack.hpp" #include "error.hpp" #include "node/data_type_node.hpp" +#include "support/casting.hpp" +#include "support/visit.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" -#include - #include -using clir::visit; - namespace tinytc { /* Inst nodes */ void stack_ptr::operator()(inst_node &) {} void stack_ptr::operator()(alloca_inst &a) { - auto t = dynamic_cast(a.result()->ty().get()); + auto t = dyn_cast(a.result()->ty().get()); if (t == nullptr) { throw compilation_error(a.loc(), status::ir_expected_memref); } diff --git a/src/visitor/work_group_size.cpp b/src/visitor/work_group_size.cpp index c67cfb57..3ad600cd 100644 --- a/src/visitor/work_group_size.cpp +++ b/src/visitor/work_group_size.cpp @@ -6,22 +6,20 @@ #include "error.hpp" #include "node/data_type_node.hpp" #include "node/value_node.hpp" +#include "support/casting.hpp" +#include "support/visit.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" -#include - #include #include #include #include -using clir::visit; - namespace tinytc { auto get_memref_type(value_node &v) { - auto t = dynamic_cast(v.ty().get()); + auto t = dyn_cast(v.ty().get()); if (t == nullptr) { throw compilation_error(v.loc(), status::ir_expected_memref); } diff --git a/src/ze/kernel.cpp b/src/ze/kernel.cpp index 3c2c573e..23bd099c 100644 --- a/src/ze/kernel.cpp +++ b/src/ze/kernel.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause #include "../compiler_options.hpp" -#include "../util.hpp" +#include "../support/util.hpp" #include "error.hpp" #include "opencl_cc.hpp" #include "tinytc/tinytc.h" From 77e1f1686c24ec7673e9f09f74cd26b65ef1cdc5 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 9 Sep 2024 16:39:27 +0200 Subject: [PATCH 013/297] Refactor results and child regions Signed-off-by: Carsten Uphoff --- src/inst.cpp | 7 +- src/node/data_type_node.cpp | 2 +- src/node/data_type_node.hpp | 24 +- src/node/function_node.hpp | 18 +- src/node/inst_node.cpp | 80 +++--- src/node/inst_node.hpp | 498 ++++++++++++++++++------------------ src/node/program_node.hpp | 14 +- src/node/region_node.hpp | 14 +- src/node/value_node.hpp | 22 +- src/parser/parser_impl.yy | 6 +- src/support/visit.hpp | 5 +- src/visitor/opencl_ast.cpp | 2 +- 12 files changed, 349 insertions(+), 343 deletions(-) diff --git a/src/inst.cpp b/src/inst.cpp index 220c06d9..8cb37e8c 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -500,13 +500,10 @@ tinytc_status_t tinytc_inst_get_values(const_tinytc_inst_t instr, uint32_t *resu } auto const num = static_cast(num_results); if (*result_list_size > 0) { - auto results = instr->results(); - if (results.size() != num_results) { - throw internal_compiler_error(); - } + auto results = instr->result_begin(); auto const limit = std::min(num, *result_list_size); for (uint32_t i = 0; i < limit; ++i) { - result_list[i] = results[i].release(); + result_list[i] = value(results[i]).release(); } } *result_list_size = num; diff --git a/src/node/data_type_node.cpp b/src/node/data_type_node.cpp index 4fefc97d..fb7ea76a 100644 --- a/src/node/data_type_node.cpp +++ b/src/node/data_type_node.cpp @@ -12,7 +12,7 @@ namespace tinytc { memref_data_type::memref_data_type(scalar_type type, std::vector shape, std::vector stride, location const &lc) - : data_type_node(DTK_memref), element_ty_(std::move(type)), shape_(std::move(shape)), + : data_type_node(DTK::memref), element_ty_(std::move(type)), shape_(std::move(shape)), stride_(std::move(stride)) { loc(lc); for (auto const &s : shape_) { diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp index cc0929e2..c4bd0c8e 100644 --- a/src/node/data_type_node.hpp +++ b/src/node/data_type_node.hpp @@ -19,23 +19,23 @@ #include namespace tinytc { +enum class DTK { group, memref, scalar, void_ }; using data_type_nodes = type_list; -} +} // namespace tinytc struct tinytc_data_type : tinytc::reference_counted { public: - enum data_type_kind { DTK_group, DTK_memref, DTK_scalar, DTK_void }; using leaves = tinytc::data_type_nodes; - inline tinytc_data_type(std::int64_t tid) : tid_(tid) {} - inline auto type_id() const -> std::int64_t { return tid_; } + inline tinytc_data_type(tinytc::DTK tid) : tid_(tid) {} + inline auto type_id() const -> tinytc::DTK { return tid_; } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } private: - std::int64_t tid_; + tinytc::DTK tid_; tinytc::location loc_; }; @@ -45,9 +45,9 @@ using data_type_node = ::tinytc_data_type; class group_data_type : public data_type_node { public: - inline static bool classof(data_type_node const &d) { return d.type_id() == DTK_group; } + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::group; } inline group_data_type(data_type ty, std::int64_t offset = 0, location const &lc = {}) - : data_type_node(DTK_group), ty_(std::move(ty)), offset_(offset) { + : data_type_node(DTK::group), ty_(std::move(ty)), offset_(offset) { loc(lc); } @@ -61,13 +61,13 @@ class group_data_type : public data_type_node { class void_data_type : public data_type_node { public: - inline static bool classof(data_type_node const &d) { return d.type_id() == DTK_void; } - inline void_data_type() : data_type_node(DTK_void) {} + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::void_; } + inline void_data_type() : data_type_node(DTK::void_) {} }; class memref_data_type : public data_type_node { public: - inline static bool classof(data_type_node const &d) { return d.type_id() == DTK_memref; } + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::memref; } memref_data_type(scalar_type type, std::vector shape, std::vector stride = {}, location const &lc = {}); @@ -106,9 +106,9 @@ class memref_data_type : public data_type_node { class scalar_data_type : public data_type_node { public: - inline static bool classof(data_type_node const &d) { return d.type_id() == DTK_scalar; } + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::scalar; } inline scalar_data_type(scalar_type type, location const &lc) - : data_type_node(DTK_scalar), ty_(type) { + : data_type_node(DTK::scalar), ty_(type) { loc(lc); } diff --git a/src/node/function_node.hpp b/src/node/function_node.hpp index aa6cc567..aed5717a 100644 --- a/src/node/function_node.hpp +++ b/src/node/function_node.hpp @@ -16,17 +16,17 @@ #include namespace tinytc { +enum class FK { function, prototype }; using function_nodes = type_list; -} +} // namespace tinytc struct tinytc_func : tinytc::reference_counted { public: - enum function_kind { FK_function, FK_prototype }; using leaves = tinytc::function_nodes; - inline tinytc_func(std::int64_t tid) : tid_(tid) {} + inline tinytc_func(tinytc::FK tid) : tid_(tid) {} inline virtual ~tinytc_func() {} - inline auto type_id() const -> std::int64_t { return tid_; } + inline auto type_id() const -> tinytc::FK { return tid_; } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } @@ -34,7 +34,7 @@ struct tinytc_func : tinytc::reference_counted { virtual auto name() const -> std::string_view = 0; private: - std::int64_t tid_; + tinytc::FK tid_; tinytc::location loc_; }; @@ -44,9 +44,9 @@ using function_node = ::tinytc_func; class prototype : public function_node { public: - inline static bool classof(function_node const &f) { return f.type_id() == FK_prototype; } + inline static bool classof(function_node const &f) { return f.type_id() == FK::prototype; } inline prototype(std::string name, std::vector args = {}, location const &lc = {}) - : function_node(FK_prototype), name_(std::move(name)), args_(std::move(args)) { + : function_node(FK::prototype), name_(std::move(name)), args_(std::move(args)) { loc(lc); } @@ -60,9 +60,9 @@ class prototype : public function_node { class function : public function_node { public: - inline static bool classof(function_node const &f) { return f.type_id() == FK_function; } + inline static bool classof(function_node const &f) { return f.type_id() == FK::function; } inline function(func prototype, region body, location const &lc = {}) - : function_node(FK_function), prototype_(std::move(prototype)), body_(std::move(body)), + : function_node(FK::function), prototype_(std::move(prototype)), body_(std::move(body)), work_group_size_{0, 0}, subgroup_size_{0} { loc(lc); } diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index b20928be..39b1e4f3 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -35,7 +35,7 @@ memref_data_type *get_memref_type(location const &loc, value const &v) { return m; } -blas_a2_inst::blas_a2_inst(std::int64_t tid, value alpha, value A, value beta, value B, bool atomic) +blas_a2_inst::blas_a2_inst(IK tid, value alpha, value A, value beta, value B, bool atomic) : standard_inst{tid}, atomic_(atomic) { op(op_alpha) = std::move(alpha); op(op_A) = std::move(A); @@ -43,8 +43,7 @@ blas_a2_inst::blas_a2_inst(std::int64_t tid, value alpha, value A, value beta, v op(op_B) = std::move(B); } -blas_a3_inst::blas_a3_inst(std::int64_t tid, value alpha, value A, value B, value beta, value C, - bool atomic) +blas_a3_inst::blas_a3_inst(IK tid, value alpha, value A, value B, value beta, value C, bool atomic) : standard_inst{tid}, atomic_(atomic) { op(op_alpha) = std::move(alpha); op(op_A) = std::move(A); @@ -53,13 +52,14 @@ blas_a3_inst::blas_a3_inst(std::int64_t tid, value alpha, value A, value B, valu op(op_C) = std::move(C); } -loop_inst::loop_inst(std::int64_t tid, value loop_var0, value from0, value to0, value step0, - region body, location const &lc) - : standard_inst{tid}, body_(std::move(body)) { +loop_inst::loop_inst(IK tid, value loop_var0, value from0, value to0, value step0, region body, + location const &lc) + : standard_inst{tid} { op(op_loop_var) = std::move(loop_var0); op(op_from) = std::move(from0); op(op_to) = std::move(to0); op(op_step) = std::move(step0); + child_region(0) = std::move(body); loc(lc); auto lvt = get_scalar_type(loc(), loop_var()); @@ -77,9 +77,11 @@ loop_inst::loop_inst(std::int64_t tid, value loop_var0, value from0, value to0, } alloca_inst::alloca_inst(data_type ty, location const &lc) - : standard_inst{IK_alloca}, result_{make_value(std::move(ty))}, stack_ptr_{-1} { + : standard_inst{IK::alloca}, stack_ptr_{-1} { loc(lc); - auto memref = dyn_cast(result_->ty().get()); + + result(0) = make_value(std::move(ty)); + auto memref = dyn_cast(result(0)->ty().get()); if (memref == nullptr) { throw compilation_error(loc(), status::ir_expected_memref); } @@ -88,7 +90,7 @@ alloca_inst::alloca_inst(data_type ty, location const &lc) axpby_inst::axpby_inst(transpose tA, value alpha0, value A0, value beta0, value B0, bool atomic, location const &lc) - : blas_a2_inst(IK_axpby_blas_a2, std::move(alpha0), std::move(A0), std::move(beta0), + : blas_a2_inst(IK::axpby_blas_a2, std::move(alpha0), std::move(A0), std::move(beta0), std::move(B0), atomic), tA_(tA) { loc(lc); @@ -112,7 +114,7 @@ axpby_inst::axpby_inst(transpose tA, value alpha0, value A0, value beta0, value } arith_inst::arith_inst(arithmetic operation, value a0, value b0, location const &lc) - : standard_inst{IK_arith}, operation_(operation) { + : standard_inst{IK::arith}, operation_(operation) { op(op_a) = std::move(a0); op(op_b) = std::move(b0); loc(lc); @@ -143,11 +145,11 @@ arith_inst::arith_inst(arithmetic operation, value a0, value b0, location const if (!inst_supports_fp && is_floating_type(at->ty())) { throw compilation_error(loc(), status::ir_fp_unsupported); } - result_ = make_value(at->ty()); + result(0) = make_value(at->ty()); } arith_unary_inst::arith_unary_inst(arithmetic_unary operation, value a0, location const &lc) - : standard_inst{IK_arith_unary}, operation_(operation) { + : standard_inst{IK::arith_unary}, operation_(operation) { op(op_a) = std::move(a0); loc(lc); @@ -164,17 +166,18 @@ arith_unary_inst::arith_unary_inst(arithmetic_unary operation, value a0, locatio if (!inst_supports_fp && is_floating_type(at->ty())) { throw compilation_error(loc(), status::ir_fp_unsupported); } - result_ = make_value(at->ty()); + result(0) = make_value(at->ty()); } -cast_inst::cast_inst(value a, scalar_type to_ty, location const &lc) - : standard_inst{IK_cast}, result_{make_value(to_ty)} { +cast_inst::cast_inst(value a, scalar_type to_ty, location const &lc) : standard_inst{IK::cast} { op(op_a) = std::move(a); loc(lc); + + result(0) = make_value(std::move(to_ty)); } compare_inst::compare_inst(cmp_condition cond, value a0, value b0, location const &lc) - : standard_inst{IK_compare}, cond_(cond), result_{make_value(scalar_type::i1)} { + : standard_inst{IK::compare}, cond_(cond) { op(op_a) = std::move(a0); op(op_b) = std::move(b0); loc(lc); @@ -185,11 +188,13 @@ compare_inst::compare_inst(cmp_condition cond, value a0, value b0, location cons if (at->ty() != bt->ty()) { throw compilation_error(loc(), status::ir_scalar_mismatch); } + + result(0) = make_value(scalar_type::i1); } expand_inst::expand_inst(value op0, std::int64_t mode, std::vector const &expand_shape0, location const &lc) - : standard_inst{IK_expand, static_cast(1 + expand_shape0.size())}, mode_(mode) { + : standard_inst{IK::expand, static_cast(1 + expand_shape0.size())}, mode_(mode) { op(0) = std::move(op0); for (std::size_t i = 0; i < expand_shape0.size(); ++i) { op(1 + i) = expand_shape0[i]; @@ -279,11 +284,11 @@ expand_inst::expand_inst(value op0, std::int64_t mode, std::vector const auto r = std::make_unique(m->element_ty(), shape, stride); r->addrspace(m->addrspace()); - result_ = make_value(data_type(r.release())); + result(0) = make_value(data_type(r.release())); } fuse_inst::fuse_inst(value op0, std::int64_t from, std::int64_t to, location const &lc) - : standard_inst{IK_fuse}, from_(from), to_(to) { + : standard_inst{IK::fuse}, from_(from), to_(to) { op(0) = std::move(op0); loc(lc); auto m = get_memref_type(loc(), operand()); @@ -317,11 +322,11 @@ fuse_inst::fuse_inst(value op0, std::int64_t from, std::int64_t to, location con auto r = std::make_unique(m->element_ty(), shape, stride); r->addrspace(m->addrspace()); - result_ = make_value(data_type(r.release())); + result(0) = make_value(data_type(r.release())); } load_inst::load_inst(value op0, std::vector const &index_list0, location const &lc) - : standard_inst{IK_load, static_cast(1 + index_list0.size())} { + : standard_inst{IK::load, static_cast(1 + index_list0.size())} { op(0) = std::move(op0); for (std::size_t i = 0; i < index_list0.size(); ++i) { op(1 + i) = index_list0[i]; @@ -333,13 +338,13 @@ load_inst::load_inst(value op0, std::vector const &index_list0, location if (static_cast(index_list().size()) != 1) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } - result_ = make_value(g.ty()); + result(0) = make_value(g.ty()); }, [&](memref_data_type &m) { if (m.dim() != static_cast(index_list().size())) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } - result_ = make_value(m.element_ty()); + result(0) = make_value(m.element_ty()); }, [&](auto &) { throw compilation_error(loc(), status::ir_expected_memref_or_group); }}, *operand()->ty()); @@ -347,7 +352,7 @@ load_inst::load_inst(value op0, std::vector const &index_list0, location gemm_inst::gemm_inst(transpose tA, transpose tB, value alpha0, value A0, value B0, value beta0, value C0, bool atomic, location const &lc) - : blas_a3_inst(IK_gemm_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), + : blas_a3_inst(IK::gemm_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), std::move(beta0), std::move(C0), atomic), tA_(tA), tB_(tB) { loc(lc); @@ -377,7 +382,7 @@ gemm_inst::gemm_inst(transpose tA, transpose tB, value alpha0, value A0, value B gemv_inst::gemv_inst(transpose tA, value alpha0, value A0, value B0, value beta0, value C0, bool atomic, location const &lc) - : blas_a3_inst(IK_gemv_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), + : blas_a3_inst(IK::gemv_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), std::move(beta0), std::move(C0), atomic), tA_(tA) { loc(lc); @@ -405,7 +410,7 @@ gemv_inst::gemv_inst(transpose tA, value alpha0, value A0, value B0, value beta0 ger_inst::ger_inst(value alpha0, value A0, value B0, value beta0, value C0, bool atomic, location const &lc) - : blas_a3_inst(IK_ger_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), + : blas_a3_inst(IK::ger_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), std::move(beta0), std::move(C0), atomic) { loc(lc); auto a = get_memref_type(loc(), A()); @@ -431,7 +436,7 @@ ger_inst::ger_inst(value alpha0, value A0, value B0, value beta0, value C0, bool hadamard_inst::hadamard_inst(value alpha0, value A0, value B0, value beta0, value C0, bool atomic, location const &lc) - : blas_a3_inst(IK_hadamard_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), + : blas_a3_inst(IK::hadamard_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), std::move(beta0), std::move(C0), atomic) { loc(lc); auto a = get_memref_type(loc(), A()); @@ -456,17 +461,18 @@ hadamard_inst::hadamard_inst(value alpha0, value A0, value B0, value beta0, valu if_inst::if_inst(value condition, region then, region otherwise, std::vector const &return_types, location const &lc) - : standard_inst{IK_if, 1, static_cast(return_types.size())}, then_(std::move(then)), - otherwise_(std::move(otherwise)) { + : standard_inst{IK::if_, 1, static_cast(return_types.size())} { op(0) = std::move(condition); + child_region(child_region_then) = std::move(then); + child_region(child_region_otherwise) = std::move(otherwise); loc(lc); - for (auto &ty : return_types) { - results_.push_back(make_value(ty)); + for (std::size_t i = 0; i < return_types.size(); ++i) { + result(i) = make_value(return_types[i]); } } size_inst::size_inst(value op0, std::int64_t mode, location const &lc) - : standard_inst{IK_size}, mode_(mode) { + : standard_inst{IK::size}, mode_(mode) { op(0) = std::move(op0); loc(lc); auto m = get_memref_type(loc(), operand()); @@ -475,12 +481,12 @@ size_inst::size_inst(value op0, std::int64_t mode, location const &lc) throw compilation_error(loc(), status::ir_out_of_bounds); } - result_ = make_value(scalar_type::index); + result(0) = make_value(scalar_type::index); } subview_inst::subview_inst(value op0, std::vector const &offset_list0, std::vector const &size_list0, location const &lc) - : standard_inst{IK_subview, + : standard_inst{IK::subview, static_cast(1 + offset_list0.size() + size_list0.size())} { op(0) = std::move(op0); { @@ -546,12 +552,12 @@ subview_inst::subview_inst(value op0, std::vector const &offset_list0, auto r = std::make_unique(m->element_ty(), shape, stride); r->addrspace(m->addrspace()); - result_ = make_value(data_type(r.release())); + result(0) = make_value(data_type(r.release())); } store_inst::store_inst(value val0, value op0, std::vector const &index_list0, location const &lc) - : standard_inst{IK_store, static_cast(2 + index_list0.size())} { + : standard_inst{IK::store, static_cast(2 + index_list0.size())} { op(op_val) = std::move(val0); op(op_operand) = std::move(op0); { @@ -576,7 +582,7 @@ store_inst::store_inst(value val0, value op0, std::vector const &index_li sum_inst::sum_inst(transpose tA, value alpha0, value A0, value beta0, value B0, bool atomic, location const &lc) - : blas_a2_inst(IK_sum_blas_a2, std::move(alpha0), std::move(A0), std::move(beta0), + : blas_a2_inst(IK::sum_blas_a2, std::move(alpha0), std::move(A0), std::move(beta0), std::move(B0), atomic), tA_(tA) { loc(lc); diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index f1de7416..94840760 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -27,7 +27,47 @@ enum class inst_execution_kind { spmd ///< SPMD instruction on varying data }; - +enum class IK { + alloca, + arith, + arith_unary, + barrier, + cast, + compare, + expand, + fuse, + load, + group_id, + group_size, + lifetime_stop, + if_, + num_subgroups, + parallel, + size, + subgroup_id, + subgroup_local_id, + subgroup_size, + subview, + store, + yield, + // blas a2 + blas_a2, + axpby_blas_a2, + sum_blas_a2, + last_blas_a2, + // blas a3 + blas_a3, + gemm_blas_a3, + gemv_blas_a3, + ger_blas_a3, + hadamard_blas_a3, + last_blas_a3, + // loop inst + loop, + for_loop, + foreach_loop, + last_loop +}; using inst_nodes = type_list; -using op_range = iterator_range_wrapper; -using const_op_range = iterator_range_wrapper; -using result_range = iterator_range_wrapper; -using const_result_range = iterator_range_wrapper; +using value_range = iterator_range_wrapper; +using const_value_range = iterator_range_wrapper; +using region_range = iterator_range_wrapper; +using const_region_range = iterator_range_wrapper; } // namespace tinytc struct tinytc_inst : tinytc::reference_counted { public: - enum inst_kind { - IK_alloca, - IK_arith, - IK_arith_unary, - IK_barrier, - IK_cast, - IK_compare, - IK_expand, - IK_fuse, - IK_load, - IK_group_id, - IK_group_size, - IK_lifetime_stop, - IK_if, - IK_num_subgroups, - IK_parallel, - IK_size, - IK_subgroup_id, - IK_subgroup_local_id, - IK_subgroup_size, - IK_subview, - IK_store, - IK_yield, - // blas a2 - IK_blas_a2, - IK_axpby_blas_a2, - IK_sum_blas_a2, - IK_last_blas_a2, - // blas a3 - IK_blas_a3, - IK_gemm_blas_a3, - IK_gemv_blas_a3, - IK_ger_blas_a3, - IK_hadamard_blas_a3, - IK_last_blas_a3, - // loop inst - IK_loop, - IK_for_loop, - IK_foreach_loop, - IK_last_loop - }; using leaves = tinytc::inst_nodes; - inline tinytc_inst(std::int64_t tid) : tid_(tid), op_begin_(nullptr), op_end_(nullptr) {} - inline virtual ~tinytc_inst() {} - inline auto type_id() const -> std::int64_t { return tid_; } + inline tinytc_inst(tinytc::IK tid) : tid_(tid) {} + inline auto type_id() const -> tinytc::IK { return tid_; } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } @@ -100,93 +98,192 @@ struct tinytc_inst : tinytc::reference_counted { // Iterator over operands inline auto op_begin() -> tinytc::value * { return op_begin_; } inline auto op_end() -> tinytc::value * { return op_end_; } - inline auto operands() -> tinytc::op_range { return tinytc::op_range{op_begin_, op_end_}; } + inline auto operands() -> tinytc::value_range { + return tinytc::value_range{op_begin_, op_end_}; + } inline auto op_begin() const -> tinytc::value const * { return op_begin_; } inline auto op_end() const -> tinytc::value const * { return op_end_; } - inline auto operands() const -> tinytc::const_op_range { - return tinytc::const_op_range{op_begin_, op_end_}; + inline auto operands() const -> tinytc::const_value_range { + return tinytc::const_value_range{op_begin_, op_end_}; } inline auto op(std::size_t pos) -> tinytc::value & { return op_begin_[pos]; } inline auto op(std::size_t pos) const -> tinytc::value const & { return op_begin_[pos]; } inline auto num_operands() const -> std::int64_t { return op_end_ - op_begin_; } - virtual tinytc::value result() const = 0; - inline virtual auto results() const -> std::vector { - if (auto r = result(); r) { - return {std::move(r)}; - } - return {}; + // Iterator over results + inline auto result_begin() -> tinytc::value * { return result_begin_; } + inline auto result_end() -> tinytc::value * { return result_end_; } + inline auto results() -> tinytc::value_range { + return tinytc::value_range{result_begin_, result_end_}; + } + inline auto result_begin() const -> tinytc::value const * { return result_begin_; } + inline auto result_end() const -> tinytc::value const * { return result_end_; } + inline auto results() const -> tinytc::const_value_range { + return tinytc::const_value_range{result_begin_, result_end_}; + } + inline auto result() const -> tinytc::value { + return num_results() > 0 ? result_begin_[0] : tinytc::value{}; + } + inline auto result(std::size_t pos) -> tinytc::value & { return result_begin_[pos]; } + inline auto result(std::size_t pos) const -> tinytc::value const & { + return result_begin_[pos]; + } + inline auto num_results() const -> std::int64_t { return result_end_ - result_begin_; } + + // Iterator over regions + inline auto child_regions_begin() -> tinytc::region * { return child_regions_begin_; } + inline auto child_regions_end() -> tinytc::region * { return child_regions_end_; } + inline auto child_regions() -> tinytc::region_range { + return tinytc::region_range{child_regions_begin_, child_regions_end_}; + } + inline auto child_regions_begin() const -> tinytc::region const * { + return child_regions_begin_; + } + inline auto child_regions_end() const -> tinytc::region const * { return child_regions_end_; } + inline auto child_regions() const -> tinytc::const_region_range { + return tinytc::const_region_range{child_regions_begin_, child_regions_end_}; + } + inline auto child_region(std::size_t pos) -> tinytc::region & { + return child_regions_begin_[pos]; + } + inline auto child_region(std::size_t pos) const -> tinytc::region const & { + return child_regions_begin_[pos]; + } + inline auto num_child_regions() const -> std::int64_t { + return child_regions_end_ - child_regions_begin_; + } + + inline constexpr auto kind() const -> tinytc::inst_execution_kind { + switch (type_id()) { + case tinytc::IK::alloca: + case tinytc::IK::barrier: + case tinytc::IK::lifetime_stop: + case tinytc::IK::foreach_loop: + case tinytc::IK::parallel: + case tinytc::IK::blas_a2: + case tinytc::IK::axpby_blas_a2: + case tinytc::IK::sum_blas_a2: + case tinytc::IK::last_blas_a2: + case tinytc::IK::blas_a3: + case tinytc::IK::gemm_blas_a3: + case tinytc::IK::gemv_blas_a3: + case tinytc::IK::ger_blas_a3: + case tinytc::IK::hadamard_blas_a3: + case tinytc::IK::last_blas_a3: + return tinytc::inst_execution_kind::collective; + case tinytc::IK::arith: + case tinytc::IK::arith_unary: + case tinytc::IK::cast: + case tinytc::IK::compare: + case tinytc::IK::expand: + case tinytc::IK::fuse: + case tinytc::IK::load: + case tinytc::IK::group_id: + case tinytc::IK::group_size: + case tinytc::IK::if_: + case tinytc::IK::num_subgroups: + case tinytc::IK::size: + case tinytc::IK::subgroup_size: + case tinytc::IK::subview: + case tinytc::IK::store: + case tinytc::IK::yield: + case tinytc::IK::loop: + case tinytc::IK::for_loop: + case tinytc::IK::last_loop: + return tinytc::inst_execution_kind::mixed; + case tinytc::IK::subgroup_id: + case tinytc::IK::subgroup_local_id: + return tinytc::inst_execution_kind::spmd; + }; + throw tinytc::internal_compiler_error(); } - inline virtual auto num_results() const -> std::size_t { return result() ? 1u : 0u; } - virtual tinytc::inst_execution_kind kind() const = 0; protected: inline auto op_range(tinytc::value *begin, tinytc::value *end) { op_begin_ = begin; op_end_ = end; } + inline auto result_range(tinytc::value *begin, tinytc::value *end) { + result_begin_ = begin; + result_end_ = end; + } + inline auto child_regions_range(tinytc::region *begin, tinytc::region *end) { + child_regions_begin_ = begin; + child_regions_end_ = end; + } private: - std::int64_t tid_; + tinytc::IK tid_; tinytc::location loc_; - tinytc::value *op_begin_, *op_end_; + tinytc::value *op_begin_ = nullptr, *op_end_ = nullptr, *result_begin_ = nullptr, + *result_end_ = nullptr; + tinytc::region *child_regions_begin_ = nullptr, *child_regions_end_ = nullptr; }; namespace tinytc { using inst_node = ::tinytc_inst; -template class value_container { +template class object_container { public: - value_container(std::int64_t num_values) { - if (num_values != NumValues) { + object_container(std::int64_t num_objects) { + if (num_objects != NumObjects) { throw internal_compiler_error(); } } - inline auto get() -> tinytc::value * { - if constexpr (NumValues == 0) { + inline auto get() -> T * { + if constexpr (NumObjects == 0) { return nullptr; } - return ops_.data(); + return objs_.data(); } private: - std::array ops_; + std::array objs_; }; -template <> class value_container { +template class object_container { public: - value_container(std::int64_t num_values) : ops_{std::make_unique(num_values)} {} + object_container(std::int64_t num_objects) : objs_{std::make_unique(num_objects)} {} - auto get() -> tinytc::value * { return ops_.get(); } + auto get() -> T * { return objs_.get(); } private: - std::unique_ptr ops_; + std::unique_ptr objs_; }; -template +template class standard_inst : public inst_node { public: - standard_inst(std::int64_t tid, std::int64_t num_operands = NumOperands, - std::int64_t num_results = NumResults) - : inst_node{tid}, ops_{num_operands}, results_{num_results} { + standard_inst(IK tid, std::int64_t num_operands = NumOperands, + std::int64_t num_results = NumResults, + std::int64_t num_child_regions = NumChildRegions) + : inst_node{tid}, ops_{num_operands}, results_{num_results}, + child_regions_{num_child_regions} { if (num_operands > 0) { op_range(ops_.get(), ops_.get() + num_operands); } + if (num_results > 0) { + result_range(results_.get(), results_.get() + num_results); + } + if (num_child_regions > 0) { + child_regions_range(child_regions_.get(), child_regions_.get() + num_child_regions); + } } private: - value_container ops_; - value_container results_; + object_container ops_; + object_container results_; + object_container child_regions_; }; -class blas_a2_inst : public standard_inst<4, 1> { +class blas_a2_inst : public standard_inst<4, 0> { public: inline static bool classof(inst_node const &i) { - return i.type_id() >= IK_blas_a2 && i.type_id() <= IK_last_blas_a2; + return i.type_id() >= IK::blas_a2 && i.type_id() <= IK::last_blas_a2; } enum op_number { op_alpha = 0, op_A = 1, op_beta = 2, op_B = 3 }; - blas_a2_inst(std::int64_t tid, value alpha, value A, value beta, value B, bool atomic); + blas_a2_inst(IK tid, value alpha, value A, value beta, value B, bool atomic); inline bool atomic() const { return atomic_; } inline void atomic(bool a) { atomic_ = a; } @@ -194,20 +291,18 @@ class blas_a2_inst : public standard_inst<4, 1> { inline auto A() const -> value const & { return op(op_A); } inline auto beta() const -> value const & { return op(op_beta); } inline auto B() const -> value const & { return op(op_B); } - inline value result() const override { return value{}; } - inline inst_execution_kind kind() const override { return inst_execution_kind::collective; } protected: bool atomic_; }; -class blas_a3_inst : public standard_inst<5, 1> { +class blas_a3_inst : public standard_inst<5, 0> { public: inline static bool classof(inst_node const &i) { - return i.type_id() >= IK_blas_a3 && i.type_id() <= IK_last_blas_a3; + return i.type_id() >= IK::blas_a3 && i.type_id() <= IK::last_blas_a3; } enum op_number { op_alpha = 0, op_A = 1, op_B = 2, op_beta = 3, op_C = 4 }; - blas_a3_inst(std::int64_t tid, value alpha, value A, value B, value beta, value C, bool atomic); + blas_a3_inst(IK tid, value alpha, value A, value B, value beta, value C, bool atomic); inline bool atomic() const { return atomic_; } inline void atomic(bool a) { atomic_ = a; } @@ -216,50 +311,41 @@ class blas_a3_inst : public standard_inst<5, 1> { inline auto B() const -> value const & { return op(op_B); } inline auto beta() const -> value const & { return op(op_beta); } inline auto C() const -> value const & { return op(op_C); } - inline value result() const override { return value{}; } - inline inst_execution_kind kind() const override { return inst_execution_kind::collective; } protected: bool atomic_; }; -class loop_inst : public standard_inst<4, 1> { +class loop_inst : public standard_inst<4, 0, 1> { public: inline static bool classof(inst_node const &i) { - return i.type_id() >= IK_loop && i.type_id() <= IK_last_loop; + return i.type_id() >= IK::loop && i.type_id() <= IK::last_loop; } enum op_number { op_loop_var = 0, op_from = 1, op_to = 2, op_step = 3 }; - loop_inst(std::int64_t tid, value loop_var, value from, value to, value step, region body, + loop_inst(IK tid, value loop_var, value from, value to, value step, region body, location const &loc = {}); inline auto loop_var() const -> value const & { return op(op_loop_var); } inline auto from() const -> value const & { return op(op_from); } inline auto to() const -> value const & { return op(op_to); } inline auto step() const -> value const & { return op(op_step); } - inline auto body() const -> region const & { return body_; } - inline value result() const override { return value{}; } - - private: - region body_; + inline auto body() const -> region const & { return child_region(0); } }; class alloca_inst : public standard_inst<0, 1> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_alloca; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::alloca; } alloca_inst(data_type ty, location const &loc = {}); - inline value result() const override { return result_; } inline std::int64_t stack_ptr() const { return stack_ptr_; } inline void stack_ptr(std::int64_t ptr) { stack_ptr_ = ptr; } - inline inst_execution_kind kind() const override { return inst_execution_kind::collective; } private: - value result_; std::int64_t stack_ptr_; }; class axpby_inst : public blas_a2_inst { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_axpby_blas_a2; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::axpby_blas_a2; } axpby_inst(transpose tA, value alpha, value A, value beta, value B, bool atomic = false, location const &lc = {}); @@ -271,79 +357,62 @@ class axpby_inst : public blas_a2_inst { class arith_inst : public standard_inst<2, 1> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_arith; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::arith; } enum op_number { op_a = 0, op_b = 1 }; arith_inst(arithmetic op, value a, value b, location const &lc = {}); inline arithmetic operation() const { return operation_; } inline auto a() const -> value const & { return op(op_a); } inline auto b() const -> value const & { return op(op_b); } - inline value result() const override { return result_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: arithmetic operation_; - value result_; }; class arith_unary_inst : public standard_inst<1, 1> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_arith_unary; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::arith_unary; } enum op_number { op_a = 0 }; arith_unary_inst(arithmetic_unary op, value a, location const &lc = {}); inline arithmetic_unary operation() const { return operation_; } inline auto a() const -> value const & { return op(op_a); } - inline value result() const override { return result_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: arithmetic_unary operation_; - value result_; }; class barrier_inst : public standard_inst<0, 0> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_barrier; } - inline barrier_inst() : standard_inst{IK_barrier} {} - - inline value result() const override { return value{}; } - inline inst_execution_kind kind() const override { return inst_execution_kind::collective; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::barrier; } + inline barrier_inst() : standard_inst{IK::barrier} {} }; class cast_inst : public standard_inst<1, 1> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_cast; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::cast; } enum op_number { op_a = 0 }; cast_inst(value a, scalar_type to_ty, location const &lc = {}); inline auto a() const -> value const & { return op(op_a); } - inline value result() const override { return result_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } - - private: - value result_; }; class compare_inst : public standard_inst<2, 1> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_compare; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::compare; } enum op_number { op_a = 0, op_b = 1 }; compare_inst(cmp_condition cond, value a, value b, location const &lc = {}); inline cmp_condition cond() const { return cond_; } inline auto a() const -> value const & { return op(op_a); } inline auto b() const -> value const & { return op(op_b); } - inline value result() const override { return result_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: cmp_condition cond_; - value result_; }; class expand_inst : public standard_inst { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_expand; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::expand; } expand_inst(value op, std::int64_t mode, std::vector const &expand_shape, location const &lc = {}); @@ -352,86 +421,63 @@ class expand_inst : public standard_inst { inline auto expand_shape() { return operands() | std::views::drop(1); } inline auto expand_shape() const { return operands() | std::views::drop(1); } inline auto expand_shape(std::int64_t i) const -> value const & { return op(i + 1); } - inline value result() const override { return result_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: - value result_; std::int64_t mode_; }; class fuse_inst : public standard_inst<1, 1> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_fuse; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::fuse; } fuse_inst(value op, std::int64_t from, std::int64_t to, location const &lc = {}); inline auto operand() const -> value const & { return op(0); } inline std::int64_t from() const { return from_; } inline std::int64_t to() const { return to_; } - inline value result() const override { return result_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: - value result_; std::int64_t from_, to_; }; class load_inst : public standard_inst { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_load; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::load; } load_inst(value op, std::vector const &index_list, location const &lc = {}); inline auto operand() const -> value const & { return op(0); } inline auto index_list() const { return operands() | std::views::drop(1); } - inline value result() const override { return result_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } - - private: - value result_; }; class group_id_inst : public standard_inst<0, 1> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_group_id; } - inline group_id_inst(location const &lc = {}) - : standard_inst{IK_group_id}, result_{make_value(scalar_type::index)} { + inline static bool classof(inst_node const &i) { return i.type_id() == IK::group_id; } + inline group_id_inst(location const &lc = {}) : standard_inst{IK::group_id} { loc(lc); + result(0) = make_value(scalar_type::index); } - inline value result() const override { return result_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } - - private: - value result_; }; class group_size_inst : public standard_inst<0, 1> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_group_size; } - inline group_size_inst(location const &lc = {}) - : standard_inst{IK_group_size}, result_{make_value(scalar_type::index)} { + inline static bool classof(inst_node const &i) { return i.type_id() == IK::group_size; } + inline group_size_inst(location const &lc = {}) : standard_inst{IK::group_size} { loc(lc); + result(0) = make_value(scalar_type::index); } - inline value result() const override { return result_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } - - private: - value result_; }; -class lifetime_stop_inst : public standard_inst<1, 1> { +class lifetime_stop_inst : public standard_inst<1, 0> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_lifetime_stop; } - inline lifetime_stop_inst(value obj) : standard_inst{IK_lifetime_stop} { + inline static bool classof(inst_node const &i) { return i.type_id() == IK::lifetime_stop; } + inline lifetime_stop_inst(value obj) : standard_inst{IK::lifetime_stop} { op(0) = std::move(obj); } inline auto object() const -> value const & { return op(0); } - inline value result() const override { return value{}; } - inline inst_execution_kind kind() const override { return inst_execution_kind::collective; } }; class gemm_inst : public blas_a3_inst { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_gemm_blas_a3; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::gemm_blas_a3; } gemm_inst(transpose tA, transpose tB, value alpha, value A, value B, value beta, value C, bool atomic = false, location const &lc = {}); @@ -444,7 +490,7 @@ class gemm_inst : public blas_a3_inst { class gemv_inst : public blas_a3_inst { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_gemv_blas_a3; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::gemv_blas_a3; } gemv_inst(transpose tA, value alpha, value A, value B, value beta, value C, bool atomic = false, location const &lc = {}); @@ -456,162 +502,125 @@ class gemv_inst : public blas_a3_inst { class ger_inst : public blas_a3_inst { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_ger_blas_a3; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::ger_blas_a3; } ger_inst(value alpha, value A, value B, value beta, value C, bool atomic = false, location const &lc = {}); }; class for_inst : public loop_inst { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_for_loop; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::for_loop; } inline for_inst(value loop_var, value from, value to, region body, location const &loc = {}) - : loop_inst{ - IK_for_loop, std::move(loop_var), std::move(from), std::move(to), {}, std::move(body), - loc} {} + : loop_inst{IK::for_loop, + std::move(loop_var), + std::move(from), + std::move(to), + {}, + std::move(body), + loc} {} inline for_inst(value loop_var, value from, value to, value step, region body, location const &loc = {}) - : loop_inst{IK_for_loop, + : loop_inst{IK::for_loop, std::move(loop_var), std::move(from), std::move(to), std::move(step), std::move(body), loc} {} - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } }; class foreach_inst : public loop_inst { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_foreach_loop; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::foreach_loop; } inline foreach_inst(value loop_var, value from, value to, region body, location const &loc = {}) - : loop_inst{IK_foreach_loop, + : loop_inst{IK::foreach_loop, std::move(loop_var), std::move(from), std::move(to), {}, std::move(body), loc} {} - inline inst_execution_kind kind() const override { return inst_execution_kind::collective; } }; class hadamard_inst : public blas_a3_inst { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_hadamard_blas_a3; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::hadamard_blas_a3; } hadamard_inst(value alpha, value A, value B, value beta, value C, bool atomic = false, location const &lc = {}); }; -class if_inst : public standard_inst<1, dynamic> { +class if_inst : public standard_inst<1, dynamic, 2> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_if; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::if_; } + enum child_region_number { child_region_then = 0, child_region_otherwise = 1 }; if_inst(value condition, region then, region otherwise = {}, std::vector const &return_types = {}, location const &lc = {}); inline auto condition() const -> value const & { return op(0); } - inline auto then() const -> region const & { return then_; } - inline auto otherwise() const -> region const & { return otherwise_; } - inline value result() const override { - return results_.size() > 0 ? results_.front() : value{}; - } - inline auto results() const -> std::vector override { return results_; } - inline auto num_results() const -> std::size_t override { return results_.size(); } - inline auto results_ref() -> std::vector & { return results_; } - inline auto results_ref() const -> std::vector const & { return results_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } - - private: - region then_, otherwise_; - std::vector results_; + inline auto then() const -> region const & { return child_region(child_region_then); } + inline auto otherwise() const -> region const & { return child_region(child_region_otherwise); } }; class num_subgroups_inst : public standard_inst<0, 1> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_num_subgroups; } - inline num_subgroups_inst(location const &lc = {}) - : standard_inst{IK_num_subgroups}, result_{make_value(scalar_type::i32)} { + inline static bool classof(inst_node const &i) { return i.type_id() == IK::num_subgroups; } + inline num_subgroups_inst(location const &lc = {}) : standard_inst{IK::num_subgroups} { loc(lc); + result(0) = make_value(scalar_type::i32); } - inline value result() const override { return result_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } - - private: - value result_; }; -class parallel_inst : public standard_inst<0, 0> { +class parallel_inst : public standard_inst<0, 0, 1> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_parallel; } - inline parallel_inst(region body, location const &lc = {}) - : standard_inst{IK_parallel}, body_(std::move(body)) { + inline static bool classof(inst_node const &i) { return i.type_id() == IK::parallel; } + inline parallel_inst(region body, location const &lc = {}) : standard_inst{IK::parallel} { + child_region(0) = std::move(body); loc(lc); } - inline auto body() const -> region const & { return body_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::collective; } - inline value result() const override { return value{}; } - - private: - region body_; + inline auto body() const -> region const & { return child_region(0); } }; class size_inst : public standard_inst<1, 1> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_size; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::size; } size_inst(value op, std::int64_t mode, location const &lc = {}); inline auto operand() const -> value const & { return op(0); } inline std::int64_t mode() const { return mode_; } - inline value result() const override { return result_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } private: - value result_; std::int64_t mode_; }; class subgroup_id_inst : public standard_inst<0, 1> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_subgroup_id; } - inline subgroup_id_inst(location const &lc = {}) - : standard_inst{IK_subgroup_id}, result_{make_value(scalar_type::i32)} { + inline static bool classof(inst_node const &i) { return i.type_id() == IK::subgroup_id; } + inline subgroup_id_inst(location const &lc = {}) : standard_inst{IK::subgroup_id} { loc(lc); + result(0) = make_value(scalar_type::i32); } - inline value result() const override { return result_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::spmd; } - - private: - value result_; }; class subgroup_local_id_inst : public standard_inst<0, 1> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_subgroup_local_id; } - inline subgroup_local_id_inst(location const &lc = {}) - : standard_inst{IK_subgroup_local_id}, result_{make_value(scalar_type::i32)} { + inline static bool classof(inst_node const &i) { return i.type_id() == IK::subgroup_local_id; } + inline subgroup_local_id_inst(location const &lc = {}) : standard_inst{IK::subgroup_local_id} { loc(lc); + result(0) = make_value(scalar_type::i32); } - inline value result() const override { return result_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::spmd; } - - private: - value result_; }; class subgroup_size_inst : public standard_inst<0, 1> { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_subgroup_size; } - inline subgroup_size_inst(location const &lc = {}) - : standard_inst{IK_subgroup_size}, result_{make_value(scalar_type::i32)} { + inline static bool classof(inst_node const &i) { return i.type_id() == IK::subgroup_size; } + inline subgroup_size_inst(location const &lc = {}) : standard_inst{IK::subgroup_size} { loc(lc); + result(0) = make_value(scalar_type::i32); } - inline value result() const override { return result_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } - - private: - value result_; }; class subview_inst : public standard_inst { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_subview; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::subview; } subview_inst(value op, std::vector const &offset_list, std::vector const &size_list, location const &lc = {}); @@ -622,29 +631,22 @@ class subview_inst : public standard_inst { return operands() | std::views::drop(1) | std::views::take(num_indices()); } inline auto size_list() const { return operands() | std::views::drop(1 + num_indices()); } - inline value result() const override { return result_; } - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } - - private: - value result_; }; class store_inst : public standard_inst { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_store; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::store; } enum op_number { op_val = 0, op_operand = 1 }; store_inst(value val, value op, std::vector const &index_list, location const &lc = {}); inline auto val() const -> value const & { return op(op_val); } inline auto operand() const -> value const & { return op(op_operand); } inline auto index_list() const { return operands() | std::views::drop(2); } - inline value result() const override { return {}; } - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } }; class sum_inst : public blas_a2_inst { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_sum_blas_a2; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::sum_blas_a2; } sum_inst(transpose tA, value alpha, value A, value beta, value B, bool atomic = false, location const &lc = {}); @@ -656,16 +658,14 @@ class sum_inst : public blas_a2_inst { class yield_inst : public standard_inst { public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK_yield; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::yield; } inline yield_inst(std::vector const &vals, location const &lc = {}) - : standard_inst{IK_yield, static_cast(vals.size())} { + : standard_inst{IK::yield, static_cast(vals.size())} { loc(lc); for (std::size_t i = 0; i < vals.size(); ++i) { op(i) = vals[i]; } } - inline value result() const override { return value{}; } - inline inst_execution_kind kind() const override { return inst_execution_kind::mixed; } }; } // namespace tinytc diff --git a/src/node/program_node.hpp b/src/node/program_node.hpp index 7136f872..9050d1d5 100644 --- a/src/node/program_node.hpp +++ b/src/node/program_node.hpp @@ -14,22 +14,22 @@ #include namespace tinytc { +enum class PK { prog }; using program_nodes = type_list; -} +} // namespace tinytc struct tinytc_prog : tinytc::reference_counted { public: - enum prog_kind { PK_prog }; using leaves = tinytc::program_nodes; - inline tinytc_prog(std::int64_t tid) : tid_(tid) {} - inline auto type_id() const -> std::int64_t { return tid_; } + inline tinytc_prog(tinytc::PK tid) : tid_(tid) {} + inline auto type_id() const -> tinytc::PK { return tid_; } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } private: - std::int64_t tid_; + tinytc::PK tid_; tinytc::location loc_; }; @@ -39,9 +39,9 @@ using program_node = ::tinytc_prog; class program : public program_node { public: - inline static bool classof(program_node const &p) { return p.type_id() == PK_prog; } + inline static bool classof(program_node const &p) { return p.type_id() == PK::prog; } inline program(std::vector decls, location const &lc = {}) - : program_node(PK_prog), decls_(std::move(decls)) { + : program_node(PK::prog), decls_(std::move(decls)) { loc(lc); } inline auto declarations() -> std::vector & { return decls_; } diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp index c100b54c..a9dca35d 100644 --- a/src/node/region_node.hpp +++ b/src/node/region_node.hpp @@ -13,22 +13,22 @@ #include namespace tinytc { +enum class RK { rgn }; using region_nodes = type_list; -} +} // namespace tinytc struct tinytc_region : tinytc::reference_counted { public: - enum region_kind { RK_rgn }; using leaves = tinytc::region_nodes; - inline tinytc_region(std::int64_t tid) : tid_(tid) {} - inline auto type_id() const -> std::int64_t { return tid_; } + inline tinytc_region(tinytc::RK tid) : tid_(tid) {} + inline auto type_id() const -> tinytc::RK { return tid_; } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } private: - std::int64_t tid_; + tinytc::RK tid_; tinytc::location loc_; }; @@ -38,9 +38,9 @@ using region_node = ::tinytc_region; class rgn : public region_node { public: - inline static bool classof(region_node const &r) { return r.type_id() == RK_rgn; } + inline static bool classof(region_node const &r) { return r.type_id() == RK::rgn; } inline rgn(std::vector insts = {}, location const &lc = {}) - : region_node(RK_rgn), insts_(std::move(insts)) { + : region_node(RK::rgn), insts_(std::move(insts)) { loc(lc); } diff --git a/src/node/value_node.hpp b/src/node/value_node.hpp index de342b35..6e3cf561 100644 --- a/src/node/value_node.hpp +++ b/src/node/value_node.hpp @@ -13,17 +13,17 @@ #include namespace tinytc { +enum class VK { float_, int_, val }; using value_nodes = type_list; -} +} // namespace tinytc struct tinytc_value : tinytc::reference_counted { public: - enum value_kind { VK_float, VK_int, VK_val }; using leaves = tinytc::value_nodes; - inline tinytc_value(std::int64_t tid) : tid_(tid) {} + inline tinytc_value(tinytc::VK tid) : tid_(tid) {} inline virtual ~tinytc_value() {} - inline auto type_id() const -> std::int64_t { return tid_; } + inline auto type_id() const -> tinytc::VK { return tid_; } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } @@ -35,7 +35,7 @@ struct tinytc_value : tinytc::reference_counted { virtual auto has_name() const -> bool = 0; private: - std::int64_t tid_; + tinytc::VK tid_; tinytc::location loc_; }; @@ -45,9 +45,9 @@ using value_node = ::tinytc_value; class float_imm : public value_node { public: - inline static bool classof(value_node const &v) { return v.type_id() == VK_float; } + inline static bool classof(value_node const &v) { return v.type_id() == VK::float_; } inline float_imm(double v, scalar_type ty = scalar_type::f64, location const &lc = {}) - : value_node(VK_float), ty_{make_scalar(ty)}, value_(v) { + : value_node(VK::float_), ty_{make_scalar(ty)}, value_(v) { loc(lc); } @@ -66,9 +66,9 @@ class float_imm : public value_node { class int_imm : public value_node { public: - inline static bool classof(value_node const &v) { return v.type_id() == VK_int; } + inline static bool classof(value_node const &v) { return v.type_id() == VK::int_; } inline int_imm(std::int64_t v, scalar_type ty = scalar_type::i64, location const &lc = {}) - : value_node(VK_int), ty_{make_scalar(ty)}, value_(v) { + : value_node(VK::int_), ty_{make_scalar(ty)}, value_(v) { loc(lc); } @@ -87,8 +87,8 @@ class int_imm : public value_node { class val : public value_node { public: - inline static bool classof(value_node const &v) { return v.type_id() == VK_val; } - inline val(data_type ty, location const &lc = {}) : value_node(VK_val), ty_(std::move(ty)) { + inline static bool classof(value_node const &v) { return v.type_id() == VK::val; } + inline val(data_type ty, location const &lc = {}) : value_node(VK::val), ty_(std::move(ty)) { loc(lc); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 41ac60c7..28f761d4 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -595,13 +595,13 @@ var_definition: $$->result()->name($identifier_list[0]); ctx.val($identifier_list[0], $$->result(), @identifier_list); } else { - auto results = $$->results(); - if (results.size() != $identifier_list.size()) { + auto results = $$->result_begin(); + if ($$->num_results() != static_cast($identifier_list.size())) { throw syntax_error( @identifier_list, "Number of identifiers does not equal number of returned values"); } - for (std::size_t i = 0; i < results.size(); ++i) { + for (std::int64_t i = 0; i < $$->num_results(); ++i) { results[i]->name($identifier_list[i]); ctx.val($identifier_list[i], results[i], @identifier_list); } diff --git a/src/support/visit.hpp b/src/support/visit.hpp index 2c4c9231..a49b2042 100644 --- a/src/support/visit.hpp +++ b/src/support/visit.hpp @@ -42,10 +42,13 @@ constexpr auto unflatten(std::index_sequence) { } } // namespace detail +template +concept type_id_return_type = std::is_integral_v || std::is_enum_v; + template concept visitable = requires(T ty) { typename T::leaves; - { ty.type_id() } -> std::integral; + { ty.type_id() } -> type_id_return_type; }; /** diff --git a/src/visitor/opencl_ast.cpp b/src/visitor/opencl_ast.cpp index d33b411b..aebf3568 100644 --- a/src/visitor/opencl_ast.cpp +++ b/src/visitor/opencl_ast.cpp @@ -841,7 +841,7 @@ std::vector opencl_ast::operator()(hadamard_inst const &g) { std::vector opencl_ast::operator()(if_inst const &in) { auto clinst = std::vector{}; yielded_vars_.push_back(std::vector{}); - for (auto const &r : in.results_ref()) { + for (auto const &r : in.results()) { auto v = declare(*r); clinst.emplace_back(clir::declaration(visit(*this, *r->ty()), v)); yielded_vars_.back().emplace_back(std::move(v)); From d7dd527e3cdbd20d6072dbd9f04808dd6f45a48e Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 11 Sep 2024 19:10:25 +0200 Subject: [PATCH 014/297] Big ongoing refactoring effort of compiler passes Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.yaml | 1 - docs/api/builder_cxxapi.yaml | 1 - docs/api/core_capi.yaml | 2 + docs/api/core_cxxapi.yaml | 2 + include/tinytc/tinytc.h | 52 ++- include/tinytc/tinytc.hpp | 51 ++- src/CMakeLists.txt | 27 +- src/codegen_tools.cpp | 162 +++++++- src/codegen_tools.hpp | 25 ++ src/compiler.cpp | 58 ++- src/func.cpp | 35 +- src/node/data_type_node.hpp | 16 +- src/node/function_node.hpp | 80 ++-- src/node/program_node.hpp | 47 +-- src/node/region_node.hpp | 52 ++- src/parser/parse_context.cpp | 12 +- src/parser/parse_context.hpp | 6 +- src/parser/parser_impl.yy | 6 +- src/{visitor => pass}/aa_results.cpp | 2 +- src/{visitor => pass}/aa_results.hpp | 0 src/{visitor => pass}/alias_analysis.cpp | 2 +- src/{visitor => pass}/alias_analysis.hpp | 2 +- src/pass/check_ir.cpp | 39 ++ src/pass/check_ir.hpp | 24 ++ src/pass/constant_propagation.cpp | 172 +++++++++ src/pass/constant_propagation.hpp | 40 ++ src/pass/dump_ir.cpp | 388 +++++++++++++++++++ src/{visitor => pass}/dump_ir.hpp | 21 +- src/{visitor => pass}/equal.cpp | 2 +- src/{visitor => pass}/equal.hpp | 0 src/{visitor => pass}/insert_barrier.cpp | 8 +- src/{visitor => pass}/insert_barrier.hpp | 2 +- src/{visitor => pass}/lifetime_analysis.cpp | 8 +- src/{visitor => pass}/lifetime_analysis.hpp | 2 +- src/pass/lower_linalg.cpp | 194 ++++++++++ src/pass/lower_linalg.hpp | 63 +++ src/{visitor => pass}/metadata.cpp | 8 +- src/{visitor => pass}/metadata.hpp | 1 - src/{visitor => pass}/opencl_ast.cpp | 10 +- src/{visitor => pass}/opencl_ast.hpp | 0 src/pass/slot_tracker.cpp | 37 ++ src/{visitor => pass}/slot_tracker.hpp | 16 +- src/{visitor => pass}/stack.cpp | 6 +- src/{visitor => pass}/stack.hpp | 0 src/{visitor => pass}/work_group_size.cpp | 6 +- src/{visitor => pass}/work_group_size.hpp | 0 src/passes.cpp | 47 ++- src/passes.def | 5 + src/passes.hpp | 19 +- src/support/walk.cpp | 26 ++ src/support/walk.hpp | 65 ++++ src/visitor/check_ir.cpp | 63 --- src/visitor/check_ir.hpp | 39 -- src/visitor/dump_ir.cpp | 403 -------------------- src/visitor/slot_tracker.cpp | 71 ---- test/CMakeLists.txt | 2 +- test/lit.cfg.py | 1 + test/lit.site.cfg.py.in | 1 + test/{codegen => opt/check-ir}/nesting0.ir | 2 +- test/{codegen => opt/check-ir}/nesting1.ir | 2 +- test/{codegen => opt/check-ir}/nesting2.ir | 2 +- test/{codegen => opt/check-ir}/nesting3.ir | 2 +- tools/CMakeLists.txt | 1 + tools/opt/CMakeLists.txt | 16 + tools/opt/args.cpp | 104 +++++ tools/opt/args.hpp | 26 ++ tools/opt/main.cpp | 55 +++ 67 files changed, 1765 insertions(+), 875 deletions(-) rename src/{visitor => pass}/aa_results.cpp (96%) rename src/{visitor => pass}/aa_results.hpp (100%) rename src/{visitor => pass}/alias_analysis.cpp (98%) rename src/{visitor => pass}/alias_analysis.hpp (97%) create mode 100644 src/pass/check_ir.cpp create mode 100644 src/pass/check_ir.hpp create mode 100644 src/pass/constant_propagation.cpp create mode 100644 src/pass/constant_propagation.hpp create mode 100644 src/pass/dump_ir.cpp rename src/{visitor => pass}/dump_ir.hpp (87%) rename src/{visitor => pass}/equal.cpp (96%) rename src/{visitor => pass}/equal.hpp (100%) rename src/{visitor => pass}/insert_barrier.cpp (96%) rename src/{visitor => pass}/insert_barrier.hpp (98%) rename src/{visitor => pass}/lifetime_analysis.cpp (97%) rename src/{visitor => pass}/lifetime_analysis.hpp (98%) create mode 100644 src/pass/lower_linalg.cpp create mode 100644 src/pass/lower_linalg.hpp rename src/{visitor => pass}/metadata.cpp (79%) rename src/{visitor => pass}/metadata.hpp (94%) rename src/{visitor => pass}/opencl_ast.cpp (99%) rename src/{visitor => pass}/opencl_ast.hpp (100%) create mode 100644 src/pass/slot_tracker.cpp rename src/{visitor => pass}/slot_tracker.hpp (61%) rename src/{visitor => pass}/stack.cpp (95%) rename src/{visitor => pass}/stack.hpp (100%) rename src/{visitor => pass}/work_group_size.cpp (97%) rename src/{visitor => pass}/work_group_size.hpp (100%) create mode 100644 src/passes.def create mode 100644 src/support/walk.cpp create mode 100644 src/support/walk.hpp delete mode 100644 src/visitor/check_ir.cpp delete mode 100644 src/visitor/check_ir.hpp delete mode 100644 src/visitor/dump_ir.cpp delete mode 100644 src/visitor/slot_tracker.cpp rename test/{codegen => opt/check-ir}/nesting0.ir (86%) rename test/{codegen => opt/check-ir}/nesting1.ir (80%) rename test/{codegen => opt/check-ir}/nesting2.ir (77%) rename test/{codegen => opt/check-ir}/nesting3.ir (80%) create mode 100644 tools/opt/CMakeLists.txt create mode 100644 tools/opt/args.cpp create mode 100644 tools/opt/args.hpp create mode 100644 tools/opt/main.cpp diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index 9447ba9b..71b07c27 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -44,7 +44,6 @@ Builder C-API: Function: function: - tinytc_function_create - - tinytc_function_prototype_create - tinytc_function_set_subgroup_size - tinytc_function_set_work_group_size - tinytc_func_release diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index 7b5f1eeb..5cbe0308 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -37,7 +37,6 @@ Builder C++-API: Function: function: - tinytc::make_function - - tinytc::make_function_prototype - tinytc::set_work_group_size - tinytc::set_subgroup_size class: diff --git a/docs/api/core_capi.yaml b/docs/api/core_capi.yaml index 0a2e97c6..7f8a4f9d 100644 --- a/docs/api/core_capi.yaml +++ b/docs/api/core_capi.yaml @@ -40,6 +40,8 @@ Core C-API: enum: - tinytc_bundle_format_t function: + - tinytc_run_function_pass + - tinytc_list_function_passes - tinytc_prog_compile_to_opencl Device Info: enum: diff --git a/docs/api/core_cxxapi.yaml b/docs/api/core_cxxapi.yaml index 990f9cf6..9c0caa41 100644 --- a/docs/api/core_cxxapi.yaml +++ b/docs/api/core_cxxapi.yaml @@ -21,6 +21,8 @@ Core C++-API: - tinytc::binary Compiler: function: + - tinytc::run_function_pass + - tinytc::list_function_passes - tinytc::compile_to_opencl Device Info: enum: diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 1d9200c1..f38b3b8b 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -839,36 +839,21 @@ TINYTC_EXPORT tinytc_status_t tinytc_region_retain(tinytc_region_t reg); //////////////////////////// /** - * @brief Create function prototype + * @brief Create function * * @param fun [out] pointer to the func object created * @param name [in] function name * @param arg_list_size [in] length of argument array - * @param arg_list [in][range(0, arg_list_size)] argument array; can be nullptr if arg_list_size is - * 0 - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_function_prototype_create(tinytc_func_t *fun, char const *name, - uint32_t arg_list_size, - tinytc_value_t *arg_list, - const tinytc_location_t *loc); - -/** - * @brief Create function - * - * @param fun [out] pointer to the func object created - * @param prototype [in] function prototype + * @param arg_list [in][range(0,arg_list_size)] argument array; can be nullptr if arg_list_size is 0 * @param body [in] function body * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_function_create(tinytc_func_t *fun, tinytc_func_t prototype, - tinytc_region_t body, +TINYTC_EXPORT tinytc_status_t tinytc_function_create(tinytc_func_t *fun, char const *name, + uint32_t arg_list_size, + tinytc_value_t *arg_list, tinytc_region_t body, const tinytc_location_t *loc); - /** * @brief Set work-group size * @@ -1229,6 +1214,33 @@ TINYTC_EXPORT tinytc_status_t tinytc_source_context_retain(tinytc_source_context ///////// Compiler ///////// //////////////////////////// +/** + * @brief Run a function pass on every function of a program + * + * @param pass_name [in] name of function pass; cf. tinytc_list_function_passes + * @param prg [inout] tensor program; modified as compiler pass is run + * @param info [in][optional] core info object; might be nullptr if core info is not required for + * pass + * @param ctx [inout][optional] source context object to save extended error messages that are + * enhanced with source code context; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_run_function_pass(char const *pass_name, tinytc_prog_t prg, + const_tinytc_core_info_t info, + tinytc_source_context_t ctx); + +/** + * @brief List function passes + * + * @param names_size [out] pointer to number of function pass names + * @param names [out][range(0,names_size)] pointer to array of C-strings; array owned by tinytc + * + * @return + */ +TINYTC_EXPORT tinytc_status_t tinytc_list_function_passes(uint32_t *names_size, + char const *const **names); + /** * @brief Compile tensor language to OpenCL-C * diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 6fc3f516..6bc03307 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1284,16 +1284,17 @@ constexpr bool func_reinterpret_allowed = } // namespace internal /** - * @brief Make function prototype + * @brief Make function * * @param name Function name * @param arg_list Argument list + * @param body Function body * @param loc Source code location * * @return Function */ -inline func make_function_prototype(char const *name, std::vector &arg_list, - location const &loc = {}) { +inline func make_function(char const *name, std::vector &arg_list, region const &body, + location const &loc = {}) { static_assert(internal::value_reinterpret_allowed); tinytc_func_t fun; auto len = arg_list.size(); @@ -1301,22 +1302,7 @@ inline func make_function_prototype(char const *name, std::vector &arg_li throw std::out_of_range("argument list too long"); } tinytc_value_t *al = reinterpret_cast(arg_list.data()); - CHECK_STATUS_LOC(tinytc_function_prototype_create(&fun, name, len, al, &loc), loc); - return func(fun); -} - -/** - * @brief Make function - * - * @param prototype Function prototype - * @param body Function body - * @param loc Source code location - * - * @return Function - */ -inline func make_function(func const &prototype, region const &body, location const &loc = {}) { - tinytc_func_t fun; - CHECK_STATUS_LOC(tinytc_function_create(&fun, prototype.get(), body.get(), &loc), loc); + CHECK_STATUS_LOC(tinytc_function_create(&fun, name, len, al, body.get(), &loc), loc); return func(fun); } @@ -1599,8 +1585,7 @@ class function_builder { * @return Function */ inline func get_product(location const &loc = {}) { - auto proto = make_function_prototype(name_.c_str(), arguments_, loc); - auto fun = make_function(proto, body_); + auto fun = make_function(name_.c_str(), arguments_, body_, loc); if (x_ > 0 && y_ > 0) { set_work_group_size(fun, x_, y_); } @@ -2049,6 +2034,30 @@ inline auto make_binary(bundle_format format, std::size_t data_size, std::uint8_ return binary{bin}; } +/** + * @brief Run a function pass on every function of a program + * + * @param pass_name name of function pass; cf. list_function_passes + * @param prg tensor program; modified as compiler pass is run + * @param info core info object; might be nullptr if core info is not required for pass + * @param ctx source context object to save extended error messages that are + * enhanced with source code context + */ +inline void run_function_pass(char const *pass_name, prog prg, core_info info = {}, + source_context ctx = {}) { + CHECK_STATUS(tinytc_run_function_pass(pass_name, prg.get(), info.get(), ctx.get())); +} + +/** + * @brief Get function pass names + * + * @param names_size Number of function pass names + * @param names Array of function pass names + */ +inline void list_function_passes(std::uint32_t &names_size, char const *const *&names) { + CHECK_STATUS(tinytc_list_function_passes(&names_size, &names)); +} + /** * @brief Compile program to OpenCL-C * diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7243d976..7e2ba6e9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -42,18 +42,21 @@ set(SOURCES source.cpp tiling.cpp value.cpp - visitor/aa_results.cpp - visitor/alias_analysis.cpp - visitor/check_ir.cpp - visitor/dump_ir.cpp - visitor/equal.cpp - visitor/insert_barrier.cpp - visitor/lifetime_analysis.cpp - visitor/metadata.cpp - visitor/opencl_ast.cpp - visitor/slot_tracker.cpp - visitor/stack.cpp - visitor/work_group_size.cpp + #pass/aa_results.cpp + #pass/alias_analysis.cpp + pass/check_ir.cpp + #pass/constant_propagation.cpp + pass/dump_ir.cpp + pass/equal.cpp + #pass/insert_barrier.cpp + #pass/lifetime_analysis.cpp + #pass/lower_linalg.cpp + #pass/metadata.cpp + #pass/opencl_ast.cpp + pass/slot_tracker.cpp + #pass/stack.cpp + #pass/work_group_size.cpp + support/walk.cpp ) set(RE2C_SOURCES parser/lexer.re diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index cdad4a86..b79b130d 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -3,8 +3,10 @@ #include "codegen_tools.hpp" #include "error.hpp" +#include "node/value_node.hpp" #include "scalar_type.hpp" -#include "util.hpp" +#include "support/util.hpp" +#include "support/visit.hpp" #include #include @@ -428,4 +430,162 @@ void write_matrix_block(block_builder &bb, block_accessor const &block, } } +void tile_loop_by_sgs_new(region_builder &bb, value const &loop_trip_count, int sgs, int num_tiles, + value const &sg_id, sgs_loop_body_builder_new const &body) { + visit(overloaded{ + [&](int_imm &c) { + tile_loop_by_sgs_new_constant(bb, c.value(), sgs, num_tiles, std::move(sg_id), + body); + }, + [&](auto &) { + tile_loop_by_sgs_new_dynamic(bb, std::move(loop_trip_count), sgs, num_tiles, + std::move(sg_id), body); + }, + }, + *loop_trip_count); +} + +void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_count, int sgs, + int num_tiles, value const &sg_id, + sgs_loop_body_builder_new const &body) { + std::int64_t blocks = loop_trip_count / sgs; + std::int64_t rem = loop_trip_count % sgs; + + auto sg_id_index = bb.add(make_cast(sg_id, scalar_type::index)); + if (blocks > 0) { + auto block_start = bb.add(make_arith(arithmetic::mul, make_index(sgs), sg_id_index)); + auto block_end = make_index(sgs * blocks); + auto step = make_index(sgs * num_tiles); + bb.for_loop( + scalar_type::index, std::move(block_start), std::move(block_end), std::move(step), + [&](region_builder &bb, value const &block) { + body(bb, block, false, make_index(sgs)); + }, + "block"); + } + + if (rem > 0) { + auto condition = + bb.add(make_cmp(cmp_condition::eq, sg_id_index, make_index(num_tiles - 1))); + bb.if_condition(condition, [&](region_builder &bb) { + body(bb, make_index(blocks * sgs), true, make_index(rem)); + }); + } +} + +void tile_loop_by_sgs_new_dynamic(region_builder &bb, value const &loop_trip_count, int sgs, + int num_tiles, value const &sg_id, + sgs_loop_body_builder_new const &body) { + auto blocks = bb.add(make_arith(arithmetic::div, loop_trip_count, make_index(sgs))); + auto rem = bb.add(make_arith(arithmetic::rem, loop_trip_count, make_index(sgs))); + + auto sg_id_index = bb.add(make_cast(sg_id, scalar_type::index)); + auto block_start = bb.add(make_arith(arithmetic::mul, make_index(sgs), sg_id_index)); + auto block_end = bb.add(make_arith(arithmetic::mul, make_index(sgs), blocks)); + auto step = make_index(sgs * num_tiles); + bb.for_loop( + scalar_type::index, std::move(block_start), std::move(block_end), std::move(step), + [&](region_builder &bb, value const &block) { body(bb, block, false, make_index(sgs)); }, + "block"); + + auto condition0 = bb.add(make_cmp(cmp_condition::gt, rem, make_index(0))); + bb.if_condition(condition0, [&](region_builder &bb) { + auto condition1 = + bb.add(make_cmp(cmp_condition::eq, sg_id_index, make_index(num_tiles - 1))); + bb.if_condition(condition1, [&](region_builder &bb) { + auto block = bb.add(make_arith(arithmetic::mul, blocks, make_index(sgs))); + body(bb, block, true, rem); + }); + }); +} + +void tile_loop_uniformly_new(region_builder &bb, value const &loop_trip_count, int block_size, + int num_tiles, value const &sg_id, + uniform_loop_body_builder_new const &body) { + visit( + overloaded{ + //[&](int_imm &c) { + // tile_loop_uniformly_new_constant(bb, c.value(), block_size, num_tiles, + // std::move(sg_id), body); + //}, + [&](auto &) { + tile_loop_uniformly_new_dynamic(bb, std::move(loop_trip_count), block_size, + num_tiles, std::move(sg_id), body); + }, + }, + *loop_trip_count); +} + +void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip_count, + int block_size, int num_tiles, value const &sg_id, + uniform_loop_body_builder_new const &body) { + // Find minimum number of blocks such that the block sizes are smaller or equal block_size + std::int64_t blocks = 1 + (loop_trip_count - 1) / block_size; + // Increase the number of blocks if such that the number of blocks is a multiple + // of the number of tiles + blocks = (1 + (blocks - 1) / num_tiles) * num_tiles; + std::int64_t bs = loop_trip_count / blocks; + std::int64_t bs_1 = bs + 1; + std::int64_t rem = loop_trip_count % blocks; + + auto sg_id_index = bb.add(make_cast(sg_id, scalar_type::index)); + if (rem > 0) { + auto block_start = bb.add(make_arith(arithmetic::mul, make_index(bs_1), sg_id_index)); + auto block_end = make_index(bs_1 * rem); + auto step = make_index(bs_1 * num_tiles); + bb.for_loop( + scalar_type::index, std::move(block_start), std::move(block_end), std::move(step), + [&](region_builder &bb, value const &block) { body(bb, block, make_index(bs_1)); }, + "block"); + } + + auto tmp = bb.add(make_arith(arithmetic::add, sg_id_index, make_index(rem % num_tiles))); + auto sg_id_1 = bb.add(make_arith(arithmetic::rem, tmp, make_index(num_tiles))); + auto tmp2 = bb.add(make_arith(arithmetic::mul, make_index(bs), sg_id_1)); + auto block_start = bb.add(make_arith(arithmetic::add, make_index(bs_1 * rem), tmp2)); + auto block_end = make_index(loop_trip_count); + auto step = make_index(bs * num_tiles); + bb.for_loop( + scalar_type::index, std::move(block_start), std::move(block_end), std::move(step), + [&](region_builder &bb, value const &block) { body(bb, block, make_index(bs)); }, "block"); +} + +void tile_loop_uniformly_new_dynamic(region_builder &bb, value const &loop_trip_count, + int block_size, int num_tiles, value const &sg_id, + uniform_loop_body_builder_new const &body) { + auto blocks0 = bb.add(make_arith(arithmetic::sub, loop_trip_count, make_index(1))); + auto blocks1 = bb.add(make_arith(arithmetic::div, blocks0, make_index(block_size))); + auto blocks2 = bb.add(make_arith(arithmetic::add, make_index(1), blocks1)); + auto blocks3 = bb.add(make_arith(arithmetic::sub, blocks2, make_index(1))); + auto blocks4 = bb.add(make_arith(arithmetic::div, blocks3, make_index(num_tiles))); + auto blocks5 = bb.add(make_arith(arithmetic::add, make_index(1), blocks4)); + auto blocks = bb.add(make_arith(arithmetic::mul, blocks5, make_index(num_tiles))); + blocks->name("blocks"); + auto bs = bb.add(make_arith(arithmetic::div, loop_trip_count, blocks)); + bs->name("bs"); + auto bs_1 = bb.add(make_arith(arithmetic::add, bs, make_index(1))); + bs_1->name("bs_1"); + auto rem = bb.add(make_arith(arithmetic::rem, loop_trip_count, blocks)); + rem->name("rem"); + + auto sg_id_index = bb.add(make_cast(sg_id, scalar_type::index)); + auto block_start_1 = bb.add(make_arith(arithmetic::mul, bs_1, sg_id_index)); + auto block_end_1 = bb.add(make_arith(arithmetic::mul, bs_1, rem)); + auto step_1 = bb.add(make_arith(arithmetic::mul, bs_1, make_index(num_tiles))); + bb.for_loop( + scalar_type::index, std::move(block_start_1), std::move(block_end_1), std::move(step_1), + [&](region_builder &bb, value const &block) { body(bb, block, bs_1); }, "block"); + + auto tmp0 = bb.add(make_arith(arithmetic::rem, rem, make_index(num_tiles))); + auto tmp1 = bb.add(make_arith(arithmetic::add, sg_id_index, tmp0)); + auto sg_id_1 = bb.add(make_arith(arithmetic::rem, tmp1, make_index(num_tiles))); + auto tmp2 = bb.add(make_arith(arithmetic::mul, bs, sg_id_1)); + auto tmp3 = bb.add(make_arith(arithmetic::mul, bs_1, rem)); + auto block_start = bb.add(make_arith(arithmetic::add, tmp3, tmp2)); + auto step = bb.add(make_arith(arithmetic::mul, bs, make_index(num_tiles))); + bb.for_loop( + scalar_type::index, std::move(block_start), loop_trip_count, std::move(step), + [&](region_builder &bb, value const &block) { body(bb, block, bs); }, "block"); +} + } // namespace tinytc diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index f27ee98b..a7037bc9 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -5,6 +5,7 @@ #define CODEGEN_TOOLS_20240229_HPP #include "device_info.hpp" +#include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" #include @@ -121,6 +122,30 @@ void write_matrix_block(clir::block_builder &bb, block_accessor const &block, matrix_block_description const &d, bool is_atomic, scalar_type beta_ty, clir::expr beta, core_config const &core_cfg); +using sgs_loop_body_builder_new = + std::function; +using uniform_loop_body_builder_new = + std::function; + +void tile_loop_by_sgs_new(region_builder &bb, value const &loop_trip_count, int sgs, int num_tiles, + value const &sg_id, sgs_loop_body_builder_new const &body); +void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_count, int sgs, + int num_tiles, value const &sg_id, + sgs_loop_body_builder_new const &body); +void tile_loop_by_sgs_new_dynamic(region_builder &bb, value const &loop_trip_count, int sgs, + int num_tiles, value const &sg_id, + sgs_loop_body_builder_new const &body); + +void tile_loop_uniformly_new(region_builder &bb, value const &loop_trip_count, int block_size, + int num_tiles, value const &sg_id, + uniform_loop_body_builder_new const &body); +void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip_count, + int block_size, int num_tiles, value const &sg_id, + uniform_loop_body_builder_new const &body); +void tile_loop_uniformly_new_dynamic(region_builder &bb, value const &loop_trip_count, + int block_size, int num_tiles, value const &sg_id, + uniform_loop_body_builder_new const &body); + } // namespace tinytc #endif // CODEGEN_TOOLS_20240229_HPP diff --git a/src/compiler.cpp b/src/compiler.cpp index 8493d738..74bcec12 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -5,6 +5,8 @@ #include "error.hpp" #include "node/program_node.hpp" #include "parser.hpp" +#include "pass/check_ir.hpp" +#include "pass/dump_ir.hpp" #include "passes.hpp" #include "reference_counted.hpp" #include "required_extensions.hpp" @@ -15,15 +17,51 @@ #include #include +#include #include #include #include #include +#include + using namespace tinytc; extern "C" { +tinytc_status_t tinytc_run_function_pass(char const *pass_name, tinytc_prog_t prg, + const_tinytc_core_info_t info, + tinytc_source_context_t ctx) { + if (prg == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { +#define FUNCTION_PASS(NAME, CREATE_PASS) \ + if (strcmp(NAME, pass_name) == 0) { \ + return run_function_pass(CREATE_PASS, *prg); \ + } +#include "passes.def" +#undef FUNCTION_PASS + }, + ctx); +} + +tinytc_status_t tinytc_list_function_passes(uint32_t *names_size, char const *const **names) { + if (names_size == nullptr || names == nullptr) { + return tinytc_status_invalid_arguments; + } +#define FUNCTION_PASS(NAME, CREATE_PASS) NAME, + static char const *const pass_names[] = { +#include "passes.def" + }; +#undef FUNCTION_PASS + *names_size = sizeof(pass_names) / sizeof(char const *); + *names = pass_names; + + return tinytc_status_success; +} + tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_t prg, const_tinytc_core_info_t info, tinytc_source_context_t ctx) { @@ -33,13 +71,17 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ return exception_to_status_code( [&] { // passes - check_ir(*prg); - insert_lifetime_stop_inst(*prg); - set_stack_ptrs(*prg); - insert_barriers(*prg); - set_work_group_size(*prg, *info); - // opencl - auto ast = generate_opencl_ast(*prg, *info); + run_function_pass(check_ir_pass{}, *prg); + // insert_lifetime_stop_inst(*prg); + // set_stack_ptrs(*prg); + // insert_barriers(*prg); + // set_work_group_size(*prg, *info); + // lower_linalg(*prg, *info); + run_function_pass(dump_ir_pass{std::cout}, *prg); + // propagate_constants(*prg); + // dump_ir(std::cout, *prg); + // opencl + /*auto ast = generate_opencl_ast(*prg, *info); clir::make_names_unique(ast); auto oss = std::ostringstream{}; @@ -52,7 +94,7 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ *src = std::make_unique<::tinytc_source>(oss.str(), prg->loc(), std::move(ext), info->core_features()) - .release(); + .release();*/ }, ctx); } diff --git a/src/func.cpp b/src/func.cpp index 5cb2ea46..fcb12ac2 100644 --- a/src/func.cpp +++ b/src/func.cpp @@ -19,10 +19,10 @@ using namespace tinytc; extern "C" { -tinytc_status_t tinytc_function_prototype_create(tinytc_func_t *fun, char const *name, - uint32_t arg_list_size, tinytc_value_t *arg_list, - const tinytc_location_t *loc) { - if (fun == nullptr || (arg_list_size > 0 && arg_list == nullptr)) { +tinytc_status_t tinytc_function_create(tinytc_func_t *fun, char const *name, uint32_t arg_list_size, + tinytc_value_t *arg_list, tinytc_region_t body, + const tinytc_location_t *loc) { + if (fun == nullptr || (arg_list_size > 0 && arg_list == nullptr) || body == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { @@ -31,37 +31,18 @@ tinytc_status_t tinytc_function_prototype_create(tinytc_func_t *fun, char const for (uint32_t i = 0; i < arg_list_size; ++i) { arg_vec.emplace_back(value(arg_list[i], true)); } - *fun = std::make_unique(std::string(name), std::move(arg_vec), get_optional(loc)) + *fun = std::make_unique(std::string(name), std::move(arg_vec), region{body, true}, + get_optional(loc)) .release(); }); } -tinytc_status_t tinytc_function_create(tinytc_func_t *fun, tinytc_func_t prototype, - tinytc_region_t body, const tinytc_location_t *loc) { - if (fun == nullptr || prototype == nullptr || body == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *fun = - std::make_unique(func{prototype, true}, region{body, true}, get_optional(loc)) - .release(); - }); -} - tinytc_status_t tinytc_function_set_work_group_size(tinytc_func_t fun, int32_t x, int32_t y) { - function *f = dyn_cast(fun); - if (f == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { f->work_group_size({x, y}); }); + return exception_to_status_code([&] { fun->work_group_size({x, y}); }); } tinytc_status_t tinytc_function_set_subgroup_size(tinytc_func_t fun, int32_t sgs) { - function *f = dyn_cast(fun); - if (f == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { f->subgroup_size(sgs); }); + return exception_to_status_code([&] { fun->subgroup_size(sgs); }); } tinytc_status_t tinytc_func_release(tinytc_func_t obj) { diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp index c4bd0c8e..74d7d83e 100644 --- a/src/node/data_type_node.hpp +++ b/src/node/data_type_node.hpp @@ -20,8 +20,8 @@ namespace tinytc { enum class DTK { group, memref, scalar, void_ }; -using data_type_nodes = type_list; +using data_type_nodes = type_list; } // namespace tinytc struct tinytc_data_type : tinytc::reference_counted { @@ -59,12 +59,6 @@ class group_data_type : public data_type_node { std::int64_t offset_; }; -class void_data_type : public data_type_node { - public: - inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::void_; } - inline void_data_type() : data_type_node(DTK::void_) {} -}; - class memref_data_type : public data_type_node { public: inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::memref; } @@ -119,6 +113,12 @@ class scalar_data_type : public data_type_node { scalar_type ty_; }; +class void_data_type : public data_type_node { + public: + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::void_; } + inline void_data_type() : data_type_node(DTK::void_) {} +}; + } // namespace tinytc #endif // DATA_TYPE_NODE_20230309_HPP diff --git a/src/node/function_node.hpp b/src/node/function_node.hpp index aed5717a..99be945c 100644 --- a/src/node/function_node.hpp +++ b/src/node/function_node.hpp @@ -6,7 +6,7 @@ #include "location.hpp" #include "reference_counted.hpp" -#include "support/type_list.hpp" +#include "support/util.hpp" #include "tinytc/tinytc.hpp" #include @@ -16,63 +16,43 @@ #include namespace tinytc { -enum class FK { function, prototype }; -using function_nodes = type_list; +using value_range = iterator_range_wrapper; +using const_value_range = iterator_range_wrapper; } // namespace tinytc struct tinytc_func : tinytc::reference_counted { public: - using leaves = tinytc::function_nodes; - - inline tinytc_func(tinytc::FK tid) : tid_(tid) {} - inline virtual ~tinytc_func() {} - inline auto type_id() const -> tinytc::FK { return tid_; } + inline tinytc_func(std::string name, std::vector args, tinytc::region body, + tinytc::location const &lc = {}) + : name_(std::move(name)), args_(std::move(args)), body_(std::move(body)), + work_group_size_{0, 0}, subgroup_size_{0} { + loc(lc); + } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } - virtual auto name() const -> std::string_view = 0; - - private: - tinytc::FK tid_; - tinytc::location loc_; -}; - -namespace tinytc { - -using function_node = ::tinytc_func; - -class prototype : public function_node { - public: - inline static bool classof(function_node const &f) { return f.type_id() == FK::prototype; } - inline prototype(std::string name, std::vector args = {}, location const &lc = {}) - : function_node(FK::prototype), name_(std::move(name)), args_(std::move(args)) { - loc(lc); + inline auto arg_begin() -> tinytc::value * { return args_.size() > 0 ? args_.data() : nullptr; } + inline auto arg_end() -> tinytc::value * { + return args_.size() > 0 ? args_.data() + args_.size() : nullptr; } - - inline auto name() const -> std::string_view override { return name_; } - inline auto args() const -> std::vector const & { return args_; } - - private: - std::string name_; - std::vector args_; -}; - -class function : public function_node { - public: - inline static bool classof(function_node const &f) { return f.type_id() == FK::function; } - inline function(func prototype, region body, location const &lc = {}) - : function_node(FK::function), prototype_(std::move(prototype)), body_(std::move(body)), - work_group_size_{0, 0}, subgroup_size_{0} { - loc(lc); + inline auto args() -> tinytc::value_range { + return tinytc::value_range{arg_begin(), arg_end()}; + } + inline auto arg_begin() const -> tinytc::value const * { + return args_.size() > 0 ? args_.data() : nullptr; + } + inline auto arg_end() const -> tinytc::value const * { + return args_.size() > 0 ? args_.data() + args_.size() : nullptr; + } + inline auto args() const -> tinytc::const_value_range { + return tinytc::const_value_range{arg_begin(), arg_end()}; } - inline auto name() const -> std::string_view override { return prototype_->name(); } + inline auto name() const -> std::string_view { return name_; } + inline auto body() const -> tinytc::region const & { return body_; } - inline auto prototype() const -> func const & { return prototype_; } - inline auto body() const -> region const & { return body_; } inline auto work_group_size() const -> std::array { return work_group_size_; } - inline void work_group_size(std::array const &work_group_size) { work_group_size_ = work_group_size; } @@ -80,12 +60,18 @@ class function : public function_node { inline void subgroup_size(std::int32_t subgroup_size) { subgroup_size_ = subgroup_size; } private: - func prototype_; - region body_; + std::string name_; + std::vector args_; + tinytc::region body_; std::array work_group_size_; std::int32_t subgroup_size_; + tinytc::location loc_; }; +namespace tinytc { + +using function = ::tinytc_func; + } // namespace tinytc #endif // FUNCTION_NODE_20230310_HPP diff --git a/src/node/program_node.hpp b/src/node/program_node.hpp index 9050d1d5..aa3a3dc1 100644 --- a/src/node/program_node.hpp +++ b/src/node/program_node.hpp @@ -6,7 +6,7 @@ #include "location.hpp" #include "reference_counted.hpp" -#include "support/type_list.hpp" +#include "support/util.hpp" #include "tinytc/tinytc.hpp" #include @@ -14,42 +14,43 @@ #include namespace tinytc { -enum class PK { prog }; -using program_nodes = type_list; +using func_range = iterator_range_wrapper; +using const_func_range = iterator_range_wrapper; } // namespace tinytc struct tinytc_prog : tinytc::reference_counted { public: - using leaves = tinytc::program_nodes; - - inline tinytc_prog(tinytc::PK tid) : tid_(tid) {} - inline auto type_id() const -> tinytc::PK { return tid_; } + inline tinytc_prog(std::vector funcs, tinytc::location const &lc = {}) + : funcs_(std::move(funcs)) { + loc(lc); + } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } + inline auto begin() -> tinytc::func * { return funcs_.size() > 0 ? funcs_.data() : nullptr; } + inline auto end() -> tinytc::func * { + return funcs_.size() > 0 ? funcs_.data() + funcs_.size() : nullptr; + } + inline auto functions() -> tinytc::func_range { return tinytc::func_range{begin(), end()}; } + inline auto begin() const -> tinytc::func const * { + return funcs_.size() > 0 ? funcs_.data() : nullptr; + } + inline auto end() const -> tinytc::func const * { + return funcs_.size() > 0 ? funcs_.data() + funcs_.size() : nullptr; + } + inline auto functions() const -> tinytc::const_func_range { + return tinytc::const_func_range{begin(), end()}; + } + private: - tinytc::PK tid_; + std::vector funcs_; tinytc::location loc_; }; namespace tinytc { -using program_node = ::tinytc_prog; - -class program : public program_node { - public: - inline static bool classof(program_node const &p) { return p.type_id() == PK::prog; } - inline program(std::vector decls, location const &lc = {}) - : program_node(PK::prog), decls_(std::move(decls)) { - loc(lc); - } - inline auto declarations() -> std::vector & { return decls_; } - inline auto declarations() const -> std::vector const & { return decls_; } - - private: - std::vector decls_; -}; +using program = ::tinytc_prog; } // namespace tinytc diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp index a9dca35d..e6ce7945 100644 --- a/src/node/region_node.hpp +++ b/src/node/region_node.hpp @@ -5,7 +5,7 @@ #define REGION_NODE_20230908_HPP #include "reference_counted.hpp" -#include "support/type_list.hpp" +#include "support/util.hpp" #include "tinytc/tinytc.hpp" #include @@ -13,45 +13,43 @@ #include namespace tinytc { -enum class RK { rgn }; -using region_nodes = type_list; +using inst_range = iterator_range_wrapper; +using const_inst_range = iterator_range_wrapper; } // namespace tinytc struct tinytc_region : tinytc::reference_counted { public: - using leaves = tinytc::region_nodes; - - inline tinytc_region(tinytc::RK tid) : tid_(tid) {} - inline auto type_id() const -> tinytc::RK { return tid_; } + inline tinytc_region(std::vector insts = {}, tinytc::location const &lc = {}) + : insts_(std::move(insts)) { + loc(lc); + } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } - private: - tinytc::RK tid_; - tinytc::location loc_; -}; - -namespace tinytc { - -using region_node = ::tinytc_region; - -class rgn : public region_node { - public: - inline static bool classof(region_node const &r) { return r.type_id() == RK::rgn; } - inline rgn(std::vector insts = {}, location const &lc = {}) - : region_node(RK::rgn), insts_(std::move(insts)) { - loc(lc); + inline auto begin() -> tinytc::inst * { return insts_.size() > 0 ? insts_.data() : nullptr; } + inline auto end() -> tinytc::inst * { + return insts_.size() > 0 ? insts_.data() + insts_.size() : nullptr; } - - inline auto insts() -> std::vector & { return insts_; } - inline auto insts() const -> std::vector const & { return insts_; } - inline void insts(std::vector insts) { insts_ = std::move(insts); } + inline auto insts() -> tinytc::inst_range { return tinytc::inst_range{begin(), end()}; } + inline auto begin() const -> tinytc::inst const * { + return insts_.size() > 0 ? insts_.data() : nullptr; + } + inline auto end() const -> tinytc::inst const * { + return insts_.size() > 0 ? insts_.data() + insts_.size() : nullptr; + } + inline auto insts() const -> tinytc::const_inst_range { + return tinytc::const_inst_range{begin(), end()}; + } + inline void insts(std::vector insts) { insts_ = std::move(insts); } private: - std::vector insts_; + std::vector insts_; + tinytc::location loc_; }; +namespace tinytc { +using rgn = ::tinytc_region; } // namespace tinytc #endif // REGION_NODE_20230908_HPP diff --git a/src/parser/parse_context.cpp b/src/parser/parse_context.cpp index 336c0226..5824dbf7 100644 --- a/src/parser/parse_context.cpp +++ b/src/parser/parse_context.cpp @@ -36,17 +36,17 @@ value parse_context::val(std::string const &id, location const &l) { throw parser::syntax_error(l, "Undefined identifier %" + id); } -void parse_context::prototype(std::string const &id, func p) { - if (auto other = prototype_map_.find(id); other != prototype_map_.end()) { +void parse_context::add_function(std::string const &id, func fn) { + if (auto other = function_map_.find(id); other != function_map_.end()) { auto oss = std::ostringstream{}; oss << "Identifier @" << id << " was already used at " << other->second->loc(); - throw parser::syntax_error(p->loc(), oss.str()); + throw parser::syntax_error(fn->loc(), oss.str()); } - prototype_map_[id] = std::move(p); + function_map_[id] = std::move(fn); } -func parse_context::prototype(std::string const &id, location const &l) { - if (auto j = prototype_map_.find(id); j != prototype_map_.end()) { +func parse_context::get_function(std::string const &id, location const &l) { + if (auto j = function_map_.find(id); j != function_map_.end()) { return j->second; } throw parser::syntax_error(l, "Undefined identifier @" + id); diff --git a/src/parser/parse_context.hpp b/src/parser/parse_context.hpp index 832cc3a8..1a213aaf 100644 --- a/src/parser/parse_context.hpp +++ b/src/parser/parse_context.hpp @@ -25,8 +25,8 @@ class parse_context { void val(std::string const &id, value val, location const &l); value val(std::string const &id, location const &l); - void prototype(std::string const &id, func p); - func prototype(std::string const &id, location const &l); + void add_function(std::string const &id, func fn); + func get_function(std::string const &id, location const &l); void add_error(location const &loc, std::string const &what); @@ -36,7 +36,7 @@ class parse_context { private: std::vector> id_map_; - std::unordered_map prototype_map_; + std::unordered_map function_map_; prog program_; std::vector> errors_; }; diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 28f761d4..2db5bd6b 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -237,15 +237,13 @@ func: } GLOBAL_IDENTIFIER LPAREN arguments RPAREN attributes region { auto loc = @FUNC; loc.end = @RPAREN.end; - auto proto = func{ - std::make_unique($GLOBAL_IDENTIFIER, std::move($arguments), loc).release()}; - ctx.prototype($GLOBAL_IDENTIFIER, proto); auto func_node = - std::make_unique(std::move(proto), std::move($region), @func).release(); + std::make_unique($GLOBAL_IDENTIFIER, std::move($arguments), std::move($region), loc).release(); for (auto &attr : $attributes) { attr(*func_node); } $func = func{func_node}; + ctx.add_function($GLOBAL_IDENTIFIER, $func); ctx.pop_scope(); } ; diff --git a/src/visitor/aa_results.cpp b/src/pass/aa_results.cpp similarity index 96% rename from src/visitor/aa_results.cpp rename to src/pass/aa_results.cpp index dbfafba3..fbf8aa44 100644 --- a/src/visitor/aa_results.cpp +++ b/src/pass/aa_results.cpp @@ -1,7 +1,7 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "visitor/aa_results.hpp" +#include "pass/aa_results.hpp" #include "node/value_node.hpp" #include diff --git a/src/visitor/aa_results.hpp b/src/pass/aa_results.hpp similarity index 100% rename from src/visitor/aa_results.hpp rename to src/pass/aa_results.hpp diff --git a/src/visitor/alias_analysis.cpp b/src/pass/alias_analysis.cpp similarity index 98% rename from src/visitor/alias_analysis.cpp rename to src/pass/alias_analysis.cpp index 891647dd..b0d45023 100644 --- a/src/visitor/alias_analysis.cpp +++ b/src/pass/alias_analysis.cpp @@ -1,7 +1,7 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "visitor/alias_analysis.hpp" +#include "pass/alias_analysis.hpp" #include "error.hpp" #include "node/data_type_node.hpp" #include "support/casting.hpp" diff --git a/src/visitor/alias_analysis.hpp b/src/pass/alias_analysis.hpp similarity index 97% rename from src/visitor/alias_analysis.hpp rename to src/pass/alias_analysis.hpp index 183b5c7b..9ae02e8d 100644 --- a/src/visitor/alias_analysis.hpp +++ b/src/pass/alias_analysis.hpp @@ -8,7 +8,7 @@ #include "node/inst_node.hpp" #include "node/region_node.hpp" #include "node/value_node.hpp" -#include "visitor/aa_results.hpp" +#include "pass/aa_results.hpp" #include diff --git a/src/pass/check_ir.cpp b/src/pass/check_ir.cpp new file mode 100644 index 00000000..6c2a65b1 --- /dev/null +++ b/src/pass/check_ir.cpp @@ -0,0 +1,39 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/check_ir.hpp" +#include "error.hpp" +#include "support/casting.hpp" +#include "support/visit.hpp" +#include "support/walk.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" + +#include +#include + +namespace tinytc { + +void check_ir_pass::run_on_function(function &fn) { + walk(fn, [this](inst_node const &i, walk_stage const &stage) { + const bool child_region_is_spmd_region = isa(i) || isa(i); + + if (stage.is_before_all_regions()) { + if (i.kind() == inst_execution_kind::collective && inside_spmd_region_) { + throw compilation_error(i.loc(), status::ir_collective_called_from_spmd); + } else if (i.kind() == inst_execution_kind::spmd && !inside_spmd_region_) { + throw compilation_error(i.loc(), status::ir_spmd_called_from_collective); + } + + if (child_region_is_spmd_region) { + inside_spmd_region_ = true; + } + } + + if (child_region_is_spmd_region && stage.is_after_all_regions()) { + inside_spmd_region_ = false; + } + }); +} + +} // namespace tinytc diff --git a/src/pass/check_ir.hpp b/src/pass/check_ir.hpp new file mode 100644 index 00000000..b6aad450 --- /dev/null +++ b/src/pass/check_ir.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CHECK_IR_20240222_HPP +#define CHECK_IR_20240222_HPP + +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/program_node.hpp" +#include "node/region_node.hpp" + +namespace tinytc { + +class check_ir_pass { + public: + void run_on_function(function &fn); + + private: + bool inside_spmd_region_ = false; +}; + +} // namespace tinytc + +#endif // CHECK_IR_20240222_HPP diff --git a/src/pass/constant_propagation.cpp b/src/pass/constant_propagation.cpp new file mode 100644 index 00000000..68acd373 --- /dev/null +++ b/src/pass/constant_propagation.cpp @@ -0,0 +1,172 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/constant_propagation.hpp" +#include "error.hpp" +#include "node/data_type_node.hpp" +#include "node/value_node.hpp" +#include "support/casting.hpp" +#include "support/visit.hpp" +#include "tinytc/tinytc.hpp" + +#include + +namespace tinytc { + +/* Inst nodes */ +void constant_propagation::operator()(inst_node &in) { + for (auto &op : in.operands()) { + if (op) { + uintptr_t u = std::bit_cast(op.get()); + if (auto kc = known_constants_.find(u); kc != known_constants_.end()) { + op = kc->second; + } + } + } +} + +void constant_propagation::operator()(arith_inst &arith) { + this->operator()(static_cast(arith)); + + auto const &a = arith.a(); + auto const &b = arith.b(); + + auto at = dyn_cast(a->ty().get()); + if (at == nullptr) { + throw compilation_error(a->loc(), status::ir_expected_scalar); + } + + if (is_floating_type(at->ty())) { + auto av = dyn_cast(a.get()); + auto bv = dyn_cast(b.get()); + if (av != nullptr && bv != nullptr) { + auto const compute = [&arith](auto a, auto b) { + switch (arith.operation()) { + case arithmetic::add: + return a + b; + case arithmetic::sub: + return a - b; + case arithmetic::mul: + return a * b; + case arithmetic::div: + return a / b; + case arithmetic::rem: + return std::fmod(a, b); + default: + break; + } + throw compilation_error(arith.loc(), status::ir_fp_unsupported); + }; + + auto constant_val = value{}; + switch (at->ty()) { + case scalar_type::f32: + constant_val = make_imm( + compute(static_cast(av->value()), static_cast(bv->value())), + scalar_type::f32, arith.loc()); + break; + case scalar_type::f64: + constant_val = + make_imm(compute(av->value(), bv->value()), scalar_type::f64, arith.loc()); + break; + default: + break; + }; + if (constant_val) { + uintptr_t u = std::bit_cast(arith.result().get()); + known_constants_[u] = std::move(constant_val); + } + } + } else { + auto av = dyn_cast(a.get()); + auto bv = dyn_cast(b.get()); + if (av != nullptr && bv != nullptr) { + auto const compute = [&arith](auto a, auto b) { + switch (arith.operation()) { + case arithmetic::add: + return a + b; + case arithmetic::sub: + return a - b; + case arithmetic::mul: + return a * b; + case arithmetic::div: + return a / b; + case arithmetic::rem: + return a % b; + case arithmetic::shl: + return a << b; + case arithmetic::shr: + return a >> b; + case arithmetic::and_: + return a & b; + case arithmetic::or_: + return a | b; + case arithmetic::xor_: + return a ^ b; + } + throw compilation_error(arith.loc(), status::runtime_error); + }; + + auto constant_val = value{}; + switch (at->ty()) { + case scalar_type::i1: { + bool const val = + compute(static_cast(av->value()), static_cast(bv->value())); + constant_val = + make_imm(static_cast(val), scalar_type::i1, arith.loc()); + break; + } + case scalar_type::i8: + constant_val = make_imm(compute(static_cast(av->value()), + static_cast(bv->value())), + arith.loc()); + break; + case scalar_type::i16: + constant_val = make_imm(compute(static_cast(av->value()), + static_cast(bv->value())), + arith.loc()); + break; + case scalar_type::i32: + constant_val = make_imm(compute(static_cast(av->value()), + static_cast(bv->value())), + arith.loc()); + break; + case scalar_type::i64: + constant_val = + make_imm(compute(av->value(), bv->value()), scalar_type::i64, arith.loc()); + break; + case scalar_type::index: + constant_val = + make_imm(compute(av->value(), bv->value()), scalar_type::index, arith.loc()); + break; + default: + break; + }; + if (constant_val) { + uintptr_t u = std::bit_cast(arith.result().get()); + known_constants_[u] = std::move(constant_val); + } + } + } +} + +void constant_propagation::operator()(parallel_inst &p) { visit(*this, *p.body()); } + +/* Region nodes */ +void constant_propagation::operator()(rgn &b) { + for (auto &s : b.insts()) { + visit(*this, *s); + } +} + +/* Function nodes */ +void constant_propagation::operator()(function &fn) { visit(*this, *fn.body()); } + +/* Program nodes */ +void constant_propagation::operator()(program &p) { + for (auto &fn : p.functions()) { + visit(*this, *fn); + } +} + +} // namespace tinytc diff --git a/src/pass/constant_propagation.hpp b/src/pass/constant_propagation.hpp new file mode 100644 index 00000000..737e6929 --- /dev/null +++ b/src/pass/constant_propagation.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CONSTANT_PROPAGATION_20240807_HPP +#define CONSTANT_PROPAGATION_20240807_HPP + +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/program_node.hpp" +#include "node/region_node.hpp" + +#include +#include +#include + +namespace tinytc { + +class constant_propagation { + public: + /* Inst nodes */ + void operator()(inst_node &); + void operator()(arith_inst &arith); + void operator()(parallel_inst &p); + + /* Region nodes */ + void operator()(rgn &b); + + /* Func nodes */ + void operator()(function &fn); + + /* Program nodes */ + void operator()(program &p); + + private: + std::unordered_map known_constants_; +}; + +} // namespace tinytc + +#endif // CONSTANT_PROPAGATION_20240807_HPP diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp new file mode 100644 index 00000000..4cd250ad --- /dev/null +++ b/src/pass/dump_ir.cpp @@ -0,0 +1,388 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/dump_ir.hpp" +#include "support/visit.hpp" +#include "tinytc/tinytc.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +dump_ir_pass::dump_ir_pass(std::ostream &os) : os_(&os) {} + +/* Data type nodes */ +void dump_ir_pass::operator()(void_data_type const &) { *os_ << "void"; } +void dump_ir_pass::operator()(group_data_type const &g) { + *os_ << "group<"; + visit(*this, *g.ty()); + *os_ << ">"; +} +void dump_ir_pass::operator()(memref_data_type const &d) { + auto const val = [&](std::int64_t v) -> std::ostream & { + if (is_dynamic_value(v)) { + return *os_ << "?"; + } + return *os_ << v; + }; + *os_ << "memref<" << to_string(d.element_ty()); + for (auto const &s : d.shape()) { + *os_ << "x"; + val(s); + } + if (!d.is_canonical_stride()) { + *os_ << ",strided<"; + do_with_infix(d.stride().begin(), d.stride().end(), [&](auto const &a) { val(a); }); + *os_ << ">"; + } + *os_ << ">"; +} +void dump_ir_pass::operator()(scalar_data_type const &s) { *os_ << to_string(s.ty()); } + +/* Value nodes */ +void dump_ir_pass::operator()(float_imm const &v) { + auto flags = os_->flags(); + *os_ << std::hexfloat << v.value(); + os_->flags(flags); +} +void dump_ir_pass::operator()(int_imm const &v) { + if (is_dynamic_value(v.value())) { + *os_ << "?"; + } else { + *os_ << v.value(); + } +} +void dump_ir_pass::operator()(val const &v) { + *os_ << "%" << v.name(); + auto const slot = tracker_.get_slot(v); + if (slot >= 0) { + *os_ << slot; + } +} + +/* Inst nodes */ +void dump_ir_pass::dump_blas_a2(blas_a2_inst const &g) { + visit(*this, *g.alpha()); + *os_ << ", "; + visit(*this, *g.A()); + *os_ << ", "; + visit(*this, *g.beta()); + *os_ << ", "; + visit(*this, *g.B()); + *os_ << " : "; + visit(*this, *g.alpha()->ty()); + *os_ << ", "; + visit(*this, *g.A()->ty()); + *os_ << ", "; + visit(*this, *g.beta()->ty()); + *os_ << ", "; + visit(*this, *g.B()->ty()); +} + +void dump_ir_pass::dump_blas_a3(blas_a3_inst const &g) { + visit(*this, *g.alpha()); + *os_ << ", "; + visit(*this, *g.A()); + *os_ << ", "; + visit(*this, *g.B()); + *os_ << ", "; + visit(*this, *g.beta()); + *os_ << ", "; + visit(*this, *g.C()); + *os_ << " : "; + visit(*this, *g.alpha()->ty()); + *os_ << ", "; + visit(*this, *g.A()->ty()); + *os_ << ", "; + visit(*this, *g.B()->ty()); + *os_ << ", "; + visit(*this, *g.beta()->ty()); + *os_ << ", "; + visit(*this, *g.C()->ty()); +} + +void dump_ir_pass::operator()(alloca_inst const &a) { + visit(*this, *a.result()); + *os_ << " = alloca -> "; + visit(*this, *a.result()->ty()); +} + +void dump_ir_pass::operator()(axpby_inst const &a) { + *os_ << "axpby"; + *os_ << "." << to_string(a.tA()) << " "; + dump_blas_a2(static_cast(a)); +} + +void dump_ir_pass::operator()(arith_inst const &a) { + visit(*this, *a.result()); + *os_ << " = arith." << to_string(a.operation()) << " "; + visit(*this, *a.a()); + *os_ << ", "; + visit(*this, *a.b()); + *os_ << " : "; + visit(*this, *a.a()->ty()); +} + +void dump_ir_pass::operator()(arith_unary_inst const &a) { + visit(*this, *a.result()); + *os_ << " = arith." << to_string(a.operation()) << " "; + visit(*this, *a.a()); + *os_ << " : "; + visit(*this, *a.a()->ty()); +} + +void dump_ir_pass::operator()(barrier_inst const &) { *os_ << "barrier"; } + +void dump_ir_pass::operator()(cast_inst const &c) { + visit(*this, *c.result()); + *os_ << " = cast "; + visit(*this, *c.a()); + *os_ << " : "; + visit(*this, *c.a()->ty()); + *os_ << " -> "; + visit(*this, *c.result()->ty()); +} + +void dump_ir_pass::operator()(compare_inst const &a) { + visit(*this, *a.result()); + *os_ << " = cmp." << to_string(a.cond()) << " "; + visit(*this, *a.a()); + *os_ << ", "; + visit(*this, *a.b()); + *os_ << " : "; + visit(*this, *a.a()->ty()); +} + +void dump_ir_pass::operator()(expand_inst const &e) { + visit(*this, *e.result()); + *os_ << " = expand "; + visit(*this, *e.operand()); + *os_ << "[" << e.mode() << "->"; + do_with_infix( + e.expand_shape().begin(), e.expand_shape().end(), + [this](auto const &i) { visit(*this, *i); }, "x"); + *os_ << "] : "; + visit(*this, *e.operand()->ty()); +} + +void dump_ir_pass::operator()(fuse_inst const &f) { + visit(*this, *f.result()); + *os_ << " = fuse "; + visit(*this, *f.operand()); + *os_ << "[" << f.from() << "," << f.to() << "]"; + *os_ << " : "; + visit(*this, *f.operand()->ty()); +} + +void dump_ir_pass::operator()(load_inst const &e) { + visit(*this, *e.result()); + *os_ << " = load "; + visit(*this, *e.operand()); + *os_ << "["; + do_with_infix(e.index_list().begin(), e.index_list().end(), + [this](auto const &i) { visit(*this, *i); }); + *os_ << "] : "; + visit(*this, *e.operand()->ty()); +} + +void dump_ir_pass::operator()(group_id_inst const &g) { + visit(*this, *g.result()); + *os_ << " = group_id"; +} + +void dump_ir_pass::operator()(group_size_inst const &g) { + visit(*this, *g.result()); + *os_ << " = group_size"; +} + +void dump_ir_pass::operator()(lifetime_stop_inst const &l) { + *os_ << "lifetime_stop "; + visit(*this, *l.object()); +} + +void dump_ir_pass::operator()(gemm_inst const &g) { + *os_ << "gemm"; + *os_ << "." << to_string(g.tA()); + *os_ << "." << to_string(g.tB()) << " "; + dump_blas_a3(static_cast(g)); +} + +void dump_ir_pass::operator()(gemv_inst const &g) { + *os_ << "gemv"; + *os_ << "." << to_string(g.tA()) << " "; + dump_blas_a3(static_cast(g)); +} + +void dump_ir_pass::operator()(ger_inst const &g) { + *os_ << "ger "; + dump_blas_a3(static_cast(g)); +} + +void dump_ir_pass::operator()(for_inst const &p) { + *os_ << "for "; + visit(*this, *p.loop_var()); + *os_ << "="; + visit(*this, *p.from()); + *os_ << ","; + visit(*this, *p.to()); + if (p.step()) { + *os_ << ","; + visit(*this, *p.step()); + } + *os_ << " : "; + visit(*this, *p.loop_var()->ty()); + *os_ << " "; + dump_region(*p.body()); +} + +void dump_ir_pass::operator()(foreach_inst const &p) { + *os_ << "foreach "; + visit(*this, *p.loop_var()); + *os_ << "="; + visit(*this, *p.from()); + *os_ << ","; + visit(*this, *p.to()); + *os_ << " : "; + visit(*this, *p.loop_var()->ty()); + *os_ << " "; + dump_region(*p.body()); +} + +void dump_ir_pass::operator()(hadamard_inst const &g) { + *os_ << "hadamard "; + dump_blas_a3(static_cast(g)); +} + +void dump_ir_pass::operator()(if_inst const &in) { + *os_ << "if "; + visit(*this, *in.condition()); + *os_ << " "; + dump_region(*in.then()); + if (in.otherwise()) { + *os_ << " else "; + dump_region(*in.otherwise()); + } +} + +void dump_ir_pass::operator()(num_subgroups_inst const &sg) { + visit(*this, *sg.result()); + *os_ << " = num_subgroups"; +} + +void dump_ir_pass::operator()(parallel_inst const &p) { + *os_ << "parallel "; + dump_region(*p.body()); +} + +void dump_ir_pass::operator()(size_inst const &s) { + visit(*this, *s.result()); + *os_ << " = size "; + visit(*this, *s.operand()); + *os_ << "[" << s.mode() << "]"; + *os_ << " : "; + visit(*this, *s.operand()->ty()); +} + +void dump_ir_pass::operator()(subgroup_id_inst const &sg) { + visit(*this, *sg.result()); + *os_ << " = subgroup_id"; +} + +void dump_ir_pass::operator()(subgroup_local_id_inst const &sg) { + visit(*this, *sg.result()); + *os_ << " = subgroup_local_id"; +} + +void dump_ir_pass::operator()(subgroup_size_inst const &sg) { + visit(*this, *sg.result()); + *os_ << " = subgroup_size"; +} + +void dump_ir_pass::operator()(subview_inst const &s) { + visit(*this, *s.result()); + *os_ << " = subview "; + visit(*this, *s.operand()); + *os_ << "["; + auto irange = std::ranges::iota_view{std::size_t{0}, s.offset_list().size()}; + do_with_infix(irange.begin(), irange.end(), [&](auto const &i) { + visit(*this, *s.offset_list()[i]); + auto &size = s.size_list()[i]; + if (size) { + *os_ << ":"; + visit(*this, *size); + } + }); + *os_ << "]"; + *os_ << " : "; + visit(*this, *s.operand()->ty()); + *os_ << " ; -> "; + visit(*this, *s.result()->ty()); +} + +void dump_ir_pass::operator()(store_inst const &e) { + *os_ << "store "; + visit(*this, *e.val()); + *os_ << ", "; + visit(*this, *e.operand()); + *os_ << "["; + do_with_infix(e.index_list().begin(), e.index_list().end(), + [this](auto const &i) { visit(*this, *i); }); + *os_ << "] : "; + visit(*this, *e.operand()->ty()); +} + +void dump_ir_pass::operator()(sum_inst const &a) { + *os_ << "sum"; + *os_ << "." << to_string(a.tA()) << " "; + dump_blas_a2(static_cast(a)); +} + +void dump_ir_pass::operator()(yield_inst const &y) { + *os_ << "yield "; + do_with_infix(y.op_begin(), y.op_end(), [this](auto const &i) { visit(*this, *i); }); + *os_ << " : "; + do_with_infix(y.op_begin(), y.op_end(), [this](auto const &i) { visit(*this, *i->ty()); }); +} + +void dump_ir_pass::dump_region(rgn const ®) { + *os_ << "{" << std::endl; + ++lvl_; + auto ind = indent(); + for (auto const &i : reg) { + *os_ << ind; + visit(*this, *i); + *os_ << std::endl; + } + --lvl_; + *os_ << indent() << "}"; +} + +void dump_ir_pass::run_on_function(function const &fn) { + *os_ << "func @" << fn.name() << "("; + std::string infix = ",\n "; + infix += std::string(fn.name().size(), ' '); + do_with_infix( + fn.args().begin(), fn.args().end(), + [this](auto const &a) { + visit(*this, *a); + *os_ << ": "; + visit(*this, *a->ty()); + }, + infix); + *os_ << ") "; + auto const sgs = fn.subgroup_size(); + auto const wgs = fn.work_group_size(); + if (sgs != 0) { + *os_ << "subgroup_size(" << sgs << ") "; + } + if (wgs[0] != 0 && wgs[1] != 0) { + *os_ << "work_group_size(" << wgs[0] << "," << wgs[1] << ") "; + } + dump_region(*fn.body()); + *os_ << std::endl; +} + +} // namespace tinytc diff --git a/src/visitor/dump_ir.hpp b/src/pass/dump_ir.hpp similarity index 87% rename from src/visitor/dump_ir.hpp rename to src/pass/dump_ir.hpp index 6912fde3..f55f93b9 100644 --- a/src/visitor/dump_ir.hpp +++ b/src/pass/dump_ir.hpp @@ -10,16 +10,16 @@ #include "node/program_node.hpp" #include "node/region_node.hpp" #include "node/value_node.hpp" -#include "visitor/slot_tracker.hpp" +#include "pass/slot_tracker.hpp" #include #include namespace tinytc { -class ir_dumper { +class dump_ir_pass { public: - ir_dumper(std::ostream &os); + dump_ir_pass(std::ostream &os); /* Data type nodes */ void operator()(void_data_type const &); @@ -64,17 +64,10 @@ class ir_dumper { void operator()(sum_inst const &s); void operator()(yield_inst const &y); - /* Region nodes */ - void operator()(rgn const &b); - - /* Func nodes */ - void operator()(prototype const &p); - void operator()(function const &fn); - - /* Program nodes */ - void operator()(program const &p); + void run_on_function(function const &fn); private: + void dump_region(rgn const ®); void dump_blas_a2(blas_a2_inst const &g); void dump_blas_a3(blas_a3_inst const &g); @@ -82,13 +75,13 @@ class ir_dumper { void do_with_infix(Iterator begin, Iterator end, Action a, std::string const &infix = ",") { for (auto it = begin; it != end; ++it) { if (it != begin) { - os_ << infix; + *os_ << infix; } a(*it); } } inline auto indent() { return std::string(2 * lvl_, ' '); } - std::ostream &os_; + std::ostream *os_; int lvl_ = 0; slot_tracker tracker_; diff --git a/src/visitor/equal.cpp b/src/pass/equal.cpp similarity index 96% rename from src/visitor/equal.cpp rename to src/pass/equal.cpp index 403331cf..010968d0 100644 --- a/src/visitor/equal.cpp +++ b/src/pass/equal.cpp @@ -1,7 +1,7 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "visitor/equal.hpp" +#include "pass/equal.hpp" #include "support/visit.hpp" #include "tinytc/tinytc.hpp" diff --git a/src/visitor/equal.hpp b/src/pass/equal.hpp similarity index 100% rename from src/visitor/equal.hpp rename to src/pass/equal.hpp diff --git a/src/visitor/insert_barrier.cpp b/src/pass/insert_barrier.cpp similarity index 96% rename from src/visitor/insert_barrier.cpp rename to src/pass/insert_barrier.cpp index 01923a9d..8fa2bab8 100644 --- a/src/visitor/insert_barrier.cpp +++ b/src/pass/insert_barrier.cpp @@ -1,11 +1,11 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "visitor/insert_barrier.hpp" +#include "pass/insert_barrier.hpp" +#include "pass/alias_analysis.hpp" #include "support/casting.hpp" #include "support/visit.hpp" #include "tinytc/tinytc.hpp" -#include "visitor/alias_analysis.hpp" #include @@ -148,8 +148,8 @@ void insert_barrier::operator()(function &fn) { /* Program nodes */ void insert_barrier::operator()(program &p) { - for (auto &decl : p.declarations()) { - visit(*this, *decl); + for (auto &fn : p.functions()) { + visit(*this, *fn); } } diff --git a/src/visitor/insert_barrier.hpp b/src/pass/insert_barrier.hpp similarity index 98% rename from src/visitor/insert_barrier.hpp rename to src/pass/insert_barrier.hpp index d7524b84..9d9ac4b8 100644 --- a/src/visitor/insert_barrier.hpp +++ b/src/pass/insert_barrier.hpp @@ -10,7 +10,7 @@ #include "node/program_node.hpp" #include "node/region_node.hpp" #include "node/value_node.hpp" -#include "visitor/aa_results.hpp" +#include "pass/aa_results.hpp" #include diff --git a/src/visitor/lifetime_analysis.cpp b/src/pass/lifetime_analysis.cpp similarity index 97% rename from src/visitor/lifetime_analysis.cpp rename to src/pass/lifetime_analysis.cpp index dd1e6420..4fb5f365 100644 --- a/src/visitor/lifetime_analysis.cpp +++ b/src/pass/lifetime_analysis.cpp @@ -1,10 +1,10 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "visitor/lifetime_analysis.hpp" +#include "pass/lifetime_analysis.hpp" #include "node/value_node.hpp" +#include "pass/alias_analysis.hpp" #include "support/visit.hpp" -#include "visitor/alias_analysis.hpp" #include #include @@ -163,8 +163,8 @@ void lifetime_inserter::operator()(function &fn) { } void lifetime_inserter::operator()(program &p) { - for (auto &decl : p.declarations()) { - visit(*this, *decl); + for (auto &fn : p.functions()) { + visit(*this, *fn); } } diff --git a/src/visitor/lifetime_analysis.hpp b/src/pass/lifetime_analysis.hpp similarity index 98% rename from src/visitor/lifetime_analysis.hpp rename to src/pass/lifetime_analysis.hpp index 1cbb26a9..9d901a7e 100644 --- a/src/visitor/lifetime_analysis.hpp +++ b/src/pass/lifetime_analysis.hpp @@ -8,9 +8,9 @@ #include "node/inst_node.hpp" #include "node/program_node.hpp" #include "node/region_node.hpp" +#include "pass/aa_results.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" -#include "visitor/aa_results.hpp" #include #include diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp new file mode 100644 index 00000000..05b44707 --- /dev/null +++ b/src/pass/lower_linalg.cpp @@ -0,0 +1,194 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/lower_linalg.hpp" +#include "codegen_tools.hpp" +#include "error.hpp" +#include "support/casting.hpp" +#include "support/visit.hpp" +#include "tinytc/tinytc.hpp" + +namespace tinytc { + +auto lower_linalg_pass::get_memref_type(value_node const &v) const -> const memref_data_type * { + auto t = dyn_cast(v.ty().get()); + if (t == nullptr) { + throw compilation_error(v.loc(), status::ir_expected_memref); + } + return t; +} + +lower_linalg_pass::lower_linalg_pass(::tinytc_core_info const *info) : info_(std::move(info)) { + if (info_ == nullptr) { + throw std::invalid_argument("info must not be nullptr"); + } +} + +/* Data type nodes */ +// bool lower_linalg_pass::operator()(void_data_type &) { return false; } +// bool lower_linalg_pass::operator()(group_data_type &b) { return visit(*this, *b.ty()); } +// bool lower_linalg_pass::operator()(memref_data_type &m) { +// return m.addrspace() == clir::address_space::local_t; +//} +// bool lower_linalg_pass::operator()(scalar_data_type &) { return false; } + +//[> Value nodes <] +// value_node *lower_linalg_pass::operator()(float_imm &) { return nullptr; } +// value_node *lower_linalg_pass::operator()(int_imm &) { return nullptr; } +// value_node *lower_linalg_pass::operator()(val &v) { +// if (visit(*this, *v.ty())) { +// return &v; +//} +// return nullptr; +//} + +/* Inst nodes */ +inst lower_linalg_pass::operator()(inst_node &) { return inst{nullptr}; } + +inst lower_linalg_pass::operator()(loop_inst &p) { + visit(*this, *p.body()); + return inst{nullptr}; +} + +inst lower_linalg_pass::operator()(if_inst &in) { + visit(*this, *in.then()); + if (in.otherwise()) { + visit(*this, *in.otherwise()); + } + return inst{nullptr}; +} + +inst lower_linalg_pass::operator()(parallel_inst &p) { + visit(*this, *p.body()); + return inst{nullptr}; +} + +inst lower_linalg_pass::operator()(ger_inst &g) { + // auto at = get_memref_type(*g.A()); + // auto bt = get_memref_type(*g.B()); + auto ct = get_memref_type(*g.C()); + + auto bb = region_builder{}; + auto sgid = bb.add(make_subgroup_id(g.loc())); + auto m_tiles_imm = make_imm(tiling_.m_tiles(), g.loc()); + auto sg_n = bb.add(make_arith(arithmetic::div, sgid, m_tiles_imm, g.loc())); + auto sg_m = bb.add(make_arith(arithmetic::rem, sgid, m_tiles_imm, g.loc())); + auto m = bb.add(make_subgroup_local_id(g.loc())); + auto m_index = bb.add(make_cast(m, scalar_type::index, g.loc())); + + auto c_shape1 = is_dynamic_value(ct->shape(1)) ? bb.add(make_size(g.C(), 1, g.loc())) + : make_index(ct->shape(1), g.loc()); + auto c_shape0 = is_dynamic_value(ct->shape(0)) ? bb.add(make_size(g.C(), 0, g.loc())) + : make_index(ct->shape(0), g.loc()); + tile_loop_uniformly_new( + bb, std::move(c_shape1), core_cfg_.subgroup_size, tiling_.n_tiles(), std::move(sg_n), + [&](region_builder &bb, value block, value trip_count) { + bb.for_loop(scalar_type::index, make_index(0, g.loc()), trip_count, + [&](region_builder &bb, value const &n) { + auto nn = bb.add(make_arith(arithmetic::add, block, n, g.loc())); + auto b = bb.add(make_load(g.B(), {nn}, g.loc())); + b->name("b"); + tile_loop_by_sgs_new( + bb, c_shape0, core_cfg_.subgroup_size, tiling_.m_tiles(), sg_m, + [&](region_builder &bb, value const &block, bool is_remainder, + value const &inner_trip_count) { + auto mm = bb.add( + make_arith(arithmetic::add, block, m_index, g.loc())); + auto a = bb.add(make_load(g.A(), {mm}, g.loc())); + a->name("a"); + auto ab = bb.add(make_arith(arithmetic::mul, a, b, g.loc())); + bb.add(make_store(ab, g.C(), {mm, nn}, g.loc())); + }); + }); + }); + return make_parallel(bb.get_product(), g.loc()); + + /*auto alpha = visit(*this, *g.alpha()); + auto beta = visit(*this, *g.beta()); + auto alpha_ty = get_scalar_type(*g.alpha()->ty()); + auto beta_ty = get_scalar_type(*g.beta()->ty()); + + auto A = visit(*this, *g.A()); + auto B = visit(*this, *g.B()); + auto C = visit(*this, *g.C()); + + auto bb = clir::block_builder{}; + auto sg_n = bb.declare_assign(clir::generic_uint(), "sg_n", + clir::get_sub_group_id() / tiling_.m_tiles()); + auto sg_m = bb.declare_assign(clir::generic_uint(), "sg_m", + clir::get_sub_group_id() % tiling_.m_tiles()); + tile_loop_uniformly( + bb, cdv.shape(1), core_cfg_.subgroup_size, tiling_.n_tiles(), std::move(sg_n), + [&](clir::block_builder &bb, clir::expr block, clir::expr trip_count) { + auto n = clir::var("n"); + bb.add(clir::for_loop_builder(clir::declaration_assignment(clir::generic_int(), n, + 0), n < std::move(trip_count), ++n) .body([&](clir::block_builder &bb) { auto b = + bb.declare_assign(to_clir_ty(bt->element_ty()), "b", B + (block + n) * bdv.stride(0)); auto + Cb = bb.declare_assign(this->operator()(*ct), "Cb", C + (block + n) * cdv.stride(1)); auto m + = bb.declare_assign(clir::generic_uint(), "m", clir::get_sub_group_local_id()); + tile_loop_by_sgs( + bb, cdv.shape(0), core_cfg_.subgroup_size, tiling_.m_tiles(), + sg_m, + [&](clir::block_builder &bb, clir::expr block, bool is_remainder, + clir::expr inner_trip_count) { + auto const inner_loop = [&](clir::block_builder &bb) { + auto a = A[(block + m) * adv.stride(0)]; + auto c = bb.declare_assign((*this)(*ct), "c", + Cb + (block + m) * + cdv.stride(0)); auto ab = bb.declare_assign( to_clir_ty(ct->element_ty()), "ab", + multiply(at->element_ty(), bt->element_ty(), + std::move(a), b)); + const auto ab_scaled = multiply(alpha_ty, + ct->element_ty(), alpha, std::move(ab)); store_helper(bb, g.atomic(), c, ct->element_ty(), + ct->addrspace(), std::move(ab_scaled), + beta_ty, beta); + }; + if (is_remainder) { + bb.add(clir::if_selection_builder( + m < std::move(inner_trip_count)) + .then(inner_loop) + .get_product()); + } else { + inner_loop(bb); + } + }); + }) + .get_product()); + });*/ +} + +/* Region nodes */ +void lower_linalg_pass::operator()(rgn &b) { + for (auto &s : b.insts()) { + if (auto lowered_inst = visit(*this, *s); lowered_inst) { + s = lowered_inst; + } + } +} + +/* Function nodes */ +void lower_linalg_pass::operator()(prototype &) {} + +void lower_linalg_pass::operator()(function &fn) { + auto const subgroup_size = fn.subgroup_size(); + try { + core_cfg_ = info_->get_core_config(subgroup_size); + } catch (std::out_of_range const &e) { + throw compilation_error(fn.loc(), status::unsupported_subgroup_size); + } + auto const work_group_size = fn.work_group_size(); + tiling_[0] = work_group_size[0] / subgroup_size; + tiling_[1] = work_group_size[1]; + + visit(*this, *fn.prototype()); + visit(*this, *fn.body()); +} + +/* Program nodes */ +void lower_linalg_pass::operator()(program &p) { + for (auto &fn : p.functions()) { + visit(*this, *fn); + } +} + +} // namespace tinytc diff --git a/src/pass/lower_linalg.hpp b/src/pass/lower_linalg.hpp new file mode 100644 index 00000000..ca8157b8 --- /dev/null +++ b/src/pass/lower_linalg.hpp @@ -0,0 +1,63 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LOWER_LINALG_20240801_HPP +#define LOWER_LINALG_20240801_HPP + +#include "device_info.hpp" +#include "node/data_type_node.hpp" +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/program_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "tiling.hpp" +#include "tinytc/types.h" + +#include + +namespace tinytc { + +class lower_linalg_pass { + public: + lower_linalg_pass(::tinytc_core_info const *info); + + /* Data type nodes */ + // bool operator()(void_data_type &); + // bool operator()(group_data_type &b); + // bool operator()(memref_data_type &m); + // bool operator()(scalar_data_type &s); + + //[> Value nodes <] + // value_node *operator()(int_imm &v); + // value_node *operator()(float_imm &v); + // value_node *operator()(val &v); + + /* Stmt nodes */ + inst operator()(inst_node &); + inst operator()(loop_inst &p); + inst operator()(ger_inst &g); + inst operator()(if_inst &in); + inst operator()(parallel_inst &p); + + /* Region nodes */ + void operator()(rgn &b); + + /* Func nodes */ + void operator()(prototype &p); + void operator()(function &fn); + + /* Program nodes */ + void operator()(program &p); + + private: + auto get_memref_type(value_node const &v) const -> const memref_data_type *; + + ::tinytc_core_info const *info_; + local_tiling tiling_ = {}; + core_config core_cfg_ = {}; +}; + +} // namespace tinytc + +#endif // LOWER_LINALG_20240801_HPP diff --git a/src/visitor/metadata.cpp b/src/pass/metadata.cpp similarity index 79% rename from src/visitor/metadata.cpp rename to src/pass/metadata.cpp index 0b4d4ef3..d4d8c76a 100644 --- a/src/visitor/metadata.cpp +++ b/src/pass/metadata.cpp @@ -1,7 +1,7 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "visitor/metadata.hpp" +#include "pass/metadata.hpp" #include "node/function_node.hpp" #include "node/program_node.hpp" #include "support/visit.hpp" @@ -13,8 +13,6 @@ namespace tinytc { /* Function nodes */ -void metadata::operator()(prototype const &) {} - void metadata::operator()(function const &fn) { auto m = kernel_metadata{}; m.subgroup_size = fn.subgroup_size(); @@ -24,8 +22,8 @@ void metadata::operator()(function const &fn) { /* Program nodes */ void metadata::operator()(program const &p) { - for (auto &decl : p.declarations()) { - visit(*this, *decl); + for (auto &fn : p.functions()) { + visit(*this, *fn); } } diff --git a/src/visitor/metadata.hpp b/src/pass/metadata.hpp similarity index 94% rename from src/visitor/metadata.hpp rename to src/pass/metadata.hpp index da746daf..67bc6c93 100644 --- a/src/visitor/metadata.hpp +++ b/src/pass/metadata.hpp @@ -16,7 +16,6 @@ namespace tinytc { class metadata { public: /* Func nodes */ - void operator()(prototype const &p); void operator()(function const &fn); /* Program nodes */ diff --git a/src/visitor/opencl_ast.cpp b/src/pass/opencl_ast.cpp similarity index 99% rename from src/visitor/opencl_ast.cpp rename to src/pass/opencl_ast.cpp index aebf3568..cdfbeec5 100644 --- a/src/visitor/opencl_ast.cpp +++ b/src/pass/opencl_ast.cpp @@ -1,7 +1,7 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "visitor/opencl_ast.hpp" +#include "pass/opencl_ast.hpp" #include "codegen_tools.hpp" #include "error.hpp" #include "gemm_generator.hpp" @@ -1143,14 +1143,14 @@ clir::prog opencl_ast::operator()(program const &p) { auto operator()(prototype const &p) -> std::string_view { return p.name(); } }; reserved_names_.clear(); - for (auto const &decl : p.declarations()) { - reserved_names_.insert(std::string(visit(name_visitor{}, *decl))); + for (auto const &fn : p.functions()) { + reserved_names_.insert(std::string(visit(name_visitor{}, *fn))); } prog_builder_ = clir::program_builder{}; - for (auto const &decl : p.declarations()) { + for (auto const &fn : p.functions()) { stack_high_water_mark_ = 0; - prog_builder_.add(visit(*this, *decl)); + prog_builder_.add(visit(*this, *fn)); } return prog_builder_.get_product(); } diff --git a/src/visitor/opencl_ast.hpp b/src/pass/opencl_ast.hpp similarity index 100% rename from src/visitor/opencl_ast.hpp rename to src/pass/opencl_ast.hpp diff --git a/src/pass/slot_tracker.cpp b/src/pass/slot_tracker.cpp new file mode 100644 index 00000000..e4a75854 --- /dev/null +++ b/src/pass/slot_tracker.cpp @@ -0,0 +1,37 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/slot_tracker.hpp" +#include "support/visit.hpp" +#include "support/walk.hpp" +#include "tinytc/tinytc.hpp" + +#include +#include + +namespace tinytc { + +void slot_tracker::set_slot(value_node const &v) { + if (!v.has_name()) { + slot_map_[&v] = slot_++; + } +} + +void slot_tracker::run_on_function(function &fn) { + slot_ = 0; + for (auto const &arg : fn.args()) { + set_slot(*arg); + } + walk(fn, [this](inst_node const &i) { + for (auto const &result : i.results()) { + set_slot(*result); + } + }); +} + +auto slot_tracker::get_slot(value_node const &v) -> std::int64_t { + auto it = slot_map_.find(&v); + return it != slot_map_.end() ? it->second : -1; +} + +} // namespace tinytc diff --git a/src/visitor/slot_tracker.hpp b/src/pass/slot_tracker.hpp similarity index 61% rename from src/visitor/slot_tracker.hpp rename to src/pass/slot_tracker.hpp index 3c93e683..9a52dfd4 100644 --- a/src/visitor/slot_tracker.hpp +++ b/src/pass/slot_tracker.hpp @@ -17,21 +17,7 @@ namespace tinytc { class slot_tracker { public: - /* Stmt nodes */ - void operator()(inst_node const &in); - void operator()(loop_inst const &p); - void operator()(if_inst const &in); - void operator()(parallel_inst const &p); - - /* Region nodes */ - void operator()(rgn const &b); - - /* Func nodes */ - void operator()(prototype const &); - void operator()(function const &fn); - - /* Program nodes */ - void operator()(program const &p); + void run_on_function(function &fn); auto get_slot(value_node const &v) -> std::int64_t; diff --git a/src/visitor/stack.cpp b/src/pass/stack.cpp similarity index 95% rename from src/visitor/stack.cpp rename to src/pass/stack.cpp index 2b291b88..329c1ad7 100644 --- a/src/visitor/stack.cpp +++ b/src/pass/stack.cpp @@ -1,7 +1,7 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "visitor/stack.hpp" +#include "pass/stack.hpp" #include "error.hpp" #include "node/data_type_node.hpp" #include "support/casting.hpp" @@ -73,9 +73,9 @@ void stack_ptr::operator()(function &fn) { /* Program nodes */ void stack_ptr::operator()(program &p) { - for (auto &decl : p.declarations()) { + for (auto &fn : p.functions()) { allocs_.clear(); - visit(*this, *decl); + visit(*this, *fn); } } diff --git a/src/visitor/stack.hpp b/src/pass/stack.hpp similarity index 100% rename from src/visitor/stack.hpp rename to src/pass/stack.hpp diff --git a/src/visitor/work_group_size.cpp b/src/pass/work_group_size.cpp similarity index 97% rename from src/visitor/work_group_size.cpp rename to src/pass/work_group_size.cpp index 3ad600cd..6fedb3c8 100644 --- a/src/visitor/work_group_size.cpp +++ b/src/pass/work_group_size.cpp @@ -1,7 +1,7 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "visitor/work_group_size.hpp" +#include "pass/work_group_size.hpp" #include "device_info.hpp" #include "error.hpp" #include "node/data_type_node.hpp" @@ -121,8 +121,8 @@ void work_group_size::operator()(function &fn) { /* Program nodes */ void work_group_size::operator()(program &p) { - for (auto &decl : p.declarations()) { - visit(*this, *decl); + for (auto &fn : p.functions()) { + visit(*this, *fn); } } diff --git a/src/visitor/work_group_size.hpp b/src/pass/work_group_size.hpp similarity index 100% rename from src/visitor/work_group_size.hpp rename to src/pass/work_group_size.hpp diff --git a/src/passes.cpp b/src/passes.cpp index ef27bad0..d9c3cd53 100644 --- a/src/passes.cpp +++ b/src/passes.cpp @@ -5,30 +5,29 @@ #include "device_info.hpp" #include "kernel_metadata.hpp" #include "node/data_type_node.hpp" -#include "node/function_node.hpp" -#include "node/program_node.hpp" -#include "visitor/check_ir.hpp" -#include "visitor/dump_ir.hpp" -#include "visitor/equal.hpp" -#include "visitor/insert_barrier.hpp" -#include "visitor/lifetime_analysis.hpp" -#include "visitor/metadata.hpp" -#include "visitor/opencl_ast.hpp" -#include "visitor/stack.hpp" -#include "visitor/work_group_size.hpp" - -#include - -using clir::visit; +// #include "node/function_node.hpp" +// #include "node/program_node.hpp" +#include "pass/check_ir.hpp" +// #include "pass/constant_propagation.hpp" +#include "pass/dump_ir.hpp" +#include "pass/equal.hpp" +// #include "pass/insert_barrier.hpp" +// #include "pass/lifetime_analysis.hpp" +// #include "pass/lower_linalg.hpp" +// #include "pass/metadata.hpp" +// #include "pass/opencl_ast.hpp" +// #include "pass/stack.hpp" +// #include "pass/work_group_size.hpp" +#include "support/visit.hpp" namespace tinytc { -void check_ir(tinytc_prog const &p) { return visit(ir_checker{}, p); } +// void check_ir(tinytc_prog const &p) { return visit(ir_checker{}, p); } -void dump_ir(std::ostream &os, tinytc_func const &f) { visit(ir_dumper{os}, f); } -void dump_ir(std::ostream &os, tinytc_prog const &p) { visit(ir_dumper{os}, p); } +// void dump_ir(std::ostream &os, tinytc_func const &f) { visit(ir_dumper{os}, f); } +void dump_ir(std::ostream &os, tinytc_prog const &p) { run_function_pass(dump_ir_pass{os}, p); } -clir::prog generate_opencl_ast(tinytc_prog const &p, ::tinytc_core_info const &info) { +/*clir::prog generate_opencl_ast(tinytc_prog const &p, ::tinytc_core_info const &info) { return visit(opencl_ast{&info}, p); } @@ -42,10 +41,16 @@ void insert_barriers(tinytc_func &f) { visit(insert_barrier{}, f); } void insert_barriers(tinytc_prog &p) { visit(insert_barrier{}, p); } void insert_lifetime_stop_inst(tinytc_func &f) { visit(lifetime_inserter{}, f); } -void insert_lifetime_stop_inst(tinytc_prog &p) { visit(lifetime_inserter{}, p); } +void insert_lifetime_stop_inst(tinytc_prog &p) { visit(lifetime_inserter{}, p); }*/ bool is_equal(tinytc_data_type const &a, tinytc_data_type const &b) { return visit(equal{}, a, b); } +/*void lower_linalg(tinytc_prog &p, ::tinytc_core_info const &info) { + visit(lower_linalg_pass{&info}, p); +} + +void propagate_constants(tinytc_prog &p) { visit(constant_propagation{}, p); } + void set_stack_ptrs(tinytc_func &f) { visit(stack_ptr{}, f); } void set_stack_ptrs(tinytc_prog &p) { visit(stack_ptr{}, p); } @@ -54,7 +59,7 @@ void set_work_group_size(tinytc_func &f, ::tinytc_core_info const &info) { } void set_work_group_size(tinytc_prog &p, ::tinytc_core_info const &info) { visit(work_group_size{&info}, p); -} +}*/ } // namespace tinytc diff --git a/src/passes.def b/src/passes.def new file mode 100644 index 00000000..12a5aeac --- /dev/null +++ b/src/passes.def @@ -0,0 +1,5 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +FUNCTION_PASS("check-ir", check_ir_pass{}) +FUNCTION_PASS("dump-ir", dump_ir_pass{std::cout}) diff --git a/src/passes.hpp b/src/passes.hpp index 19515e18..004c0324 100644 --- a/src/passes.hpp +++ b/src/passes.hpp @@ -5,6 +5,7 @@ #define PASSES_20240314_HPP #include "kernel_metadata.hpp" +#include "node/program_node.hpp" #include "tinytc/types.h" #include @@ -15,13 +16,13 @@ namespace tinytc { //! Check whether some IR rules are respected -void check_ir(tinytc_prog const &p); +// void check_ir(tinytc_prog const &p); //! Dump IR to ostream void dump_ir(std::ostream &os, tinytc_func const &f); //! Dump IR to ostream void dump_ir(std::ostream &os, tinytc_prog const &p); //! Generate OpenCL AST -clir::prog generate_opencl_ast(tinytc_prog const &p, tinytc_core_info const &info); +/*clir::prog generate_opencl_ast(tinytc_prog const &p, tinytc_core_info const &info); //! Get kernel metadata auto get_metadata(tinytc_prog const &p) -> std::unordered_map; //! Insert barriers where necessary @@ -31,9 +32,13 @@ void insert_barriers(tinytc_prog &p); //! Insert lifetime stop instructions for set_stack_ptrs pass void insert_lifetime_stop_inst(tinytc_func &f); //! Insert lifetime stop instructions for set_stack_ptrs pass -void insert_lifetime_stop_inst(tinytc_prog &p); +void insert_lifetime_stop_inst(tinytc_prog &p);*/ //! Check whether data types a and b are equal bool is_equal(tinytc_data_type const &a, tinytc_data_type const &b); +//! Implement linear algebra instructions +/*void lower_linalg(tinytc_prog &p, tinytc_core_info const &info); +//! Constant propagation +void propagate_constants(tinytc_prog &p); //! Manage temporary memory requested by alloca void set_stack_ptrs(tinytc_func &f); //! Manage temporary memory requested by alloca @@ -41,7 +46,13 @@ void set_stack_ptrs(tinytc_prog &p); //! Choose work group and subgroup size heuristically if not given explicitly void set_work_group_size(tinytc_func &f, tinytc_core_info const &info); //! Choose work group and subgroup size heuristically if not given explicitly -void set_work_group_size(tinytc_prog &p, tinytc_core_info const &info); +void set_work_group_size(tinytc_prog &p, tinytc_core_info const &info);*/ + +template void run_function_pass(FunctionPass &&pass, tinytc_prog const &p) { + for (auto const &func : p.functions()) { + pass.run_on_function(*func); + } +} } // namespace tinytc diff --git a/src/support/walk.cpp b/src/support/walk.cpp new file mode 100644 index 00000000..b6a13755 --- /dev/null +++ b/src/support/walk.cpp @@ -0,0 +1,26 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "support/walk.hpp" + +namespace tinytc { + +walk_stage::walk_stage(inst_node &i) : num_regions_(i.num_child_regions()) {} + +void walk(inst_node &i, std::function callback) { + auto stage = walk_stage(i); + + for (auto ® : i.child_regions()) { + callback(i, stage); + stage.advance(); + + if (reg) { + for (auto &j : *reg) { + walk(*j, callback); + } + } + } + callback(i, stage); +} + +} // namespace tinytc diff --git a/src/support/walk.hpp b/src/support/walk.hpp new file mode 100644 index 00000000..c64f59f7 --- /dev/null +++ b/src/support/walk.hpp @@ -0,0 +1,65 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef WALK_20240911_HPP +#define WALK_20240911_HPP + +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" + +#include +#include + +namespace tinytc { + +enum class walk_order { pre_order, post_order }; + +class walk_stage { + public: + walk_stage(inst_node &i); + + inline bool is_before_all_regions() const { return next_region_ == 0; } + inline bool is_after_all_regions() const { return next_region_ == num_regions_; } + + void advance() { ++next_region_; } + + private: + int const num_regions_; + int next_region_ = 0; +}; + +template void walk(inst_node &i, std::function callback) { + if constexpr (Order == walk_order::pre_order) { + callback(i); + } + for (auto ® : i.child_regions()) { + if (reg) { + for (auto &j : *reg) { + walk(*j, callback); + } + } + } + if constexpr (Order == walk_order::post_order) { + callback(i); + } +} + +void walk(inst_node &i, std::function callback); + +template void walk(function &fn, std::function callback) { + for (auto &i : *fn.body()) { + walk(*i, callback); + } +} + +inline void walk(function &fn, + std::function callback) { + for (auto &i : *fn.body()) { + walk(*i, callback); + } +} + +} // namespace tinytc + +#endif // WALK_20240911_HPP diff --git a/src/visitor/check_ir.cpp b/src/visitor/check_ir.cpp deleted file mode 100644 index e3eea4aa..00000000 --- a/src/visitor/check_ir.cpp +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "visitor/check_ir.hpp" -#include "error.hpp" -#include "support/visit.hpp" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.hpp" - -#include - -namespace tinytc { - -/* Stmt nodes */ -void ir_checker::operator()(inst_node const &in) { - if (in.kind() == inst_execution_kind::collective && inside_spmd_region_) { - throw compilation_error(in.loc(), status::ir_collective_called_from_spmd); - } else if (in.kind() == inst_execution_kind::spmd && !inside_spmd_region_) { - throw compilation_error(in.loc(), status::ir_spmd_called_from_collective); - } -} -void ir_checker::operator()(for_inst const &p) { return visit(*this, *p.body()); } -void ir_checker::operator()(foreach_inst const &p) { - this->operator()(static_cast(p)); - inside_spmd_region_ = true; - visit(*this, *p.body()); - inside_spmd_region_ = false; -} -void ir_checker::operator()(if_inst const &in) { - visit(*this, *in.then()); - if (in.otherwise()) { - visit(*this, *in.otherwise()); - } -} -void ir_checker::operator()(parallel_inst const &p) { - this->operator()(static_cast(p)); - inside_spmd_region_ = true; - visit(*this, *p.body()); - inside_spmd_region_ = false; -} - -/* Region nodes */ -void ir_checker::operator()(rgn const &b) { - for (auto const &s : b.insts()) { - visit(*this, *s); - } -} - -/* Function nodes */ -void ir_checker::operator()(prototype const &) {} -void ir_checker::operator()(function const &fn) { - visit(*this, *fn.prototype()); - visit(*this, *fn.body()); -} - -/* Program nodes */ -void ir_checker::operator()(program const &p) { - for (auto const &s : p.declarations()) { - visit(*this, *s); - } -} - -} // namespace tinytc diff --git a/src/visitor/check_ir.hpp b/src/visitor/check_ir.hpp deleted file mode 100644 index 0b09eaa3..00000000 --- a/src/visitor/check_ir.hpp +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef CHECK_IR_20240222_HPP -#define CHECK_IR_20240222_HPP - -#include "node/function_node.hpp" -#include "node/inst_node.hpp" -#include "node/program_node.hpp" -#include "node/region_node.hpp" - -namespace tinytc { - -class ir_checker { - public: - /* Stmt nodes */ - void operator()(inst_node const &in); - void operator()(for_inst const &p); - void operator()(foreach_inst const &p); - void operator()(if_inst const &in); - void operator()(parallel_inst const &p); - - /* Region nodes */ - void operator()(rgn const &b); - - /* Func nodes */ - void operator()(prototype const &); - void operator()(function const &fn); - - /* Program nodes */ - void operator()(program const &p); - - private: - bool inside_spmd_region_ = false; -}; - -} // namespace tinytc - -#endif // CHECK_IR_20240222_HPP diff --git a/src/visitor/dump_ir.cpp b/src/visitor/dump_ir.cpp deleted file mode 100644 index d02d1ec8..00000000 --- a/src/visitor/dump_ir.cpp +++ /dev/null @@ -1,403 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "visitor/dump_ir.hpp" -#include "support/visit.hpp" -#include "tinytc/tinytc.hpp" - -#include -#include -#include -#include - -namespace tinytc { - -ir_dumper::ir_dumper(std::ostream &os) : os_(os) {} - -/* Data type nodes */ -void ir_dumper::operator()(void_data_type const &) { os_ << "void"; } -void ir_dumper::operator()(group_data_type const &g) { - os_ << "group<"; - visit(*this, *g.ty()); - os_ << ">"; -} -void ir_dumper::operator()(memref_data_type const &d) { - auto const val = [&](std::int64_t v) -> std::ostream & { - if (is_dynamic_value(v)) { - return os_ << "?"; - } - return os_ << v; - }; - os_ << "memref<" << to_string(d.element_ty()); - for (auto const &s : d.shape()) { - os_ << "x"; - val(s); - } - if (!d.is_canonical_stride()) { - os_ << ",strided<"; - do_with_infix(d.stride().begin(), d.stride().end(), [&](auto const &a) { val(a); }); - os_ << ">"; - } - os_ << ">"; -} -void ir_dumper::operator()(scalar_data_type const &s) { os_ << to_string(s.ty()); } - -/* Value nodes */ -void ir_dumper::operator()(float_imm const &v) { - auto flags = os_.flags(); - os_ << std::hexfloat << v.value(); - os_.flags(flags); -} -void ir_dumper::operator()(int_imm const &v) { - if (is_dynamic_value(v.value())) { - os_ << "?"; - } else { - os_ << v.value(); - } -} -void ir_dumper::operator()(val const &v) { - os_ << "%" << v.name(); - auto const slot = tracker_.get_slot(v); - if (slot >= 0) { - os_ << slot; - } -} - -/* Inst nodes */ -void ir_dumper::dump_blas_a2(blas_a2_inst const &g) { - visit(*this, *g.alpha()); - os_ << ", "; - visit(*this, *g.A()); - os_ << ", "; - visit(*this, *g.beta()); - os_ << ", "; - visit(*this, *g.B()); - os_ << " : "; - visit(*this, *g.alpha()->ty()); - os_ << ", "; - visit(*this, *g.A()->ty()); - os_ << ", "; - visit(*this, *g.beta()->ty()); - os_ << ", "; - visit(*this, *g.B()->ty()); -} - -void ir_dumper::dump_blas_a3(blas_a3_inst const &g) { - visit(*this, *g.alpha()); - os_ << ", "; - visit(*this, *g.A()); - os_ << ", "; - visit(*this, *g.B()); - os_ << ", "; - visit(*this, *g.beta()); - os_ << ", "; - visit(*this, *g.C()); - os_ << " : "; - visit(*this, *g.alpha()->ty()); - os_ << ", "; - visit(*this, *g.A()->ty()); - os_ << ", "; - visit(*this, *g.B()->ty()); - os_ << ", "; - visit(*this, *g.beta()->ty()); - os_ << ", "; - visit(*this, *g.C()->ty()); -} - -void ir_dumper::operator()(alloca_inst const &a) { - visit(*this, *a.result()); - os_ << " = alloca -> "; - visit(*this, *a.result()->ty()); -} - -void ir_dumper::operator()(axpby_inst const &a) { - os_ << "axpby"; - os_ << "." << to_string(a.tA()) << " "; - dump_blas_a2(static_cast(a)); -} - -void ir_dumper::operator()(arith_inst const &a) { - visit(*this, *a.result()); - os_ << " = arith." << to_string(a.operation()) << " "; - visit(*this, *a.a()); - os_ << ", "; - visit(*this, *a.b()); - os_ << " : "; - visit(*this, *a.a()->ty()); -} - -void ir_dumper::operator()(arith_unary_inst const &a) { - visit(*this, *a.result()); - os_ << " = arith." << to_string(a.operation()) << " "; - visit(*this, *a.a()); - os_ << " : "; - visit(*this, *a.a()->ty()); -} - -void ir_dumper::operator()(barrier_inst const &) { os_ << "barrier"; } - -void ir_dumper::operator()(cast_inst const &c) { - visit(*this, *c.result()); - os_ << " = cast "; - visit(*this, *c.a()); - os_ << " : "; - visit(*this, *c.a()->ty()); - os_ << " -> "; - visit(*this, *c.result()->ty()); -} - -void ir_dumper::operator()(compare_inst const &a) { - visit(*this, *a.result()); - os_ << " = cmp." << to_string(a.cond()) << " "; - visit(*this, *a.a()); - os_ << ", "; - visit(*this, *a.b()); - os_ << " : "; - visit(*this, *a.a()->ty()); -} - -void ir_dumper::operator()(expand_inst const &e) { - visit(*this, *e.result()); - os_ << " = expand "; - visit(*this, *e.operand()); - os_ << "[" << e.mode() << "->"; - do_with_infix( - e.expand_shape().begin(), e.expand_shape().end(), - [this](auto const &i) { visit(*this, *i); }, "x"); - os_ << "] : "; - visit(*this, *e.operand()->ty()); -} - -void ir_dumper::operator()(fuse_inst const &f) { - visit(*this, *f.result()); - os_ << " = fuse "; - visit(*this, *f.operand()); - os_ << "[" << f.from() << "," << f.to() << "]"; - os_ << " : "; - visit(*this, *f.operand()->ty()); -} - -void ir_dumper::operator()(load_inst const &e) { - visit(*this, *e.result()); - os_ << " = load "; - visit(*this, *e.operand()); - os_ << "["; - do_with_infix(e.index_list().begin(), e.index_list().end(), - [this](auto const &i) { visit(*this, *i); }); - os_ << "] : "; - visit(*this, *e.operand()->ty()); -} - -void ir_dumper::operator()(group_id_inst const &g) { - visit(*this, *g.result()); - os_ << " = group_id"; -} - -void ir_dumper::operator()(group_size_inst const &g) { - visit(*this, *g.result()); - os_ << " = group_size"; -} - -void ir_dumper::operator()(lifetime_stop_inst const &l) { - os_ << "lifetime_stop "; - visit(*this, *l.object()); -} - -void ir_dumper::operator()(gemm_inst const &g) { - os_ << "gemm"; - os_ << "." << to_string(g.tA()); - os_ << "." << to_string(g.tB()) << " "; - dump_blas_a3(static_cast(g)); -} - -void ir_dumper::operator()(gemv_inst const &g) { - os_ << "gemv"; - os_ << "." << to_string(g.tA()) << " "; - dump_blas_a3(static_cast(g)); -} - -void ir_dumper::operator()(ger_inst const &g) { - os_ << "ger "; - dump_blas_a3(static_cast(g)); -} - -void ir_dumper::operator()(for_inst const &p) { - os_ << "for "; - visit(*this, *p.loop_var()); - os_ << "="; - visit(*this, *p.from()); - os_ << ","; - visit(*this, *p.to()); - if (p.step()) { - os_ << ","; - visit(*this, *p.step()); - } - os_ << " : "; - visit(*this, *p.loop_var()->ty()); - os_ << " "; - visit(*this, *p.body()); -} - -void ir_dumper::operator()(foreach_inst const &p) { - os_ << "foreach "; - visit(*this, *p.loop_var()); - os_ << "="; - visit(*this, *p.from()); - os_ << ","; - visit(*this, *p.to()); - os_ << " : "; - visit(*this, *p.loop_var()->ty()); - os_ << " "; - visit(*this, *p.body()); -} - -void ir_dumper::operator()(hadamard_inst const &g) { - os_ << "hadamard "; - dump_blas_a3(static_cast(g)); -} - -void ir_dumper::operator()(if_inst const &in) { - os_ << "if "; - visit(*this, *in.condition()); - os_ << " "; - visit(*this, *in.then()); - if (in.otherwise()) { - os_ << " else "; - visit(*this, *in.otherwise()); - } -} - -void ir_dumper::operator()(num_subgroups_inst const &sg) { - visit(*this, *sg.result()); - os_ << " = num_subgroups"; -} - -void ir_dumper::operator()(parallel_inst const &p) { - os_ << "parallel "; - visit(*this, *p.body()); -} - -void ir_dumper::operator()(size_inst const &s) { - visit(*this, *s.result()); - os_ << " = size "; - visit(*this, *s.operand()); - os_ << "[" << s.mode() << "]"; - os_ << " : "; - visit(*this, *s.operand()->ty()); -} - -void ir_dumper::operator()(subgroup_id_inst const &sg) { - visit(*this, *sg.result()); - os_ << " = subgroup_id"; -} - -void ir_dumper::operator()(subgroup_local_id_inst const &sg) { - visit(*this, *sg.result()); - os_ << " = subgroup_local_id"; -} - -void ir_dumper::operator()(subgroup_size_inst const &sg) { - visit(*this, *sg.result()); - os_ << " = subgroup_size"; -} - -void ir_dumper::operator()(subview_inst const &s) { - visit(*this, *s.result()); - os_ << " = subview "; - visit(*this, *s.operand()); - os_ << "["; - auto irange = std::ranges::iota_view{std::size_t{0}, s.offset_list().size()}; - do_with_infix(irange.begin(), irange.end(), [&](auto const &i) { - visit(*this, *s.offset_list()[i]); - auto &size = s.size_list()[i]; - if (size) { - os_ << ":"; - visit(*this, *size); - } - }); - os_ << "]"; - os_ << " : "; - visit(*this, *s.operand()->ty()); - os_ << " ; -> "; - visit(*this, *s.result()->ty()); -} - -void ir_dumper::operator()(store_inst const &e) { - os_ << "store "; - visit(*this, *e.val()); - os_ << ", "; - visit(*this, *e.operand()); - os_ << "["; - do_with_infix(e.index_list().begin(), e.index_list().end(), - [this](auto const &i) { visit(*this, *i); }); - os_ << "] : "; - visit(*this, *e.operand()->ty()); -} - -void ir_dumper::operator()(sum_inst const &a) { - os_ << "sum"; - os_ << "." << to_string(a.tA()) << " "; - dump_blas_a2(static_cast(a)); -} - -void ir_dumper::operator()(yield_inst const &y) { - os_ << "yield "; - do_with_infix(y.op_begin(), y.op_end(), [this](auto const &i) { visit(*this, *i); }); - os_ << " : "; - do_with_infix(y.op_begin(), y.op_end(), [this](auto const &i) { visit(*this, *i->ty()); }); -} - -/* Region nodes */ -void ir_dumper::operator()(rgn const &b) { - os_ << "{" << std::endl; - ++lvl_; - auto ind = indent(); - for (auto const &s : b.insts()) { - os_ << ind; - visit(*this, *s); - os_ << std::endl; - } - --lvl_; - os_ << indent() << "}"; -} - -/* Function nodes */ -void ir_dumper::operator()(prototype const &p) { - os_ << "func @" << p.name() << "("; - std::string infix = ",\n "; - infix += std::string(p.name().size(), ' '); - do_with_infix( - p.args().begin(), p.args().end(), - [this](auto const &a) { - visit(*this, *a); - os_ << ": "; - visit(*this, *a->ty()); - }, - infix); - os_ << ")"; -} - -void ir_dumper::operator()(function const &fn) { - visit(*this, *fn.prototype()); - os_ << " "; - auto const sgs = fn.subgroup_size(); - auto const wgs = fn.work_group_size(); - if (sgs != 0) { - os_ << "subgroup_size(" << sgs << ") "; - } - if (wgs[0] != 0 && wgs[1] != 0) { - os_ << "work_group_size(" << wgs[0] << "," << wgs[1] << ") "; - } - visit(*this, *fn.body()); - os_ << std::endl; -} - -/* Program nodes */ -void ir_dumper::operator()(program const &p) { - visit(tracker_, p); - for (auto const &decl : p.declarations()) { - visit(*this, *decl); - } -} - -} // namespace tinytc diff --git a/src/visitor/slot_tracker.cpp b/src/visitor/slot_tracker.cpp deleted file mode 100644 index 439b60a5..00000000 --- a/src/visitor/slot_tracker.cpp +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "visitor/slot_tracker.hpp" -#include "support/visit.hpp" -#include "tinytc/tinytc.hpp" - -#include -#include - -namespace tinytc { - -void slot_tracker::set_slot(value_node const &v) { - if (!v.has_name()) { - slot_map_[&v] = slot_++; - } -} - -/* Stmt nodes */ -void slot_tracker::operator()(inst_node const &in) { - for (auto const &result : in.results()) { - set_slot(*result); - } -} -void slot_tracker::operator()(loop_inst const &p) { - set_slot(*p.loop_var()); - return visit(*this, *p.body()); -} - -void slot_tracker::operator()(if_inst const &in) { - visit(*this, *in.then()); - if (in.otherwise()) { - visit(*this, *in.otherwise()); - } -} - -void slot_tracker::operator()(parallel_inst const &p) { return visit(*this, *p.body()); } - -/* Region nodes */ -void slot_tracker::operator()(rgn const &b) { - for (auto const &s : b.insts()) { - visit(*this, *s); - } -} - -/* Function nodes */ -void slot_tracker::operator()(prototype const &p) { - for (auto const &arg : p.args()) { - set_slot(*arg); - } -} - -void slot_tracker::operator()(function const &fn) { - slot_ = 0; - visit(*this, *fn.prototype()); - visit(*this, *fn.body()); -} - -/* Program nodes */ -void slot_tracker::operator()(program const &p) { - for (auto const &s : p.declarations()) { - visit(*this, *s); - } -} - -auto slot_tracker::get_slot(value_node const &v) -> std::int64_t { - auto it = slot_map_.find(&v); - return it != slot_map_.end() ? it->second : -1; -} - -} // namespace tinytc diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index e4040322..ca72a24f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -38,7 +38,7 @@ set(LIT_COMMAND lit "${CMAKE_CURRENT_BINARY_DIR}" -v) add_test(lit-check ${LIT_COMMAND}) set_tests_properties(lit-check PROPERTIES LABELS "lit") add_custom_target(lit-check COMMAND ${LIT_COMMAND}) -add_dependencies(lit-check tinytc-oc) +add_dependencies(lit-check tinytc-oc tinytc-opt) if(BUILD_OPENCL) diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 6c264818..9a575ea0 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -12,3 +12,4 @@ config.test_exec_root = config.my_exec_root config.substitutions.append(('%tinytc-oc', config.tinytc_oc_path)) +config.substitutions.append(('%tinytc-opt', config.tinytc_opt_path)) diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in index e6d354c8..4e8d641a 100644 --- a/test/lit.site.cfg.py.in +++ b/test/lit.site.cfg.py.in @@ -6,5 +6,6 @@ import os config.my_src_root = r'@CMAKE_CURRENT_SOURCE_DIR@' config.my_exec_root = r'@CMAKE_CURRENT_BINARY_DIR@' config.tinytc_oc_path = r'$' +config.tinytc_opt_path = r'$' lit_config.load_config(config, os.path.join(config.my_src_root, 'lit.cfg.py')) diff --git a/test/codegen/nesting0.ir b/test/opt/check-ir/nesting0.ir similarity index 86% rename from test/codegen/nesting0.ir rename to test/opt/check-ir/nesting0.ir index 9a03a443..3683b9a3 100644 --- a/test/codegen/nesting0.ir +++ b/test/opt/check-ir/nesting0.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s +; RUN: not %tinytc-opt --check-ir < %s 2>&1 | filecheck %s func @illegal_nesting(%A: memref, %B: memref, %C: memref) { foreach %i=1,16 { gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref diff --git a/test/codegen/nesting1.ir b/test/opt/check-ir/nesting1.ir similarity index 80% rename from test/codegen/nesting1.ir rename to test/opt/check-ir/nesting1.ir index 9bc20f52..b6ec7dd9 100644 --- a/test/codegen/nesting1.ir +++ b/test/opt/check-ir/nesting1.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s +; RUN: not %tinytc-opt --check-ir < %s 2>&1 | filecheck %s func @illegal_nesting() { foreach %i=1,16 { foreach %j=1,16 { diff --git a/test/codegen/nesting2.ir b/test/opt/check-ir/nesting2.ir similarity index 77% rename from test/codegen/nesting2.ir rename to test/opt/check-ir/nesting2.ir index e336529d..35f302e5 100644 --- a/test/codegen/nesting2.ir +++ b/test/opt/check-ir/nesting2.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s +; RUN: not %tinytc-opt --check-ir < %s 2>&1 | filecheck %s func @illegal_nesting() { %0 = subgroup_id ; CHECK: 6.10-20: SPMD instruction must not be called from collective region diff --git a/test/codegen/nesting3.ir b/test/opt/check-ir/nesting3.ir similarity index 80% rename from test/codegen/nesting3.ir rename to test/opt/check-ir/nesting3.ir index 9e3151b2..c36375a1 100644 --- a/test/codegen/nesting3.ir +++ b/test/opt/check-ir/nesting3.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s +; RUN: not %tinytc-opt --check-ir < %s 2>&1 | filecheck %s func @illegal_nesting() { parallel { foreach %j=1,16 { diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 1cf91cce..cb910352 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -2,3 +2,4 @@ # SPDX-License-Identifier: BSD-3-Clause add_subdirectory(offline_compiler) +add_subdirectory(opt) diff --git a/tools/opt/CMakeLists.txt b/tools/opt/CMakeLists.txt new file mode 100644 index 00000000..6368f8bd --- /dev/null +++ b/tools/opt/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause + +include(CommonOptions) +include(GNUInstallDirs) + +add_executable(tinytc-opt main.cpp args.cpp) +target_link_libraries(tinytc-opt PRIVATE tinytc) +set_cxx_common_options(tinytc-opt) + +set_target_properties(tinytc-opt PROPERTIES INSTALL_RPATH_USE_LINK_PATH True) +set_target_properties(tinytc-opt PROPERTIES INSTALL_RPATH "\$ORIGIN/../${CMAKE_INSTALL_LIBDIR}") + +install(TARGETS tinytc-opt + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} +) diff --git a/tools/opt/args.cpp b/tools/opt/args.cpp new file mode 100644 index 00000000..8bcd6e8f --- /dev/null +++ b/tools/opt/args.cpp @@ -0,0 +1,104 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "args.hpp" +#include "tinytc/types.hpp" + +#include +#include +#include + +using tinytc::core_info; +using tinytc::intel_gpu_architecture; +using tinytc::list_function_passes; +using tinytc::make_core_info_intel_from_arch; + +auto make_core_info_from_string(char const *name) -> core_info { + if (std::strcmp(name, "pvc") == 0) { + return make_core_info_intel_from_arch(intel_gpu_architecture::pvc); + } else if (std::strcmp(name, "tgl") == 0) { + return make_core_info_intel_from_arch(intel_gpu_architecture::tgl); + } + return core_info{}; +} + +args arg_parser::parse_args(int argc, char **argv) { + args a = {}; + a.filename = nullptr; + + std::uint32_t names_size = 0; + char const *const *names = nullptr; + list_function_passes(names_size, names); + + auto const has_function_pass = [&names_size, names](char const *pass_name) -> bool { + for (std::uint32_t i = 0; i < names_size; ++i) { + if (std::strcmp(pass_name, names[i]) == 0) { + return true; + } + } + return false; + }; + + int npos = 0; + for (int i = 1; i < argc; ++i) { + if (argv[i][0] == '-') { + auto const fail = [&]() { + throw std::runtime_error( + (std::ostringstream{} << "==> Unrecognized argument: " << argv[i]).str()); + }; + if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0) { + a.help = true; + } else if (argv[i][1] == '-' && has_function_pass(argv[i] + 2)) { + a.pass_names.emplace_back(std::string(argv[i] + 2)); + } else if (i + 1 < argc) { + if (std::strcmp(argv[i], "-d") == 0 || std::strcmp(argv[i], "--device") == 0) { + a.info = make_core_info_from_string(argv[++i]); + if (!a.info) { + throw std::runtime_error( + (std::ostringstream{} << "==> Unknown device: " << argv[i]).str()); + } + } else { + fail(); + } + } else { + fail(); + } + } else { + if (npos == 0) { + a.filename = argv[i]; + ++npos; + } else { + throw std::runtime_error("==> At most a single positional argument is expected"); + } + } + } + if (!a.info) { + a.info = make_core_info_intel_from_arch(intel_gpu_architecture::pvc); + } + + if (a.pass_names.empty() || std::strcmp(a.pass_names.back().c_str(), "dump-ir") != 0) { + a.pass_names.emplace_back(std::string("dump-ir")); + } + + return a; +} + +void arg_parser::show_help(std::ostream &os) { + os << "usage: tinytc-opt [-d ] [file-name]" << std::endl + << R"HELP( +positional arguments: + file-name Path to source code; leave empty to read from stdin + +optional arguments: + -d, --device Device name (cf. intel_gpu_architecture enum), default is "pvc" + -h, --help Show help text and exit + +passes: +)HELP"; + std::uint32_t names_size = 0; + char const *const *names = nullptr; + list_function_passes(names_size, names); + for (std::uint32_t i = 0; i < names_size; ++i) { + os << " --" << names[i] << std::endl; + } +} diff --git a/tools/opt/args.hpp b/tools/opt/args.hpp new file mode 100644 index 00000000..8ddd7754 --- /dev/null +++ b/tools/opt/args.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef ARGS_20240911_HPP +#define ARGS_20240911_HPP + +#include "tinytc/tinytc.hpp" + +#include +#include +#include + +struct args { + std::vector pass_names; + char const *filename; + tinytc::core_info info; + bool help; +}; + +class arg_parser { + public: + static args parse_args(int argc, char **argv); + static void show_help(std::ostream &os); +}; + +#endif // ARGS_20240911_HPP diff --git a/tools/opt/main.cpp b/tools/opt/main.cpp new file mode 100644 index 00000000..662714b9 --- /dev/null +++ b/tools/opt/main.cpp @@ -0,0 +1,55 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "args.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" + +#include +#include +#include +#include +#include + +using namespace tinytc; + +int main(int argc, char **argv) { + auto a = args{}; + try { + a = arg_parser::parse_args(argc, argv); + } catch (status const &st) { + std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; + return -1; + } catch (std::runtime_error const &e) { + std::cerr << e.what() << std::endl; + return -1; + } + if (a.help) { + arg_parser::show_help(std::cout); + return 0; + } + + auto ctx = source_context{}; + try { + ctx = make_source_context(); + auto p = prog{}; + if (!a.filename) { + p = parse_stdin(ctx); + } else { + p = parse_file(a.filename, ctx); + } + + for (auto const &pass_name : a.pass_names) { + run_function_pass(pass_name.c_str(), p, a.info, ctx); + } + } catch (status const &st) { + std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; + std::cerr << "Error log: " << std::endl << ctx.get_error_log() << std::endl; + return 1; + } catch (std::exception const &e) { + std::cerr << e.what() << std::endl; + return 1; + } + + return 0; +} From 06e415e1e53a7911780f21fd29526ffc631391ad Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 12 Sep 2024 18:45:39 +0200 Subject: [PATCH 015/297] Further refactoring of compiler passes Signed-off-by: Carsten Uphoff --- src/CMakeLists.txt | 28 ++--- src/{pass => analysis}/aa_results.cpp | 10 +- src/{pass => analysis}/aa_results.hpp | 12 +- src/analysis/alias.cpp | 78 ++++++++++++ src/analysis/alias.hpp | 19 +++ src/compiler.cpp | 16 ++- src/node/inst_node.cpp | 4 +- src/node/inst_node.hpp | 4 +- src/node/region_node.hpp | 34 ++--- src/pass/alias_analysis.cpp | 76 ------------ src/pass/alias_analysis.hpp | 45 ------- src/pass/insert_lifetime_stop.cpp | 70 +++++++++++ src/pass/insert_lifetime_stop.hpp | 28 +++++ src/pass/lifetime_analysis.cpp | 171 -------------------------- src/pass/lifetime_analysis.hpp | 76 ------------ src/pass/stack.cpp | 115 ++++++++--------- src/pass/stack.hpp | 33 +---- src/pass/work_group_size.cpp | 85 +++++-------- src/pass/work_group_size.hpp | 29 +---- src/passes.def | 3 + src/support/walk.cpp | 6 +- src/support/walk.hpp | 16 +++ test/opt/insert-lifetime-stop.ir | 61 +++++++++ test/opt/work-group-size.ir | 26 ++++ 24 files changed, 450 insertions(+), 595 deletions(-) rename src/{pass => analysis}/aa_results.cpp (81%) rename src/{pass => analysis}/aa_results.hpp (85%) create mode 100644 src/analysis/alias.cpp create mode 100644 src/analysis/alias.hpp delete mode 100644 src/pass/alias_analysis.cpp delete mode 100644 src/pass/alias_analysis.hpp create mode 100644 src/pass/insert_lifetime_stop.cpp create mode 100644 src/pass/insert_lifetime_stop.hpp delete mode 100644 src/pass/lifetime_analysis.cpp delete mode 100644 src/pass/lifetime_analysis.hpp create mode 100644 test/opt/insert-lifetime-stop.ir create mode 100644 test/opt/work-group-size.ir diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7e2ba6e9..ddcdc24d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -17,6 +17,8 @@ find_package(re2c REQUIRED) find_package(BISON 3.8.2 REQUIRED) set(SOURCES + analysis/aa_results.cpp + analysis/alias.cpp binary.cpp codegen_tools.cpp compiler.cpp @@ -31,6 +33,18 @@ set(SOURCES node/inst_node.cpp parser/parse_context.cpp parser.cpp + pass/check_ir.cpp + #pass/constant_propagation.cpp + pass/dump_ir.cpp + pass/equal.cpp + #pass/insert_barrier.cpp + pass/insert_lifetime_stop.cpp + #pass/lower_linalg.cpp + #pass/metadata.cpp + #pass/opencl_ast.cpp + pass/slot_tracker.cpp + pass/stack.cpp + pass/work_group_size.cpp passes.cpp prog.cpp recipe.cpp @@ -42,20 +56,6 @@ set(SOURCES source.cpp tiling.cpp value.cpp - #pass/aa_results.cpp - #pass/alias_analysis.cpp - pass/check_ir.cpp - #pass/constant_propagation.cpp - pass/dump_ir.cpp - pass/equal.cpp - #pass/insert_barrier.cpp - #pass/lifetime_analysis.cpp - #pass/lower_linalg.cpp - #pass/metadata.cpp - #pass/opencl_ast.cpp - pass/slot_tracker.cpp - #pass/stack.cpp - #pass/work_group_size.cpp support/walk.cpp ) set(RE2C_SOURCES diff --git a/src/pass/aa_results.cpp b/src/analysis/aa_results.cpp similarity index 81% rename from src/pass/aa_results.cpp rename to src/analysis/aa_results.cpp index fbf8aa44..391be83e 100644 --- a/src/pass/aa_results.cpp +++ b/src/analysis/aa_results.cpp @@ -1,7 +1,7 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "pass/aa_results.hpp" +#include "analysis/aa_results.hpp" #include "node/value_node.hpp" #include @@ -12,14 +12,14 @@ aa_results::aa_results(std::unordered_map allocs) : alias_(std::move(alias)), allocs_(std::move(allocs)) {} -auto aa_results::root(value_node const &a) -> value_node const * { +auto aa_results::root(value_node const &a) const -> value_node const * { auto root = &a; - if (alias_.find(root) != alias_.end()) { - root = alias_[root]; + if (auto it = alias_.find(root); it != alias_.end()) { + root = it->second; } return root; } -bool aa_results::alias(value_node const &a, value_node const &b) { +bool aa_results::alias(value_node const &a, value_node const &b) const { auto ra = root(a); auto rb = root(b); if (ra == rb) { diff --git a/src/pass/aa_results.hpp b/src/analysis/aa_results.hpp similarity index 85% rename from src/pass/aa_results.hpp rename to src/analysis/aa_results.hpp index e65f2613..3c553e64 100644 --- a/src/pass/aa_results.hpp +++ b/src/analysis/aa_results.hpp @@ -13,21 +13,19 @@ namespace tinytc { class aa_results { public: - aa_results() = default; - auto root(::tinytc_value const &a) -> ::tinytc_value const *; - bool alias(::tinytc_value const &a, ::tinytc_value const &b); - - private: struct allocation { std::int64_t start, stop; }; aa_results(std::unordered_map<::tinytc_value const *, ::tinytc_value const *> alias, std::unordered_map<::tinytc_value const *, allocation> allocs); + + auto root(::tinytc_value const &a) const -> ::tinytc_value const *; + bool alias(::tinytc_value const &a, ::tinytc_value const &b) const; + + private: std::unordered_map<::tinytc_value const *, ::tinytc_value const *> alias_; std::unordered_map<::tinytc_value const *, allocation> allocs_; - - friend class alias_analyser; }; } // namespace tinytc diff --git a/src/analysis/alias.cpp b/src/analysis/alias.cpp new file mode 100644 index 00000000..e51add20 --- /dev/null +++ b/src/analysis/alias.cpp @@ -0,0 +1,78 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "analysis/alias.hpp" +#include "error.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "support/casting.hpp" +#include "support/visit.hpp" +#include "support/walk.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" + +#include +#include + +namespace tinytc { + +class alias_analysis_visitor { + public: + void operator()(inst_node const &); + void operator()(alloca_inst const &a); + void operator()(expand_inst const &e); + void operator()(fuse_inst const &f); + void operator()(subview_inst const &s); + + auto get_result() && -> aa_results { return aa_results(std::move(alias_), std::move(allocs_)); } + + private: + std::unordered_map allocs_; + std::unordered_map alias_; +}; + +void alias_analysis_visitor::operator()(inst_node const &) {} +void alias_analysis_visitor::operator()(alloca_inst const &a) { + if (a.stack_ptr() >= 0) { + auto t = dyn_cast(a.result()->ty().get()); + if (t == nullptr) { + throw compilation_error(a.loc(), status::ir_expected_memref); + } + allocs_[a.result().get()] = + aa_results::allocation{a.stack_ptr(), a.stack_ptr() + t->size_in_bytes()}; + } +} +void alias_analysis_visitor::operator()(expand_inst const &e) { + value_node const *source = e.operand().get(); + while (alias_.find(source) != alias_.end()) { + source = alias_[source]; + } + alias_[e.result().get()] = source; +} +void alias_analysis_visitor::operator()(fuse_inst const &f) { + value_node const *source = f.operand().get(); + while (alias_.find(source) != alias_.end()) { + source = alias_[source]; + } + alias_[f.result().get()] = source; +} + +void alias_analysis_visitor::operator()(subview_inst const &s) { + value_node const *source = s.operand().get(); + while (alias_.find(source) != alias_.end()) { + source = alias_[source]; + } + alias_[s.result().get()] = source; +} + +auto alias_analysis::run_on_function(function &fn) -> aa_results { + auto visitor = alias_analysis_visitor{}; + + walk(fn, [&visitor](inst_node &i) { visit(visitor, i); }); + + return std::move(visitor).get_result(); +} + +} // namespace tinytc diff --git a/src/analysis/alias.hpp b/src/analysis/alias.hpp new file mode 100644 index 00000000..02314ed0 --- /dev/null +++ b/src/analysis/alias.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef ALIAS_20240912_HPP +#define ALIAS_20240912_HPP + +#include "analysis/aa_results.hpp" +#include "node/function_node.hpp" + +namespace tinytc { + +class alias_analysis { + public: + auto run_on_function(function &fn) -> aa_results; +}; + +} // namespace tinytc + +#endif // ALIAS_20240912_HPP diff --git a/src/compiler.cpp b/src/compiler.cpp index 74bcec12..b14e253f 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -7,6 +7,9 @@ #include "parser.hpp" #include "pass/check_ir.hpp" #include "pass/dump_ir.hpp" +#include "pass/insert_lifetime_stop.hpp" +#include "pass/stack.hpp" +#include "pass/work_group_size.hpp" #include "passes.hpp" #include "reference_counted.hpp" #include "required_extensions.hpp" @@ -41,8 +44,13 @@ tinytc_status_t tinytc_run_function_pass(char const *pass_name, tinytc_prog_t pr if (strcmp(NAME, pass_name) == 0) { \ return run_function_pass(CREATE_PASS, *prg); \ } +#define FUNCTION_PASS_WITH_INFO(NAME, CREATE_PASS) \ + if (strcmp(NAME, pass_name) == 0) { \ + return run_function_pass(CREATE_PASS(info), *prg); \ + } #include "passes.def" #undef FUNCTION_PASS +#undef FUNCTION_PASS_WITH_INFO }, ctx); } @@ -52,10 +60,12 @@ tinytc_status_t tinytc_list_function_passes(uint32_t *names_size, char const *co return tinytc_status_invalid_arguments; } #define FUNCTION_PASS(NAME, CREATE_PASS) NAME, +#define FUNCTION_PASS_WITH_INFO(NAME, CREATE_PASS) NAME, static char const *const pass_names[] = { #include "passes.def" }; #undef FUNCTION_PASS +#undef FUNCTION_PASS_WITH_INFO *names_size = sizeof(pass_names) / sizeof(char const *); *names = pass_names; @@ -72,10 +82,10 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ [&] { // passes run_function_pass(check_ir_pass{}, *prg); - // insert_lifetime_stop_inst(*prg); - // set_stack_ptrs(*prg); + run_function_pass(insert_lifetime_stop_pass{}, *prg); + run_function_pass(set_stack_ptr_pass{}, *prg); // insert_barriers(*prg); - // set_work_group_size(*prg, *info); + run_function_pass(work_group_size_pass{info}, *prg); // lower_linalg(*prg, *info); run_function_pass(dump_ir_pass{std::cout}, *prg); // propagate_constants(*prg); diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 39b1e4f3..09aeaee7 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -54,7 +54,7 @@ blas_a3_inst::blas_a3_inst(IK tid, value alpha, value A, value B, value beta, va loop_inst::loop_inst(IK tid, value loop_var0, value from0, value to0, value step0, region body, location const &lc) - : standard_inst{tid} { + : standard_inst{tid, step0 ? 4 : 3} { op(op_loop_var) = std::move(loop_var0); op(op_from) = std::move(from0); op(op_to) = std::move(to0); @@ -461,7 +461,7 @@ hadamard_inst::hadamard_inst(value alpha0, value A0, value B0, value beta0, valu if_inst::if_inst(value condition, region then, region otherwise, std::vector const &return_types, location const &lc) - : standard_inst{IK::if_, 1, static_cast(return_types.size())} { + : standard_inst{IK::if_, 1, static_cast(return_types.size()), otherwise ? 2 : 1} { op(0) = std::move(condition); child_region(child_region_then) = std::move(then); child_region(child_region_otherwise) = std::move(otherwise); diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 94840760..b4998981 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -227,7 +227,9 @@ using inst_node = ::tinytc_inst; template class object_container { public: object_container(std::int64_t num_objects) { - if (num_objects != NumObjects) { + // Check that num_objects is not larger than container size + // Smaller is ok too support optional arguments + if (num_objects > NumObjects) { throw internal_compiler_error(); } } diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp index e6ce7945..c30a3c30 100644 --- a/src/node/region_node.hpp +++ b/src/node/region_node.hpp @@ -12,13 +12,11 @@ #include #include -namespace tinytc { -using inst_range = iterator_range_wrapper; -using const_inst_range = iterator_range_wrapper; -} // namespace tinytc - struct tinytc_region : tinytc::reference_counted { public: + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + inline tinytc_region(std::vector insts = {}, tinytc::location const &lc = {}) : insts_(std::move(insts)) { loc(lc); @@ -27,21 +25,23 @@ struct tinytc_region : tinytc::reference_counted { inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } - inline auto begin() -> tinytc::inst * { return insts_.size() > 0 ? insts_.data() : nullptr; } - inline auto end() -> tinytc::inst * { - return insts_.size() > 0 ? insts_.data() + insts_.size() : nullptr; + inline auto begin() -> iterator { return insts_.begin(); } + inline auto end() -> iterator { return insts_.end(); } + inline auto insts() -> tinytc::iterator_range_wrapper { return {begin(), end()}; } + inline auto begin() const -> const_iterator { return insts_.cbegin(); } + inline auto end() const -> const_iterator { return insts_.cend(); } + inline auto insts() const -> tinytc::iterator_range_wrapper { + return {begin(), end()}; } - inline auto insts() -> tinytc::inst_range { return tinytc::inst_range{begin(), end()}; } - inline auto begin() const -> tinytc::inst const * { - return insts_.size() > 0 ? insts_.data() : nullptr; - } - inline auto end() const -> tinytc::inst const * { - return insts_.size() > 0 ? insts_.data() + insts_.size() : nullptr; + inline void insts(std::vector insts) { insts_ = std::move(insts); } + inline auto erase(iterator pos) -> iterator { return insts_.erase(pos); } + inline auto insert(iterator pos, tinytc::inst const &i) -> iterator { + return insts_.insert(pos, i); } - inline auto insts() const -> tinytc::const_inst_range { - return tinytc::const_inst_range{begin(), end()}; + inline auto insert(iterator pos, tinytc::inst &&i) -> iterator { + return insts_.insert(pos, std::move(i)); } - inline void insts(std::vector insts) { insts_ = std::move(insts); } + inline auto empty() const -> bool { return insts_.empty(); } private: std::vector insts_; diff --git a/src/pass/alias_analysis.cpp b/src/pass/alias_analysis.cpp deleted file mode 100644 index b0d45023..00000000 --- a/src/pass/alias_analysis.cpp +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "pass/alias_analysis.hpp" -#include "error.hpp" -#include "node/data_type_node.hpp" -#include "support/casting.hpp" -#include "support/visit.hpp" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.hpp" - -#include - -namespace tinytc { - -/* Stmt nodes */ -void alias_analyser::operator()(inst_node const &) {} -void alias_analyser::operator()(alloca_inst const &a) { - auto t = dyn_cast(a.result()->ty().get()); - if (t == nullptr) { - throw compilation_error(a.loc(), status::ir_expected_memref); - } - allocs_[a.result().get()] = - aa_results::allocation{a.stack_ptr(), a.stack_ptr() + t->size_in_bytes()}; -} -void alias_analyser::operator()(loop_inst const &p) { visit(*this, *p.body()); } -void alias_analyser::operator()(expand_inst const &e) { - value_node const *source = e.operand().get(); - while (alias_.find(source) != alias_.end()) { - source = alias_[source]; - } - alias_[e.result().get()] = source; -} -void alias_analyser::operator()(fuse_inst const &f) { - value_node const *source = f.operand().get(); - while (alias_.find(source) != alias_.end()) { - source = alias_[source]; - } - alias_[f.result().get()] = source; -} - -void alias_analyser::operator()(if_inst const &in) { - visit(*this, *in.then()); - if (in.otherwise()) { - visit(*this, *in.otherwise()); - } -} - -void alias_analyser::operator()(parallel_inst const &p) { visit(*this, *p.body()); } - -void alias_analyser::operator()(subview_inst const &s) { - value_node const *source = s.operand().get(); - while (alias_.find(source) != alias_.end()) { - source = alias_[source]; - } - alias_[s.result().get()] = source; -} - -/* Region nodes */ -void alias_analyser::operator()(rgn const &b) { - for (auto &s : b.insts()) { - visit(*this, *s); - } -} - -/* Function nodes */ -void alias_analyser::operator()(prototype const &) {} -void alias_analyser::operator()(function const &fn) { - alias_.clear(); - visit(*this, *fn.prototype()); - visit(*this, *fn.body()); -} - -aa_results alias_analyser::get_result() const { return aa_results(alias_, allocs_); } - -} // namespace tinytc diff --git a/src/pass/alias_analysis.hpp b/src/pass/alias_analysis.hpp deleted file mode 100644 index 9ae02e8d..00000000 --- a/src/pass/alias_analysis.hpp +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef ALIAS_ANALYSIS_20230330_HPP -#define ALIAS_ANALYSIS_20230330_HPP - -#include "node/function_node.hpp" -#include "node/inst_node.hpp" -#include "node/region_node.hpp" -#include "node/value_node.hpp" -#include "pass/aa_results.hpp" - -#include - -namespace tinytc { - -class alias_analyser { - public: - /* Stmt nodes */ - void operator()(inst_node const &); - void operator()(alloca_inst const &a); - void operator()(loop_inst const &p); - void operator()(expand_inst const &e); - void operator()(fuse_inst const &f); - void operator()(if_inst const &in); - void operator()(parallel_inst const &p); - void operator()(subview_inst const &s); - - /* Region nodes */ - void operator()(rgn const &b); - - /* Func nodes */ - void operator()(prototype const &); - void operator()(function const &fn); - - aa_results get_result() const; - - private: - std::unordered_map allocs_; - std::unordered_map alias_; -}; - -} // namespace tinytc - -#endif // ALIAS_ANALYSIS_20230330_HPP diff --git a/src/pass/insert_lifetime_stop.cpp b/src/pass/insert_lifetime_stop.cpp new file mode 100644 index 00000000..dd74d4ad --- /dev/null +++ b/src/pass/insert_lifetime_stop.cpp @@ -0,0 +1,70 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/insert_lifetime_stop.hpp" +#include "analysis/alias.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "node/value_node.hpp" +#include "support/visit.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +auto insert_lifetime_stop_pass::run_on_region(rgn ®, aa_results const &aa) + -> std::unordered_set { + if (reg.empty()) { + return {}; + } + + auto allocas = std::vector{}; + for (auto &i : reg) { + if (auto alloca = dyn_cast(i.get()); alloca != nullptr) { + allocas.emplace_back(alloca->result(0)); + } + } + + auto rgn_ops = std::unordered_set{}; + auto prev_it = reg.end(); + for (; prev_it != reg.begin(); --prev_it) { + auto &i = *(prev_it - 1); + for (auto &subreg : i->child_regions()) { + if (subreg) { + rgn_ops.merge(run_on_region(*subreg, aa)); + } + } + for (auto &v : i->operands()) { + if (isa(*v->ty())) { + rgn_ops.insert(aa.root(*v)); + } + } + for (auto &v : i->results()) { + if (isa(*v->ty())) { + rgn_ops.insert(aa.root(*v)); + } + } + + auto alloca_it = allocas.begin(); + while (alloca_it != allocas.end()) { + if (rgn_ops.contains(alloca_it->get())) { + prev_it = reg.insert( + prev_it, inst{std::make_unique(*alloca_it).release()}); + alloca_it = allocas.erase(alloca_it); + } else { + ++alloca_it; + } + } + } + return rgn_ops; +} + +void insert_lifetime_stop_pass::run_on_function(function &fn) { + auto aa = alias_analysis{}.run_on_function(fn); + run_on_region(*fn.body(), aa); +} + +} // namespace tinytc diff --git a/src/pass/insert_lifetime_stop.hpp b/src/pass/insert_lifetime_stop.hpp new file mode 100644 index 00000000..2c91129e --- /dev/null +++ b/src/pass/insert_lifetime_stop.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef INSERT_LIFETIME_STOP_20240912_HPP +#define INSERT_LIFETIME_STOP_20240912_HPP + +#include "analysis/aa_results.hpp" +#include "node/function_node.hpp" +#include "node/region_node.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" + +#include + +namespace tinytc { + +class insert_lifetime_stop_pass { + public: + void run_on_function(function &fn); + + private: + auto run_on_region(rgn ®, aa_results const &aa) + -> std::unordered_set<::tinytc_value const *>; +}; + +} // namespace tinytc + +#endif // INSERT_LIFETIME_STOP_20240912_HPP diff --git a/src/pass/lifetime_analysis.cpp b/src/pass/lifetime_analysis.cpp deleted file mode 100644 index 4fb5f365..00000000 --- a/src/pass/lifetime_analysis.cpp +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "pass/lifetime_analysis.hpp" -#include "node/value_node.hpp" -#include "pass/alias_analysis.hpp" -#include "support/visit.hpp" - -#include -#include -#include - -namespace tinytc { - -find_alloca::find_alloca(bool recursive) : recursive_(recursive) {} - -/* Inst nodes */ -value find_alloca::operator()(inst_node &) { return value{}; } -value find_alloca::operator()(alloca_inst &a) { return a.result(); } -value find_alloca::operator()(for_inst &p) { - if (recursive_) { - visit(*this, *p.body()); - } - return value{}; -} -value find_alloca::operator()(if_inst &in) { - if (recursive_) { - visit(*this, *in.then()); - if (in.otherwise()) { - visit(*this, *in.otherwise()); - } - } - return value{}; -} - -/* Region nodes */ -value find_alloca::operator()(rgn &b) { - for (auto &s : b.insts()) { - alloca_.emplace_back(visit(*this, *s)); - } - return value{}; -} - -std::vector find_alloca::allocas() const { return alloca_; } - -/* Inst nodes */ -auto lifetime_inserter::operator()(inst_node &) -> std::unordered_set { - return {}; -} - -auto lifetime_inserter::operator()(blas_a2_inst &a) -> std::unordered_set { - return {a.A().get(), a.B().get()}; -} - -auto lifetime_inserter::operator()(blas_a3_inst &inst) -> std::unordered_set { - return {inst.A().get(), inst.B().get(), inst.C().get()}; -} - -auto lifetime_inserter::operator()(loop_inst &p) -> std::unordered_set { - return visit(*this, *p.body()); -} - -auto lifetime_inserter::operator()(alloca_inst &a) -> std::unordered_set { - return {a.result().get()}; -} - -auto lifetime_inserter::operator()(barrier_inst &) -> std::unordered_set { - return {}; -} - -auto lifetime_inserter::operator()(expand_inst &e) -> std::unordered_set { - return std::unordered_set{e.operand().get(), e.result().get()}; -} - -auto lifetime_inserter::operator()(fuse_inst &f) -> std::unordered_set { - return std::unordered_set{f.operand().get(), f.result().get()}; -} - -auto lifetime_inserter::operator()(load_inst &e) -> std::unordered_set { - return std::unordered_set{e.operand().get(), e.result().get()}; -} - -auto lifetime_inserter::operator()(if_inst &in) -> std::unordered_set { - auto s = visit(*this, *in.then()); - if (in.otherwise()) { - s.merge(visit(*this, *in.otherwise())); - } - return s; -} - -auto lifetime_inserter::operator()(lifetime_stop_inst &ls) - -> std::unordered_set { - return {ls.object().get()}; -} - -auto lifetime_inserter::operator()(parallel_inst &p) -> std::unordered_set { - return visit(*this, *p.body()); -} - -auto lifetime_inserter::operator()(size_inst &s) -> std::unordered_set { - return std::unordered_set{s.operand().get()}; -} - -auto lifetime_inserter::operator()(store_inst &s) -> std::unordered_set { - return std::unordered_set{s.operand().get()}; -} - -auto lifetime_inserter::operator()(subview_inst &s) -> std::unordered_set { - return {s.result().get(), s.operand().get()}; -} - -auto lifetime_inserter::operator()(yield_inst &) -> std::unordered_set { - return {}; -} - -/* Region nodes */ -auto lifetime_inserter::operator()(rgn &b) -> std::unordered_set { - auto const intersects = [](std::vector &a, - std::unordered_set const &b) { - for (auto aa = a.begin(); aa != a.end(); ++aa) { - if (b.find(aa->get()) != b.end()) { - return aa; - } - } - return a.end(); - }; - - auto fa = find_alloca{}; - fa(b); - auto allocas = fa.allocas(); - - auto rgn_ops = std::unordered_set{}; - - auto s = b.insts().end(); - while (s != b.insts().begin()) { - auto operands = visit(*this, **(s - 1)); - rgn_ops.insert(operands.begin(), operands.end()); - auto operands_root = decltype(operands){}; - std::transform(operands.begin(), operands.end(), - std::inserter(operands_root, operands_root.begin()), - [this](auto const &op) { return aa_.root(*op); }); - std::vector::iterator aa; - while ((aa = intersects(allocas, operands_root)) != allocas.end()) { - s = b.insts().insert(s, inst{std::make_unique(*aa).release()}); - allocas.erase(aa); - } - --s; - } - - return rgn_ops; -} - -/* Function nodes */ -void lifetime_inserter::operator()(prototype &) {} - -void lifetime_inserter::operator()(function &fn) { - auto aa = alias_analyser{}; - aa(fn); - aa_ = aa.get_result(); - - visit(*this, *fn.prototype()); - visit(*this, *fn.body()); -} - -void lifetime_inserter::operator()(program &p) { - for (auto &fn : p.functions()) { - visit(*this, *fn); - } -} - -} // namespace tinytc diff --git a/src/pass/lifetime_analysis.hpp b/src/pass/lifetime_analysis.hpp deleted file mode 100644 index 9d901a7e..00000000 --- a/src/pass/lifetime_analysis.hpp +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef LIFETIME_ANALYSIS_20230329_HPP -#define LIFETIME_ANALYSIS_20230329_HPP - -#include "node/function_node.hpp" -#include "node/inst_node.hpp" -#include "node/program_node.hpp" -#include "node/region_node.hpp" -#include "pass/aa_results.hpp" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.h" - -#include -#include - -namespace tinytc { - -class find_alloca { - public: - find_alloca(bool recursive = false); - - /* Inst nodes */ - value operator()(inst_node &); - value operator()(alloca_inst &a); - value operator()(for_inst &p); - value operator()(if_inst &p); - - /* Region nodes */ - value operator()(rgn &); - - std::vector allocas() const; - - private: - bool recursive_; - std::vector alloca_; -}; - -class lifetime_inserter { - public: - /* Inst nodes */ - auto operator()(inst_node &inst) -> std::unordered_set<::tinytc_value const *>; - auto operator()(blas_a2_inst &inst) -> std::unordered_set<::tinytc_value const *>; - auto operator()(blas_a3_inst &inst) -> std::unordered_set<::tinytc_value const *>; - auto operator()(loop_inst &p) -> std::unordered_set<::tinytc_value const *>; - auto operator()(alloca_inst &a) -> std::unordered_set<::tinytc_value const *>; - auto operator()(barrier_inst &b) -> std::unordered_set<::tinytc_value const *>; - auto operator()(expand_inst &e) -> std::unordered_set<::tinytc_value const *>; - auto operator()(fuse_inst &f) -> std::unordered_set<::tinytc_value const *>; - auto operator()(load_inst &e) -> std::unordered_set<::tinytc_value const *>; - auto operator()(if_inst &in) -> std::unordered_set<::tinytc_value const *>; - auto operator()(lifetime_stop_inst &) -> std::unordered_set<::tinytc_value const *>; - auto operator()(parallel_inst &p) -> std::unordered_set<::tinytc_value const *>; - auto operator()(size_inst &s) -> std::unordered_set<::tinytc_value const *>; - auto operator()(store_inst &s) -> std::unordered_set<::tinytc_value const *>; - auto operator()(subview_inst &s) -> std::unordered_set<::tinytc_value const *>; - auto operator()(yield_inst &in) -> std::unordered_set<::tinytc_value const *>; - - /* Region nodes */ - auto operator()(rgn &b) -> std::unordered_set<::tinytc_value const *>; - - /* Func nodes */ - void operator()(prototype &p); - void operator()(function &fn); - - /* Program nodes */ - void operator()(program &p); - - private: - aa_results aa_; -}; - -} // namespace tinytc - -#endif // LIFETIME_ANALYSIS_20230329_HPP diff --git a/src/pass/stack.cpp b/src/pass/stack.cpp index 329c1ad7..718591aa 100644 --- a/src/pass/stack.cpp +++ b/src/pass/stack.cpp @@ -4,79 +4,66 @@ #include "pass/stack.hpp" #include "error.hpp" #include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "node/value_node.hpp" #include "support/casting.hpp" +#include "support/util.hpp" #include "support/visit.hpp" +#include "support/walk.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" -#include +#include +#include namespace tinytc { -/* Inst nodes */ -void stack_ptr::operator()(inst_node &) {} -void stack_ptr::operator()(alloca_inst &a) { - auto t = dyn_cast(a.result()->ty().get()); - if (t == nullptr) { - throw compilation_error(a.loc(), status::ir_expected_memref); - } - auto size = t->size_in_bytes(); - std::int64_t stack_ptr = 0; - auto it = allocs_.begin(); - for (; it != allocs_.end(); ++it) { - if (it->start - stack_ptr >= size) { - break; - } - stack_ptr = it->stop; - } - allocs_.insert(it, allocation{a.result().get(), stack_ptr, stack_ptr + size}); - a.stack_ptr(stack_ptr); -} -void stack_ptr::operator()(lifetime_stop_inst &s) { - int num = 0; - auto v = s.object().get(); - for (auto it = allocs_.begin(); it != allocs_.end();) { - if (it->value == v) { - it = allocs_.erase(it); - ++num; - } else { - ++it; - } - } - if (num != 1) { - throw compilation_error(s.loc(), status::internal_compiler_error, - "Incorrect lifetime_stop: value not found in list of allocations"); - } -} -void stack_ptr::operator()(for_inst &p) { visit(*this, *p.body()); } - -void stack_ptr::operator()(if_inst &in) { - visit(*this, *in.then()); - if (in.otherwise()) { - visit(*this, *in.otherwise()); - } -} - -/* Region nodes */ -void stack_ptr::operator()(rgn &b) { - for (auto &s : b.insts()) { - visit(*this, *s); - } -} - -/* Function nodes */ -void stack_ptr::operator()(prototype &) {} -void stack_ptr::operator()(function &fn) { - visit(*this, *fn.prototype()); - visit(*this, *fn.body()); -} +void set_stack_ptr_pass::run_on_function(function &fn) { + struct allocation { + value_node *value; + std::int64_t start, stop; + }; + std::list allocs; -/* Program nodes */ -void stack_ptr::operator()(program &p) { - for (auto &fn : p.functions()) { - allocs_.clear(); - visit(*this, *fn); - } + walk(fn, [&allocs](inst_node &i) { + visit(overloaded{ + [&allocs](alloca_inst &a) { + auto t = dyn_cast(a.result()->ty().get()); + if (t == nullptr) { + throw compilation_error(a.loc(), status::ir_expected_memref); + } + auto size = t->size_in_bytes(); + std::int64_t stack_ptr = 0; + auto it = allocs.begin(); + for (; it != allocs.end(); ++it) { + if (it->start - stack_ptr >= size) { + break; + } + stack_ptr = it->stop; + } + allocs.insert(it, allocation{a.result().get(), stack_ptr, stack_ptr + size}); + a.stack_ptr(stack_ptr); + }, + [&allocs](lifetime_stop_inst &s) { + int num = 0; + auto v = s.object().get(); + for (auto it = allocs.begin(); it != allocs.end();) { + if (it->value == v) { + it = allocs.erase(it); + ++num; + } else { + ++it; + } + } + if (num != 1) { + throw compilation_error( + s.loc(), status::internal_compiler_error, + "Incorrect lifetime_stop: value not found in list of allocations"); + } + }, + [](inst_node &) {}}, + i); + }); } } // namespace tinytc diff --git a/src/pass/stack.hpp b/src/pass/stack.hpp index f515d87a..72b6f7ad 100644 --- a/src/pass/stack.hpp +++ b/src/pass/stack.hpp @@ -5,41 +5,12 @@ #define STACK_20230413_HPP #include "node/function_node.hpp" -#include "node/inst_node.hpp" -#include "node/program_node.hpp" -#include "node/region_node.hpp" -#include "node/value_node.hpp" - -#include -#include namespace tinytc { -class stack_ptr { +class set_stack_ptr_pass { public: - /* Inst nodes */ - void operator()(inst_node &); - void operator()(alloca_inst &a); - void operator()(lifetime_stop_inst &s); - void operator()(for_inst &p); - void operator()(if_inst &in); - - /* Region nodes */ - void operator()(rgn &b); - - /* Func nodes */ - void operator()(prototype &p); - void operator()(function &fn); - - /* Program nodes */ - void operator()(program &p); - - private: - struct allocation { - value_node *value; - std::int64_t start, stop; - }; - std::list allocs_; + void run_on_function(function &fn); }; } // namespace tinytc diff --git a/src/pass/work_group_size.cpp b/src/pass/work_group_size.cpp index 6fedb3c8..62792619 100644 --- a/src/pass/work_group_size.cpp +++ b/src/pass/work_group_size.cpp @@ -5,14 +5,18 @@ #include "device_info.hpp" #include "error.hpp" #include "node/data_type_node.hpp" +#include "node/inst_node.hpp" #include "node/value_node.hpp" #include "support/casting.hpp" #include "support/visit.hpp" +#include "support/walk.hpp" +#include "tiling.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" #include #include +#include #include #include @@ -26,61 +30,43 @@ auto get_memref_type(value_node &v) { return t; } -work_group_size::work_group_size(::tinytc_core_info const *info) : info_(std::move(info)) { +work_group_size_pass::work_group_size_pass(::tinytc_core_info const *info) + : info_(std::move(info)) { if (info_ == nullptr) { throw std::invalid_argument("info must not be nullptr"); } } -/* Stmt nodes */ -void work_group_size::operator()(inst_node &) {} - -void work_group_size::operator()(blas_a2_inst &in) { - auto b = get_memref_type(*in.B()); - if (b->dim() == 1) { - shapes_.insert({b->element_ty(), {b->shape(0), 0}}); - } else if (b->dim() >= 2) { - shapes_.insert({b->element_ty(), {b->shape(0), b->shape(1)}}); - } -} -void work_group_size::operator()(blas_a3_inst &in) { - auto c = get_memref_type(*in.C()); - if (c->dim() == 1) { - shapes_.insert({c->element_ty(), {c->shape(0), 0}}); - } else if (c->dim() >= 2) { - shapes_.insert({c->element_ty(), {c->shape(0), c->shape(1)}}); - } -} - -void work_group_size::operator()(if_inst &in) { - visit(*this, *in.then()); - if (in.otherwise()) { - visit(*this, *in.otherwise()); - } -} -void work_group_size::operator()(loop_inst &in) { visit(*this, *in.body()); } -void work_group_size::operator()(parallel_inst &p) { visit(*this, *p.body()); } - -/* Region nodes */ -void work_group_size::operator()(rgn &b) { - for (auto &i : b.insts()) { - visit(*this, *i); - } -} - -/* Function nodes */ -void work_group_size::operator()(prototype &) {} - -void work_group_size::operator()(function &fn) { +void work_group_size_pass::run_on_function(function &fn) { auto subgroup_size = fn.subgroup_size(); auto work_group_size = fn.work_group_size(); - shapes_.clear(); - if (subgroup_size == 0 || work_group_size[0] == 0 || work_group_size[1] == 0) { - visit(*this, *fn.prototype()); - visit(*this, *fn.body()); + auto shape_set = std::unordered_set{}; - auto const shapes = std::vector(shapes_.begin(), shapes_.end()); + if (subgroup_size == 0 || work_group_size[0] == 0 || work_group_size[1] == 0) { + walk(fn, [&shape_set](inst_node &i) { + visit( + overloaded{[&shape_set](blas_a2_inst &in) { + auto b = get_memref_type(*in.B()); + if (b->dim() == 1) { + shape_set.insert({b->element_ty(), {b->shape(0), 0}}); + } else if (b->dim() >= 2) { + shape_set.insert({b->element_ty(), {b->shape(0), b->shape(1)}}); + } + }, + [&shape_set](blas_a3_inst &in) { + auto c = get_memref_type(*in.C()); + if (c->dim() == 1) { + shape_set.insert({c->element_ty(), {c->shape(0), 0}}); + } else if (c->dim() >= 2) { + shape_set.insert({c->element_ty(), {c->shape(0), c->shape(1)}}); + } + }, + [](inst_node &) {}}, + i); + }); + + auto const shapes = std::vector(shape_set.begin(), shape_set.end()); if (subgroup_size == 0) { subgroup_size = suggest_subgroup_size(shapes, *info_); @@ -119,11 +105,4 @@ void work_group_size::operator()(function &fn) { } } -/* Program nodes */ -void work_group_size::operator()(program &p) { - for (auto &fn : p.functions()) { - visit(*this, *fn); - } -} - } // namespace tinytc diff --git a/src/pass/work_group_size.hpp b/src/pass/work_group_size.hpp index db350be1..69c31fd3 100644 --- a/src/pass/work_group_size.hpp +++ b/src/pass/work_group_size.hpp @@ -6,40 +6,17 @@ #include "device_info.hpp" #include "node/function_node.hpp" -#include "node/inst_node.hpp" -#include "node/program_node.hpp" -#include "node/region_node.hpp" -#include "tiling.hpp" - -#include namespace tinytc { -class work_group_size { +class work_group_size_pass { public: - work_group_size(tinytc_core_info const *info); - - /* Inst nodes */ - void operator()(inst_node &); - void operator()(blas_a2_inst &in); - void operator()(blas_a3_inst &in); - void operator()(if_inst &in); - void operator()(loop_inst &in); - void operator()(parallel_inst &p); - - /* Region nodes */ - void operator()(rgn &b); - - /* Func nodes */ - void operator()(prototype &p); - void operator()(function &fn); + work_group_size_pass(tinytc_core_info const *info); - /* Program nodes */ - void operator()(program &p); + void run_on_function(function &fn); private: tinytc_core_info const *info_; - std::unordered_set shapes_; }; } // namespace tinytc diff --git a/src/passes.def b/src/passes.def index 12a5aeac..61f2442f 100644 --- a/src/passes.def +++ b/src/passes.def @@ -3,3 +3,6 @@ FUNCTION_PASS("check-ir", check_ir_pass{}) FUNCTION_PASS("dump-ir", dump_ir_pass{std::cout}) +FUNCTION_PASS("insert-lifetime-stop", insert_lifetime_stop_pass{}) +FUNCTION_PASS("set-stack-ptr", set_stack_ptr_pass{}) +FUNCTION_PASS_WITH_INFO("work-group-size", [](tinytc_core_info const* info) { return work_group_size_pass(info); }) diff --git a/src/support/walk.cpp b/src/support/walk.cpp index b6a13755..b158f36c 100644 --- a/src/support/walk.cpp +++ b/src/support/walk.cpp @@ -14,10 +14,8 @@ void walk(inst_node &i, std::function void walk(inst_node &i, std::function void walk(inst_node &i, std::function callback) { + for (auto ® : i.child_regions()) { + if (reg) { + if constexpr (Order == walk_order::pre_order) { + callback(reg); + } + for (auto &j : *reg) { + walk(*j, callback); + } + if constexpr (Order == walk_order::post_order) { + callback(reg); + } + } + } +} + void walk(inst_node &i, std::function callback); template void walk(function &fn, std::function callback) { diff --git a/test/opt/insert-lifetime-stop.ir b/test/opt/insert-lifetime-stop.ir new file mode 100644 index 00000000..4ca88b22 --- /dev/null +++ b/test/opt/insert-lifetime-stop.ir @@ -0,0 +1,61 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-opt --insert-lifetime-stop < %s | filecheck %s +func @basic() { + %0 = alloca -> memref +; CHECK: %0 = alloca -> memref +; CHECK-NEXT: lifetime_stop %0 +} + +func @use1(%A: memref, %C: memref) { +; CHECK-LABEL: func @use1{{.*}} + %B = alloca -> memref + gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref +; CHECK: gemm.n.n{{.*}} +; CHECK-NEXT: lifetime_stop %B +} + +func @use2(%A: memref, %C: memref) { +; CHECK-LABEL: func @use2{{.*}} + %B = alloca -> memref + gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref + %B2 = alloca -> memref + gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref + gemm.n.n 1.0, %A, %B2, 0.0, %C : f32, memref, memref, f32, memref +; CHECK: %B2 = {{.*}} +; CHECK-NEXT: gemm.n.n{{.*}} +; CHECK-NEXT: lifetime_stop %B +; CHECK: gemm.n.n{{.*}} +; CHECK-NEXT: lifetime_stop %B2 +} + +func @use_alias(%A: memref, %C: memref) { +; CHECK-LABEL: func @use_alias{{.*}} + %B = alloca -> memref + %0 = fuse %B[1,3] : memref + %1 = subview %0[0:8,:] : memref + gemm.n.n 1.0, %A, %1, 0.0, %C : f32, memref, memref>, f32, memref +; CHECK: gemm.n.n{{.*}} +; CHECK-NEXT: lifetime_stop %B +} + +func @region1() { +; CHECK-LABEL: func @region1{{.*}} + %0 = alloca -> memref + for %i=0,4 : index { + %1 = alloca -> memref + for %k=0,4 : index { + %2 = alloca -> memref + gemm.n.n 1.0, %0, %1, 0.0, %2 : f32, memref, memref, f32, memref + axpby.n 1.0, %0, 0.0, %1 : f32, memref, f32, memref + } + } +; CHECK: gemm.n.n{{.*}} +; CHECK-NEXT: lifetime_stop %2 +; CHECK-NEXT: axpby.n{{.*}} +; CHECK-NEXT: } +; CHECK-NEXT: lifetime_stop %1 +; CHECK-NEXT: } +; CHECK-NEXT: lifetime_stop %0 +} diff --git a/test/opt/work-group-size.ir b/test/opt/work-group-size.ir new file mode 100644 index 00000000..8b033814 --- /dev/null +++ b/test/opt/work-group-size.ir @@ -0,0 +1,26 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-opt -d pvc --work-group-size < %s | filecheck %s +func @default_pvc() { +; CHECK: func @default_pvc() subgroup_size(32) work_group_size(32,1) { +} + +func @f32_blas() { +; CHECK: func @f32_blas() subgroup_size(32) work_group_size(128,2) { + %0 = alloca -> memref + %1 = alloca -> memref + for %i=0,4 { + axpby.n 1.0, %0, 0.0, %1 : f32, memref, f32, memref + } +} + +func @f64_blas() { +; CHECK: func @f64_blas() subgroup_size(16) work_group_size(128,8) { + %0 = alloca -> memref + %1 = alloca -> memref + %2 = alloca -> memref + for %i=0,4 { + gemm.n.n 1.0, %0, %1, 0.0, %2 : f64, memref, memref, f64, memref + } +} From 86bdecf332c520667c36b4efaa593eeb119afc57 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 13 Sep 2024 10:32:48 +0200 Subject: [PATCH 016/297] Fix convert to opencl pass Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 12 +- src/CMakeLists.txt | 5 +- src/{pass => analysis}/equal.cpp | 4 +- src/{pass => analysis}/equal.hpp | 2 + src/compiler.cpp | 5 +- src/error.cpp | 2 +- src/node/inst_node.cpp | 51 +++-- src/parser/parser_impl.yy | 2 +- .../{opencl_ast.cpp => convert_to_opencl.cpp} | 191 +++++++++--------- .../{opencl_ast.hpp => convert_to_opencl.hpp} | 23 +-- src/passes.cpp | 65 ------ src/passes.hpp | 39 ---- src/prog.cpp | 7 +- test/codegen/dope_vector0.ir | 14 +- 14 files changed, 168 insertions(+), 254 deletions(-) rename src/{pass => analysis}/equal.cpp (85%) rename src/{pass => analysis}/equal.hpp (90%) rename src/pass/{opencl_ast.cpp => convert_to_opencl.cpp} (89%) rename src/pass/{opencl_ast.hpp => convert_to_opencl.hpp} (92%) delete mode 100644 src/passes.cpp diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 51377663..5e7b7ee9 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -952,6 +952,15 @@ A dynamic size ("?") means that the size is the mode size inferred from the memr minus the offset. A plain colon is syntactic sugar for "0:?". +Zero sizes are used to encode that a rank-reduction is required, that is, +the rank of size 0 is removed from the output memref type. +A single index is syntactic sugar for offset plus size 0, e.g. %0 is syntactic sugar for %0:0. +(Note that a zero-size rank, e.g. in memref, is non-sense, because any multi-index passed +to the memref would be out-of-bounds. However, a one-sized rank, e.g. memref, might be desirable.) +A dynamic size of zero is undefined behaviour. + + + There is no run-time check whether the indices are within bounds. Offset and size must be of index type. Offset must be non-negative and size must be positive. @@ -966,11 +975,12 @@ The output type is a memref type according to the following rules: subview %0[4:8,8:4] : memref ; Returns memref> -#. **Rank-reduction:** A mode accessed by a single constant or value is removed from the output tensor. +#. **Rank-reduction:** A mode accessed by offset only or a mode with size statically known to be 0 is removed from the output tensor. .. code:: subview %0[2:4, %1] : memref ; Returns memref> + subview %0[2:4, %1:0] : memref ; Returns memref> subview %0[2:4, %1:1] : memref ; Returns memref> #. **Output-mode size:** The size of the output mode is determined by the size field of a slice diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ddcdc24d..a90ba8c5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -19,6 +19,7 @@ find_package(BISON 3.8.2 REQUIRED) set(SOURCES analysis/aa_results.cpp analysis/alias.cpp + analysis/equal.cpp binary.cpp codegen_tools.cpp compiler.cpp @@ -35,17 +36,15 @@ set(SOURCES parser.cpp pass/check_ir.cpp #pass/constant_propagation.cpp + pass/convert_to_opencl.cpp pass/dump_ir.cpp - pass/equal.cpp #pass/insert_barrier.cpp pass/insert_lifetime_stop.cpp #pass/lower_linalg.cpp #pass/metadata.cpp - #pass/opencl_ast.cpp pass/slot_tracker.cpp pass/stack.cpp pass/work_group_size.cpp - passes.cpp prog.cpp recipe.cpp recipe/small_gemm_batched.cpp diff --git a/src/pass/equal.cpp b/src/analysis/equal.cpp similarity index 85% rename from src/pass/equal.cpp rename to src/analysis/equal.cpp index 010968d0..02ae23c5 100644 --- a/src/pass/equal.cpp +++ b/src/analysis/equal.cpp @@ -1,7 +1,7 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "pass/equal.hpp" +#include "analysis/equal.hpp" #include "support/visit.hpp" #include "tinytc/tinytc.hpp" @@ -21,4 +21,6 @@ bool equal::operator()(scalar_data_type const &a, scalar_data_type const &b) { return a.ty() == b.ty(); } +bool is_equal(tinytc_data_type const &a, tinytc_data_type const &b) { return visit(equal{}, a, b); } + } // namespace tinytc diff --git a/src/pass/equal.hpp b/src/analysis/equal.hpp similarity index 90% rename from src/pass/equal.hpp rename to src/analysis/equal.hpp index 657716b2..be258b02 100644 --- a/src/pass/equal.hpp +++ b/src/analysis/equal.hpp @@ -18,6 +18,8 @@ class equal { bool operator()(scalar_data_type const &a, scalar_data_type const &b); }; +bool is_equal(tinytc_data_type const &a, tinytc_data_type const &b); + } // namespace tinytc #endif // EQUAL_20240208_HPP diff --git a/src/compiler.cpp b/src/compiler.cpp index b14e253f..84047fa0 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -6,6 +6,7 @@ #include "node/program_node.hpp" #include "parser.hpp" #include "pass/check_ir.hpp" +#include "pass/convert_to_opencl.hpp" #include "pass/dump_ir.hpp" #include "pass/insert_lifetime_stop.hpp" #include "pass/stack.hpp" @@ -91,7 +92,7 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ // propagate_constants(*prg); // dump_ir(std::cout, *prg); // opencl - /*auto ast = generate_opencl_ast(*prg, *info); + auto ast = convert_to_opencl_pass{info}.run_on_program(*prg); clir::make_names_unique(ast); auto oss = std::ostringstream{}; @@ -104,7 +105,7 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ *src = std::make_unique<::tinytc_source>(oss.str(), prg->loc(), std::move(ext), info->core_features()) - .release();*/ + .release(); }, ctx); } diff --git a/src/error.cpp b/src/error.cpp index 67486bdd..ab6108c5 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -145,7 +145,7 @@ char const *tinytc_error_string(tinytc_status_t status) { case tinytc_status_ir_multiple_dynamic_modes: return "At most one mode must be dynamic ('?')"; case tinytc_status_ir_invalid_slice: - return "Offset must be non-negative and must not be '?'; size must be positive or '?'"; + return "Offset must be non-negative and must not be '?'; size must be non-negative or '?'"; case tinytc_status_ir_expand_shape_order_too_small: return "Expand shape must have at least 2 entries"; case tinytc_status_ir_expand_shape_mismatch: diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 09aeaee7..48e6aa2d 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -495,7 +495,7 @@ subview_inst::subview_inst(value op0, std::vector const &offset_list0, op(i++) = val; } for (auto const &val : size_list0) { - op(i++) = val; + op(i++) = val ? val : make_index(0); } } loc(lc); @@ -520,31 +520,30 @@ subview_inst::subview_inst(value op0, std::vector const &offset_list0, }, [](auto &) {}}, *offset); - if (size) { // if size is given - visit(overloaded{[&](int_imm &i) { - if (i.value() < 1 && !is_dynamic_value(i.value())) { - throw compilation_error(loc(), status::ir_invalid_slice); - } - }, - [](auto &) {}}, - *size); - auto size_value = - visit(overloaded{[&](int_imm &offset, int_imm &size) -> std::int64_t { - if (is_dynamic_value(size.value())) { - return is_dynamic_value(m->shape(i)) - ? dynamic - : m->shape(i) - offset.value(); - } - return size.value(); - }, - [&](val &, int_imm &size) -> std::int64_t { - if (is_dynamic_value(size.value())) { - return dynamic; - } - return size.value(); - }, - [](auto &, auto &) -> std::int64_t { return dynamic; }}, - *offset, *size); + visit(overloaded{[&](int_imm &i) { + if (i.value() < 0 && !is_dynamic_value(i.value())) { + throw compilation_error(loc(), status::ir_invalid_slice); + } + }, + [](auto &) {}}, + *size); + auto size_value = visit(overloaded{[&](int_imm &offset, int_imm &size) -> std::int64_t { + if (is_dynamic_value(size.value())) { + return is_dynamic_value(m->shape(i)) + ? dynamic + : m->shape(i) - offset.value(); + } + return size.value(); + }, + [&](val &, int_imm &size) -> std::int64_t { + if (is_dynamic_value(size.value())) { + return dynamic; + } + return size.value(); + }, + [](auto &, auto &) -> std::int64_t { return dynamic; }}, + *offset, *size); + if (size_value > 0 || is_dynamic_value(size_value)) { shape.push_back(size_value); stride.push_back(m->stride(i)); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 2db5bd6b..6e93fa8e 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -21,6 +21,7 @@ } %code { + #include "analysis/equal.hpp" #include "error.hpp" #include "node/data_type_node.hpp" #include "node/inst_node.hpp" @@ -29,7 +30,6 @@ #include "node/value_node.hpp" #include "parser/lexer.hpp" #include "parser/parse_context.hpp" - #include "passes.hpp" #include "support/util.hpp" #include "support/visit.hpp" diff --git a/src/pass/opencl_ast.cpp b/src/pass/convert_to_opencl.cpp similarity index 89% rename from src/pass/opencl_ast.cpp rename to src/pass/convert_to_opencl.cpp index cdfbeec5..ab071c39 100644 --- a/src/pass/opencl_ast.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -1,7 +1,7 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "pass/opencl_ast.hpp" +#include "pass/convert_to_opencl.hpp" #include "codegen_tools.hpp" #include "error.hpp" #include "gemm_generator.hpp" @@ -106,14 +106,15 @@ dope_vector dope_vector::from_memref_type(std::string const &prefix, memref_data return dope_vector(std::move(shape), std::move(stride)); } -opencl_ast::opencl_ast(::tinytc_core_info const *info) : info_(std::move(info)) { +convert_to_opencl_pass::convert_to_opencl_pass(::tinytc_core_info const *info) + : info_(std::move(info)) { if (info_ == nullptr) { throw std::invalid_argument("info must not be nullptr"); } declared_vars_.push_back({}); } -auto opencl_ast::get_dope_vector(value_node *v) -> dope_vector & { +auto convert_to_opencl_pass::get_dope_vector(value_node *v) -> dope_vector & { auto dv = dope_vector_.find(std::bit_cast(v)); if (dv == dope_vector_.end()) { throw compilation_error(v->loc(), status::internal_compiler_error, @@ -122,12 +123,12 @@ auto opencl_ast::get_dope_vector(value_node *v) -> dope_vector & { return dv->second; } -void opencl_ast::set_dope_vector(value_node *v, dope_vector dv) { +void convert_to_opencl_pass::set_dope_vector(value_node *v, dope_vector dv) { uintptr_t u = std::bit_cast(v); dope_vector_[u] = std::move(dv); } -clir::var opencl_ast::declare(value_node const &v) { +clir::var convert_to_opencl_pass::declare(value_node const &v) { uintptr_t u = std::bit_cast(&v); for (auto it = declared_vars_.rbegin(); it != declared_vars_.rend(); ++it) { if (it->find(u) != it->end()) { @@ -141,7 +142,8 @@ clir::var opencl_ast::declare(value_node const &v) { return declared_vars_.back()[u]; } -auto opencl_ast::get_memref_type(value_node const &v) const -> const memref_data_type * { +auto convert_to_opencl_pass::get_memref_type(value_node const &v) const + -> const memref_data_type * { auto t = dyn_cast(v.ty().get()); if (t == nullptr) { throw compilation_error(v.loc(), status::ir_expected_memref); @@ -149,7 +151,7 @@ auto opencl_ast::get_memref_type(value_node const &v) const -> const memref_data return t; } -auto opencl_ast::get_scalar_type(data_type_node const &ty) -> scalar_type { +auto convert_to_opencl_pass::get_scalar_type(data_type_node const &ty) -> scalar_type { return visit(overloaded{[](scalar_data_type const &i) -> scalar_type { return i.ty(); }, [](memref_data_type const &i) -> scalar_type { return i.element_ty(); }, [&](auto const &i) -> scalar_type { @@ -161,10 +163,10 @@ auto opencl_ast::get_scalar_type(data_type_node const &ty) -> scalar_type { }; /* Data type nodes */ -clir::data_type opencl_ast::operator()(void_data_type const &) { +clir::data_type convert_to_opencl_pass::operator()(void_data_type const &) { return clir::builtin_type::void_t; } -clir::data_type opencl_ast::operator()(group_data_type const &g) { +clir::data_type convert_to_opencl_pass::operator()(group_data_type const &g) { auto ptr_ty = visit(*this, *g.ty()); ptr_ty = clir::visit(overloaded{[](clir::internal::pointer &t) { return clir::pointer_to(clir::pointer_to( @@ -178,21 +180,23 @@ clir::data_type opencl_ast::operator()(group_data_type const &g) { } return ptr_ty; } -clir::data_type opencl_ast::operator()(memref_data_type const &d) { +clir::data_type convert_to_opencl_pass::operator()(memref_data_type const &d) { return clir::pointer_to(d.clir_element_ty()); } -clir::data_type opencl_ast::operator()(scalar_data_type const &s) { return s.clir_ty(); } +clir::data_type convert_to_opencl_pass::operator()(scalar_data_type const &s) { + return s.clir_ty(); +} /* Value nodes */ -clir::expr opencl_ast::operator()(float_imm const &v) { +clir::expr convert_to_opencl_pass::operator()(float_imm const &v) { auto ty = get_scalar_type(*v.ty()); return clir::expr(v.value(), static_cast(size(ty) * 8)); } -clir::expr opencl_ast::operator()(int_imm const &v) { +clir::expr convert_to_opencl_pass::operator()(int_imm const &v) { auto ty = get_scalar_type(*v.ty()); return clir::expr(v.value(), static_cast(size(ty) * 8)); } -clir::expr opencl_ast::operator()(val const &v) { +clir::expr convert_to_opencl_pass::operator()(val const &v) { uintptr_t u = std::bit_cast(&v); for (auto it = declared_vars_.rbegin(); it != declared_vars_.rend(); ++it) { if (auto j = it->find(u); j != it->end()) { @@ -205,7 +209,7 @@ clir::expr opencl_ast::operator()(val const &v) { } /* Stmt nodes */ -std::vector opencl_ast::operator()(alloca_inst const &a) { +std::vector convert_to_opencl_pass::operator()(alloca_inst const &a) { if (a.stack_ptr() < 0) { throw compilation_error(a.loc(), status::internal_compiler_error, "Invalid stack_ptr in alloca. Did you run set_stack_ptrs?"); @@ -228,7 +232,7 @@ std::vector opencl_ast::operator()(alloca_inst const &a) { return {std::move(result)}; } -std::vector opencl_ast::operator()(axpby_inst const &inst) { +std::vector convert_to_opencl_pass::operator()(axpby_inst const &inst) { auto at = get_memref_type(*inst.A()); auto bt = get_memref_type(*inst.B()); auto alpha_ty = get_scalar_type(*inst.alpha()->ty()); @@ -307,12 +311,12 @@ std::vector opencl_ast::operator()(axpby_inst const &inst) { throw compilation_error(inst.loc(), status::ir_expected_vector_or_matrix); } -std::vector opencl_ast::operator()(barrier_inst const &) { +std::vector convert_to_opencl_pass::operator()(barrier_inst const &) { return {clir::expression_statement(clir::call_builtin( clir::builtin_function::barrier, {clir::cl_mem_fence_flags::CLK_LOCAL_MEM_FENCE}))}; } -std::vector opencl_ast::operator()(arith_inst const &a) { +std::vector convert_to_opencl_pass::operator()(arith_inst const &a) { auto const make = [](arithmetic op, clir::expr a, clir::expr b, scalar_type sty) -> clir::expr { switch (op) { case arithmetic::add: @@ -354,7 +358,7 @@ std::vector opencl_ast::operator()(arith_inst const &a) { make(a.operation(), visit(*this, *a.a()), visit(*this, *a.b()), sty))}; } -std::vector opencl_ast::operator()(arith_unary_inst const &a) { +std::vector convert_to_opencl_pass::operator()(arith_unary_inst const &a) { auto const make = [](arithmetic_unary op, clir::expr a, scalar_type sty) -> clir::expr { switch (op) { case arithmetic_unary::neg: @@ -373,14 +377,14 @@ std::vector opencl_ast::operator()(arith_unary_inst const &a) { make(a.operation(), visit(*this, *a.a()), sty))}; } -std::vector opencl_ast::operator()(cast_inst const &c) { +std::vector convert_to_opencl_pass::operator()(cast_inst const &c) { auto v = declare(*c.result()); auto result_ty = visit(*this, *c.result()->ty()); auto cst = cast(result_ty, visit(*this, *c.a())); return {declaration_assignment(std::move(result_ty), std::move(v), std::move(cst))}; } -std::vector opencl_ast::operator()(compare_inst const &c) { +std::vector convert_to_opencl_pass::operator()(compare_inst const &c) { auto const make = [](cmp_condition cond, clir::expr a, clir::expr b) -> clir::expr { switch (cond) { case cmp_condition::eq: @@ -403,7 +407,7 @@ std::vector opencl_ast::operator()(compare_inst const &c) { make(c.cond(), visit(*this, *c.a()), visit(*this, *c.b())))}; } -std::vector opencl_ast::operator()(expand_inst const &e) { +std::vector convert_to_opencl_pass::operator()(expand_inst const &e) { auto result_var = declare(*e.result()); auto m = get_memref_type(*e.operand()); auto &dv = get_dope_vector(e.operand().get()); @@ -471,7 +475,7 @@ std::vector opencl_ast::operator()(expand_inst const &e) { })); return clinst; } -std::vector opencl_ast::operator()(fuse_inst const &f) { +std::vector convert_to_opencl_pass::operator()(fuse_inst const &f) { auto result_var = declare(*f.result()); auto m = get_memref_type(*f.operand()); auto &dv = get_dope_vector(f.operand().get()); @@ -511,7 +515,7 @@ std::vector opencl_ast::operator()(fuse_inst const &f) { return clinst; } -std::vector opencl_ast::operator()(load_inst const &e) { +std::vector convert_to_opencl_pass::operator()(load_inst const &e) { auto op_val = e.operand(); auto rhs = visit(*this, *op_val); @@ -566,23 +570,25 @@ std::vector opencl_ast::operator()(load_inst const &e) { return clinst; } -std::vector opencl_ast::operator()(group_id_inst const &g) { +std::vector convert_to_opencl_pass::operator()(group_id_inst const &g) { auto rhs = clir::get_global_id(2); auto lhs = declare(*g.result()); return { declaration_assignment(visit(*this, *g.result()->ty()), std::move(lhs), std::move(rhs))}; } -std::vector opencl_ast::operator()(group_size_inst const &g) { +std::vector convert_to_opencl_pass::operator()(group_size_inst const &g) { auto rhs = clir::get_global_size(2); auto lhs = declare(*g.result()); return { declaration_assignment(visit(*this, *g.result()->ty()), std::move(lhs), std::move(rhs))}; } -std::vector opencl_ast::operator()(lifetime_stop_inst const &) { return {}; } +std::vector convert_to_opencl_pass::operator()(lifetime_stop_inst const &) { + return {}; +} -std::vector opencl_ast::operator()(gemm_inst const &g) { +std::vector convert_to_opencl_pass::operator()(gemm_inst const &g) { auto a = get_memref_type(*g.A()); auto b = get_memref_type(*g.B()); auto c = get_memref_type(*g.C()); @@ -636,7 +642,7 @@ std::vector opencl_ast::operator()(gemm_inst const &g) { visit(*this, *g.beta()), visit(*this, *g.C()), cdv.stride(0), cdv.stride(1)}))}; } -std::vector opencl_ast::operator()(gemv_inst const &g) { +std::vector convert_to_opencl_pass::operator()(gemv_inst const &g) { auto a = get_memref_type(*g.A()); auto b = get_memref_type(*g.B()); auto c = get_memref_type(*g.C()); @@ -689,7 +695,7 @@ std::vector opencl_ast::operator()(gemv_inst const &g) { visit(*this, *g.beta()), visit(*this, *g.C()), cdv.stride(0), 0}))}; } -std::vector opencl_ast::operator()(ger_inst const &g) { +std::vector convert_to_opencl_pass::operator()(ger_inst const &g) { auto at = get_memref_type(*g.A()); auto bt = get_memref_type(*g.B()); auto ct = get_memref_type(*g.C()); @@ -757,7 +763,7 @@ std::vector opencl_ast::operator()(ger_inst const &g) { return {bb.get_product()}; } -std::vector opencl_ast::operator()(for_inst const &p) { +std::vector convert_to_opencl_pass::operator()(for_inst const &p) { auto clinst = std::vector{}; auto lv = declare(*p.loop_var()); @@ -765,14 +771,14 @@ std::vector opencl_ast::operator()(for_inst const &p) { auto start = clir::declaration_assignment(std::move(lv_ty), lv, visit(*this, *p.from())); auto condition = lv < visit(*this, *p.to()); auto step = p.step() ? clir::add_into(lv, visit(*this, *p.step())) : ++lv; - auto body = visit(*this, *p.body()); + auto body = run_on_region(*p.body()); clinst.emplace_back(clir::stmt(std::make_shared( std::move(start), std::move(condition), std::move(step), std::move(body)))); return clinst; } -std::vector opencl_ast::operator()(foreach_inst const &p) { +std::vector convert_to_opencl_pass::operator()(foreach_inst const &p) { auto lv = declare(*p.loop_var()); auto lv_ty = visit(*this, *p.loop_var()->ty()); auto from = visit(*this, *p.from()); @@ -785,12 +791,12 @@ std::vector opencl_ast::operator()(foreach_inst const &p) { bb, trip_count, core_cfg_.subgroup_size, tiling_.m_tiles() * tiling_.n_tiles(), std::move(sg), [&](clir::block_builder &bb, clir::expr block, bool, clir::expr) { bb.add(clir::declaration_assignment(lv_ty, lv, std::move(block) + m + from)); - bb.add(visit(*this, *p.body())); + bb.add(run_on_region(*p.body())); }); return {bb.get_product()}; } -std::vector opencl_ast::operator()(hadamard_inst const &g) { +std::vector convert_to_opencl_pass::operator()(hadamard_inst const &g) { auto at = get_memref_type(*g.A()); auto bt = get_memref_type(*g.B()); auto ct = get_memref_type(*g.C()); @@ -838,7 +844,7 @@ std::vector opencl_ast::operator()(hadamard_inst const &g) { return {bb.get_product()}; } -std::vector opencl_ast::operator()(if_inst const &in) { +std::vector convert_to_opencl_pass::operator()(if_inst const &in) { auto clinst = std::vector{}; yielded_vars_.push_back(std::vector{}); for (auto const &r : in.results()) { @@ -847,27 +853,27 @@ std::vector opencl_ast::operator()(if_inst const &in) { yielded_vars_.back().emplace_back(std::move(v)); } auto ib = clir::if_selection_builder(visit(*this, *in.condition())); - ib.set_then(visit(*this, *in.then())); + ib.set_then(run_on_region(*in.then())); if (in.otherwise()) { - ib.set_otherwise(visit(*this, *in.otherwise())); + ib.set_otherwise(run_on_region(*in.otherwise())); } yielded_vars_.pop_back(); clinst.emplace_back(ib.get_product()); return clinst; } -std::vector opencl_ast::operator()(num_subgroups_inst const &sg) { +std::vector convert_to_opencl_pass::operator()(num_subgroups_inst const &sg) { auto rhs = clir::get_num_sub_groups(); auto lhs = declare(*sg.result()); return { declaration_assignment(visit(*this, *sg.result()->ty()), std::move(lhs), std::move(rhs))}; } -std::vector opencl_ast::operator()(parallel_inst const &p) { - return {visit(*this, *p.body())}; +std::vector convert_to_opencl_pass::operator()(parallel_inst const &p) { + return {run_on_region(*p.body())}; } -std::vector opencl_ast::operator()(size_inst const &s) { +std::vector convert_to_opencl_pass::operator()(size_inst const &s) { auto v = declare(*s.result()); auto &dv = get_dope_vector(s.operand().get()); @@ -875,28 +881,28 @@ std::vector opencl_ast::operator()(size_inst const &s) { dv.shape(s.mode()))}; } -std::vector opencl_ast::operator()(subgroup_id_inst const &sg) { +std::vector convert_to_opencl_pass::operator()(subgroup_id_inst const &sg) { auto rhs = clir::get_sub_group_id(); auto lhs = declare(*sg.result()); return { declaration_assignment(visit(*this, *sg.result()->ty()), std::move(lhs), std::move(rhs))}; } -std::vector opencl_ast::operator()(subgroup_local_id_inst const &sg) { +std::vector convert_to_opencl_pass::operator()(subgroup_local_id_inst const &sg) { auto rhs = clir::get_sub_group_local_id(); auto lhs = declare(*sg.result()); return { declaration_assignment(visit(*this, *sg.result()->ty()), std::move(lhs), std::move(rhs))}; } -std::vector opencl_ast::operator()(subgroup_size_inst const &sg) { +std::vector convert_to_opencl_pass::operator()(subgroup_size_inst const &sg) { auto rhs = clir::get_sub_group_size(); auto lhs = declare(*sg.result()); return { declaration_assignment(visit(*this, *sg.result()->ty()), std::move(lhs), std::move(rhs))}; } -std::vector opencl_ast::operator()(subview_inst const &s) { +std::vector convert_to_opencl_pass::operator()(subview_inst const &s) { auto result_var = declare(*s.result()); auto t = get_memref_type(*s.operand()); if (t->dim() != static_cast(s.num_indices())) { @@ -915,21 +921,24 @@ std::vector opencl_ast::operator()(subview_inst const &s) { auto &offset = s.offset_list()[i]; auto &size = s.size_list()[i]; rhs = rhs + visit(*this, *offset) * dv.stride(j); - if (size) { - bool is_size_unknown = visit(overloaded{[&](int_imm const &size) -> bool { - return is_dynamic_value(size.value()); - }, - [](auto const &) -> bool { return false; }}, - *size); - auto size_value = clir::expr{}; - if (is_size_unknown) { - size_value = dv.shape(j) - visit(*this, *offset); - } else { - size_value = visit(*this, *size); - } + + auto size_value = + visit(overloaded{[&](int_imm &s) -> clir::expr { + if (s.value() == 0) { + return nullptr; + } else if (is_dynamic_value(s.value())) { + return dv.shape(j) - visit(*this, *offset); + } + return this->operator()(s); + }, + [&](value_node &s) -> clir::expr { return visit(*this, s); }}, + *size); + + if (size_value) { shape_out.emplace_back(size_value); stride_out.emplace_back(dv.stride(j)); } + ++j; } @@ -947,7 +956,7 @@ std::vector opencl_ast::operator()(subview_inst const &s) { return clinst; } -std::vector opencl_ast::operator()(store_inst const &s) { +std::vector convert_to_opencl_pass::operator()(store_inst const &s) { auto ot = get_memref_type(*s.operand()); if (static_cast(s.index_list().size()) != ot->dim()) { @@ -965,7 +974,7 @@ std::vector opencl_ast::operator()(store_inst const &s) { return {expression_statement(std::move(st))}; } -std::vector opencl_ast::operator()(sum_inst const &inst) { +std::vector convert_to_opencl_pass::operator()(sum_inst const &inst) { auto at = get_memref_type(*inst.A()); auto bt = get_memref_type(*inst.B()); auto &adv = get_dope_vector(inst.A().get()); @@ -1048,7 +1057,7 @@ std::vector opencl_ast::operator()(sum_inst const &inst) { return {bb.get_product()}; } -std::vector opencl_ast::operator()(yield_inst const &in) { +std::vector convert_to_opencl_pass::operator()(yield_inst const &in) { if (yielded_vars_.empty()) { throw compilation_error(in.loc(), status::ir_unexpected_yield); } @@ -1064,10 +1073,10 @@ std::vector opencl_ast::operator()(yield_inst const &in) { } /* Region nodes */ -clir::stmt opencl_ast::operator()(rgn const &b) { +clir::stmt convert_to_opencl_pass::run_on_region(rgn ®) { declared_vars_.push_back({}); auto bb = clir::block_builder{}; - for (auto &s : b.insts()) { + for (auto &s : reg.insts()) { for (auto &cs : visit(*this, *s)) { bb.add(cs); } @@ -1077,9 +1086,23 @@ clir::stmt opencl_ast::operator()(rgn const &b) { } /* Function nodes */ -clir::func opencl_ast::operator()(prototype const &p) { - auto fb = clir::kernel_builder(std::string(p.name())); - for (auto const &v : p.args()) { +auto convert_to_opencl_pass::run_on_function(function &fn) -> clir::func { + stack_high_water_mark_ = 0; + auto const subgroup_size = fn.subgroup_size(); + try { + core_cfg_ = info_->get_core_config(subgroup_size); + } catch (std::out_of_range const &e) { + throw compilation_error(fn.loc(), status::unsupported_subgroup_size); + } + auto const work_group_size = fn.work_group_size(); + tiling_[0] = work_group_size[0] / subgroup_size; + tiling_[1] = work_group_size[1]; + + stack_ = clir::var("stack"); + + // Create prototype + auto fb = clir::kernel_builder(std::string(fn.name())); + for (auto const &v : fn.args()) { fb.argument(visit(*this, *v->ty()), declare(*v)); auto dv = visit( overloaded{[&fb, &v](memref_data_type const &) -> std::optional { @@ -1099,26 +1122,11 @@ clir::func opencl_ast::operator()(prototype const &p) { } } - auto const wgs = tiling_.work_group_size(core_cfg_.subgroup_size); - fb.attribute(clir::reqd_work_group_size(wgs[0], wgs[1], 1)); - fb.attribute(clir::intel_reqd_sub_group_size(core_cfg_.subgroup_size)); - return fb.get_product(); -} + fb.attribute(clir::reqd_work_group_size(work_group_size[0], work_group_size[1], 1)); + fb.attribute(clir::intel_reqd_sub_group_size(subgroup_size)); -clir::func opencl_ast::operator()(function const &fn) { - auto const subgroup_size = fn.subgroup_size(); - try { - core_cfg_ = info_->get_core_config(subgroup_size); - } catch (std::out_of_range const &e) { - throw compilation_error(fn.loc(), status::unsupported_subgroup_size); - } - auto const work_group_size = fn.work_group_size(); - tiling_[0] = work_group_size[0] / subgroup_size; - tiling_[1] = work_group_size[1]; + auto body = run_on_region(*fn.body()); - stack_ = clir::var("stack"); - auto proto = visit(*this, *fn.prototype()); - auto body = visit(*this, *fn.body()); if (stack_high_water_mark_ > 0) { auto bb = dynamic_cast(body.get()); if (bb == nullptr) { @@ -1131,26 +1139,19 @@ clir::func opencl_ast::operator()(function const &fn) { stack_high_water_mark_), stack_, {clir::aligned(size(scalar_type::f64) * 8)})); } - return clir::function(std::move(proto), std::move(body)); + return clir::function(fb.get_product(), std::move(body)); } /* Program nodes */ -clir::prog opencl_ast::operator()(program const &p) { - struct name_visitor { - auto operator()(function const &f) -> std::string_view { - return visit(*this, *f.prototype()); - } - auto operator()(prototype const &p) -> std::string_view { return p.name(); } - }; +auto convert_to_opencl_pass::run_on_program(program &p) -> clir::prog { reserved_names_.clear(); for (auto const &fn : p.functions()) { - reserved_names_.insert(std::string(visit(name_visitor{}, *fn))); + reserved_names_.insert(std::string(fn->name())); } prog_builder_ = clir::program_builder{}; for (auto const &fn : p.functions()) { - stack_high_water_mark_ = 0; - prog_builder_.add(visit(*this, *fn)); + prog_builder_.add(run_on_function(*fn)); } return prog_builder_.get_product(); } diff --git a/src/pass/opencl_ast.hpp b/src/pass/convert_to_opencl.hpp similarity index 92% rename from src/pass/opencl_ast.hpp rename to src/pass/convert_to_opencl.hpp index 94f54271..69cd84fd 100644 --- a/src/pass/opencl_ast.hpp +++ b/src/pass/convert_to_opencl.hpp @@ -1,8 +1,8 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#ifndef OPENCL_AST_20230309_HPP -#define OPENCL_AST_20230309_HPP +#ifndef CONVERT_TO_OPENCL_20240913_HPP +#define CONVERT_TO_OPENCL_20240913_HPP #include "device_info.hpp" #include "node/data_type_node.hpp" @@ -55,9 +55,9 @@ class dope_vector { clir::expr offset_ = clir::expr(std::int64_t(0)); }; -class opencl_ast { +class convert_to_opencl_pass { public: - opencl_ast(::tinytc_core_info const *info); + convert_to_opencl_pass(::tinytc_core_info const *info); /* Data type nodes */ clir::data_type operator()(void_data_type const &); @@ -102,17 +102,12 @@ class opencl_ast { std::vector operator()(sum_inst const &s); std::vector operator()(yield_inst const &in); - /* Region nodes */ - clir::stmt operator()(rgn const &b); - - /* Func nodes */ - clir::func operator()(prototype const &p); - clir::func operator()(function const &fn); - - /* Program nodes */ - clir::prog operator()(program const &p); + auto run_on_program(program &p) -> clir::prog; private: + auto run_on_region(rgn ®) -> clir::stmt; + auto run_on_function(function &fn) -> clir::func; + auto get_dope_vector(value_node *v) -> dope_vector &; void set_dope_vector(value_node *v, dope_vector dv); clir::var declare(value_node const &v); @@ -134,4 +129,4 @@ class opencl_ast { } // namespace tinytc -#endif // OPENCL_AST_20230309_HPP +#endif // CONVERT_TO_OPENCL_20240913_HPP diff --git a/src/passes.cpp b/src/passes.cpp deleted file mode 100644 index d9c3cd53..00000000 --- a/src/passes.cpp +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "passes.hpp" -#include "device_info.hpp" -#include "kernel_metadata.hpp" -#include "node/data_type_node.hpp" -// #include "node/function_node.hpp" -// #include "node/program_node.hpp" -#include "pass/check_ir.hpp" -// #include "pass/constant_propagation.hpp" -#include "pass/dump_ir.hpp" -#include "pass/equal.hpp" -// #include "pass/insert_barrier.hpp" -// #include "pass/lifetime_analysis.hpp" -// #include "pass/lower_linalg.hpp" -// #include "pass/metadata.hpp" -// #include "pass/opencl_ast.hpp" -// #include "pass/stack.hpp" -// #include "pass/work_group_size.hpp" -#include "support/visit.hpp" - -namespace tinytc { - -// void check_ir(tinytc_prog const &p) { return visit(ir_checker{}, p); } - -// void dump_ir(std::ostream &os, tinytc_func const &f) { visit(ir_dumper{os}, f); } -void dump_ir(std::ostream &os, tinytc_prog const &p) { run_function_pass(dump_ir_pass{os}, p); } - -/*clir::prog generate_opencl_ast(tinytc_prog const &p, ::tinytc_core_info const &info) { - return visit(opencl_ast{&info}, p); -} - -auto get_metadata(tinytc_prog const &p) -> std::unordered_map { - auto v = metadata{}; - visit(v, p); - return v.get_result(); -} - -void insert_barriers(tinytc_func &f) { visit(insert_barrier{}, f); } -void insert_barriers(tinytc_prog &p) { visit(insert_barrier{}, p); } - -void insert_lifetime_stop_inst(tinytc_func &f) { visit(lifetime_inserter{}, f); } -void insert_lifetime_stop_inst(tinytc_prog &p) { visit(lifetime_inserter{}, p); }*/ - -bool is_equal(tinytc_data_type const &a, tinytc_data_type const &b) { return visit(equal{}, a, b); } - -/*void lower_linalg(tinytc_prog &p, ::tinytc_core_info const &info) { - visit(lower_linalg_pass{&info}, p); -} - -void propagate_constants(tinytc_prog &p) { visit(constant_propagation{}, p); } - -void set_stack_ptrs(tinytc_func &f) { visit(stack_ptr{}, f); } -void set_stack_ptrs(tinytc_prog &p) { visit(stack_ptr{}, p); } - -void set_work_group_size(tinytc_func &f, ::tinytc_core_info const &info) { - visit(work_group_size{&info}, f); -} -void set_work_group_size(tinytc_prog &p, ::tinytc_core_info const &info) { - visit(work_group_size{&info}, p); -}*/ - -} // namespace tinytc - diff --git a/src/passes.hpp b/src/passes.hpp index 004c0324..d8961911 100644 --- a/src/passes.hpp +++ b/src/passes.hpp @@ -4,50 +4,11 @@ #ifndef PASSES_20240314_HPP #define PASSES_20240314_HPP -#include "kernel_metadata.hpp" #include "node/program_node.hpp" #include "tinytc/types.h" -#include -#include -#include -#include - namespace tinytc { -//! Check whether some IR rules are respected -// void check_ir(tinytc_prog const &p); -//! Dump IR to ostream -void dump_ir(std::ostream &os, tinytc_func const &f); -//! Dump IR to ostream -void dump_ir(std::ostream &os, tinytc_prog const &p); -//! Generate OpenCL AST -/*clir::prog generate_opencl_ast(tinytc_prog const &p, tinytc_core_info const &info); -//! Get kernel metadata -auto get_metadata(tinytc_prog const &p) -> std::unordered_map; -//! Insert barriers where necessary -void insert_barriers(tinytc_func &f); -//! Insert barriers where necessary -void insert_barriers(tinytc_prog &p); -//! Insert lifetime stop instructions for set_stack_ptrs pass -void insert_lifetime_stop_inst(tinytc_func &f); -//! Insert lifetime stop instructions for set_stack_ptrs pass -void insert_lifetime_stop_inst(tinytc_prog &p);*/ -//! Check whether data types a and b are equal -bool is_equal(tinytc_data_type const &a, tinytc_data_type const &b); -//! Implement linear algebra instructions -/*void lower_linalg(tinytc_prog &p, tinytc_core_info const &info); -//! Constant propagation -void propagate_constants(tinytc_prog &p); -//! Manage temporary memory requested by alloca -void set_stack_ptrs(tinytc_func &f); -//! Manage temporary memory requested by alloca -void set_stack_ptrs(tinytc_prog &p); -//! Choose work group and subgroup size heuristically if not given explicitly -void set_work_group_size(tinytc_func &f, tinytc_core_info const &info); -//! Choose work group and subgroup size heuristically if not given explicitly -void set_work_group_size(tinytc_prog &p, tinytc_core_info const &info);*/ - template void run_function_pass(FunctionPass &&pass, tinytc_prog const &p) { for (auto const &func : p.functions()) { pass.run_on_function(*func); diff --git a/src/prog.cpp b/src/prog.cpp index 7ee2c714..3de2f4a2 100644 --- a/src/prog.cpp +++ b/src/prog.cpp @@ -4,6 +4,7 @@ #include "error.hpp" #include "location.hpp" #include "node/program_node.hpp" +#include "pass/dump_ir.hpp" #include "passes.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" @@ -62,7 +63,7 @@ tinytc_status_t tinytc_prog_dump(const_tinytc_prog_t prg) { if (prg == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { dump_ir(std::cerr, *prg); }); + return exception_to_status_code([&] { run_function_pass(dump_ir_pass{std::cerr}, *prg); }); } tinytc_status_t tinytc_prog_print_to_file(const_tinytc_prog_t prg, char const *filename) { @@ -74,7 +75,7 @@ tinytc_status_t tinytc_prog_print_to_file(const_tinytc_prog_t prg, char const *f if (!stream.good()) { throw status::file_io_error; } - dump_ir(stream, *prg); + run_function_pass(dump_ir_pass{stream}, *prg); }); } @@ -85,7 +86,7 @@ tinytc_status_t tinytc_prog_print_to_string(const_tinytc_prog_t prg, char **str) return exception_to_status_code([&] { auto const text = [&] { auto oss = std::ostringstream{}; - dump_ir(oss, *prg); + run_function_pass(dump_ir_pass{oss}, *prg); return std::move(oss).str(); }(); auto const length = text.size() + 1; // Need to include terminating null character diff --git a/test/codegen/dope_vector0.ir b/test/codegen/dope_vector0.ir index bb285ac4..67ed61c7 100644 --- a/test/codegen/dope_vector0.ir +++ b/test/codegen/dope_vector0.ir @@ -2,9 +2,17 @@ ; SPDX-License-Identifier: BSD-3-Clause ; RUN: %tinytc-oc < %s | filecheck %s -; CHECK: void kernel(global float* K0, long K0_shape0, long K0_shape1, long K0_stride1, long offset, long size) func @kernel(%K0: memref, %offset: index, %size: index) { - ; CHECK-NEXT: global float* x0 = K0 + 4ll * 1 + offset * K0_stride1; - ; CHECK-NEXT: long x0_shape0 = size; %0 = subview %K0[4:%size, %offset] : memref +; CHECK: void kernel({{.*}} +; CHECK-NEXT: global float* x0 = K0 + 4ll * 1 + offset * K0_stride1; +; CHECK-NEXT: long x0_shape0 = size; +} + +func @kernel2(%K0: memref, %offset: index, %size: index) { + %0 = subview %K0[%offset, 4:%size] : memref +; CHECK: void kernel2({{.*}} +; CHECK-NEXT: global float* x0 = K0 + offset * 1 + 4ll * K0_stride1; +; CHECK-NEXT: long x0_shape0 = size; +; CHECK-NEXT: long x0_stride0 = K0_stride1; } From 906bc0ce5a69117fa79f77cdd8b7e27a160aa62c Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 16 Sep 2024 19:19:44 +0200 Subject: [PATCH 017/297] Add address space attribute; fix insert barrier pass; need to add test for insert barrier. Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.rst | 21 ++- docs/api/builder_capi.yaml | 2 + docs/api/builder_cxxapi.rst | 21 ++- docs/api/builder_cxxapi.yaml | 2 + docs/api/core_capi.rst | 14 ++ docs/api/core_cxxapi.rst | 14 ++ docs/manual/tensor-ir.rst | 42 +++++- include/tinytc/tinytc.h | 4 + include/tinytc/tinytc.hpp | 42 ++++-- include/tinytc/types.h | 13 +- include/tinytc/types.hpp | 10 +- src/CMakeLists.txt | 3 +- src/codegen_tools.cpp | 24 ++-- src/compiler.cpp | 3 + src/data_type.cpp | 9 +- src/error.cpp | 6 + src/gemm_generator.cpp | 9 +- src/inst.cpp | 11 ++ src/node/data_type_node.cpp | 5 +- src/node/data_type_node.hpp | 17 +-- src/node/inst_node.cpp | 4 +- src/node/inst_node.hpp | 14 +- src/parser/lexer.re | 5 + src/parser/parser_impl.yy | 50 ++++++- src/pass/convert_to_opencl.cpp | 57 +++++--- src/pass/dump_ir.cpp | 13 +- src/pass/insert_barrier.cpp | 222 +++++++++++++++--------------- src/pass/insert_barrier.hpp | 64 +++------ src/pass/insert_lifetime_stop.cpp | 4 +- src/pass/metadata.cpp | 30 ---- src/pass/metadata.hpp | 34 ----- src/passes.def | 1 + src/recipe/small_gemm_batched.cpp | 17 ++- src/recipe/tall_and_skinny.cpp | 12 +- src/scalar_type.cpp | 10 ++ src/scalar_type.hpp | 1 + src/support/walk.hpp | 24 ++-- test/opt/insert-lifetime-stop.ir | 18 +-- test/opt/work-group-size.ir | 15 +- 39 files changed, 506 insertions(+), 361 deletions(-) delete mode 100644 src/pass/metadata.cpp delete mode 100644 src/pass/metadata.hpp diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index 07205980..1a9fbe17 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -10,6 +10,8 @@ Common * Enumerations + * :ref:`tinytc_address_space_t` + * :ref:`tinytc_arithmetic_t` * :ref:`tinytc_arithmetic_unary_t` @@ -26,6 +28,8 @@ Common * Functions + * :ref:`tinytc_address_space_to_string` + * :ref:`tinytc_arithmetic_to_string` * :ref:`tinytc_arithmetic_unary_to_string` @@ -75,6 +79,11 @@ Common Common Enumerations ------------------- +tinytc_address_space_t +...................... + +.. doxygenenum:: tinytc_address_space_t + tinytc_arithmetic_t ................... @@ -111,6 +120,11 @@ TINYTC_DYNAMIC Common Functions ---------------- +tinytc_address_space_to_string +.............................. + +.. doxygenfunction:: tinytc_address_space_to_string + tinytc_arithmetic_to_string ........................... @@ -272,8 +286,6 @@ Function * :ref:`tinytc_function_create` - * :ref:`tinytc_function_prototype_create` - * :ref:`tinytc_function_set_subgroup_size` * :ref:`tinytc_function_set_work_group_size` @@ -290,11 +302,6 @@ tinytc_function_create .. doxygenfunction:: tinytc_function_create -tinytc_function_prototype_create -................................ - -.. doxygenfunction:: tinytc_function_prototype_create - tinytc_function_set_subgroup_size ................................. diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index 71b07c27..a33a1619 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -3,6 +3,7 @@ Builder C-API: Common: enum: + - tinytc_address_space_t - tinytc_arithmetic_t - tinytc_arithmetic_unary_t - tinytc_cmp_condition_t @@ -11,6 +12,7 @@ Builder C-API: define: - TINYTC_DYNAMIC function: + - tinytc_address_space_to_string - tinytc_arithmetic_to_string - tinytc_arithmetic_unary_to_string - tinytc_cmp_condition_to_string diff --git a/docs/api/builder_cxxapi.rst b/docs/api/builder_cxxapi.rst index 41f0cb28..d59b5992 100644 --- a/docs/api/builder_cxxapi.rst +++ b/docs/api/builder_cxxapi.rst @@ -10,6 +10,8 @@ Common * Enumerations + * :ref:`address_space` + * :ref:`arithmetic` * :ref:`arithmetic_unary` @@ -24,6 +26,8 @@ Common * :ref:`is_dynamic_value` + * :ref:`to_string(address_space)` + * :ref:`to_string(arithmetic)` * :ref:`to_string(arithmetic_unary)` @@ -53,6 +57,11 @@ Common Common Enumerations ------------------- +address_space +............. + +.. doxygenenum:: tinytc::address_space + arithmetic .......... @@ -86,6 +95,11 @@ is_dynamic_value .. doxygenfunction:: tinytc::is_dynamic_value +to_string(address_space) +........................ + +.. doxygenfunction:: tinytc::to_string(address_space) + to_string(arithmetic) ..................... @@ -217,8 +231,6 @@ Function * :ref:`make_function` - * :ref:`make_function_prototype` - * :ref:`set_work_group_size` * :ref:`set_subgroup_size` @@ -237,11 +249,6 @@ make_function .. doxygenfunction:: tinytc::make_function -make_function_prototype -....................... - -.. doxygenfunction:: tinytc::make_function_prototype - set_work_group_size ................... diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index 5cbe0308..1092232c 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -3,6 +3,7 @@ Builder C++-API: Common: enum: + - tinytc::address_space - tinytc::arithmetic - tinytc::arithmetic_unary - tinytc::cmp_condition @@ -10,6 +11,7 @@ Builder C++-API: - tinytc::transpose function: - tinytc::is_dynamic_value + - tinytc::to_string(address_space) - tinytc::to_string(arithmetic) - tinytc::to_string(arithmetic_unary) - tinytc::to_string(cmp_condition) diff --git a/docs/api/core_capi.rst b/docs/api/core_capi.rst index 81c1901d..45052d7b 100644 --- a/docs/api/core_capi.rst +++ b/docs/api/core_capi.rst @@ -241,6 +241,10 @@ Compiler * Functions + * :ref:`tinytc_run_function_pass` + + * :ref:`tinytc_list_function_passes` + * :ref:`tinytc_prog_compile_to_opencl` Compiler Enumerations @@ -254,6 +258,16 @@ tinytc_bundle_format_t Compiler Functions ------------------ +tinytc_run_function_pass +........................ + +.. doxygenfunction:: tinytc_run_function_pass + +tinytc_list_function_passes +........................... + +.. doxygenfunction:: tinytc_list_function_passes + tinytc_prog_compile_to_opencl ............................. diff --git a/docs/api/core_cxxapi.rst b/docs/api/core_cxxapi.rst index ba68c04e..d1919835 100644 --- a/docs/api/core_cxxapi.rst +++ b/docs/api/core_cxxapi.rst @@ -116,11 +116,25 @@ Compiler * Functions + * :ref:`run_function_pass` + + * :ref:`list_function_passes` + * :ref:`compile_to_opencl` Compiler Functions ------------------ +run_function_pass +................. + +.. doxygenfunction:: tinytc::run_function_pass + +list_function_passes +.................... + +.. doxygenfunction:: tinytc::list_function_passes + compile_to_opencl ................. diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 5e7b7ee9..b35a29c8 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -151,9 +151,10 @@ Memref type .. code:: abnf - memref-type = "memref<" scalar-type tensor-shape ["," memory-layout] ">" + memref-type = "memref<" scalar-type tensor-shape ["," memory-layout] ["," address-space] ">" constant-or-dynamic = integer-constant / "?" tensor-shape = *("x" constant-or-dynamic) + address-space = "global" / "local" A memref is a reference to a region of memory. In analogy to the C/C++-language, the memref can be thought of as a pointer, @@ -180,6 +181,13 @@ E.g. the memory layout of ``memref`` is ``strided<1,5,30>``. We note that ``memref`` and ``memref>`` are the same type. +Memrefs have an optional address space attribute. +The global address space referse to memory objects allocated from the global memory pool +that is shared by all work groups. +The local memory space is shared by all work-items of the work-group but inaccessible to another work-group. +The default address space is "global", memrefs with "local" address space are returned by +the alloca instruction. + Memory layout ............. @@ -272,7 +280,8 @@ A memref of the memref-type. Restrictions ~~~~~~~~~~~~ -The memref's size must known at compile-time, i.e. the tensor shape must not contain any dynamic modes. +- The memref's size must known at compile-time, i.e. the tensor shape must not contain any dynamic modes. +- The address space must be "local". Axpby ..... @@ -590,6 +599,34 @@ Op Allowed type Description .not integer-type Bitwise not ==== ============ ============================================================================== +Barrier +....... + +.. code:: abnf + + barrier-instruction = "barrier" [".global"] [".local"] + +Overview +~~~~~~~~ + +**Note:** Barriers are inserted automatically in collective regions, but not in SPMD regions. +Manual barrier insertion should only be only necessesary in SPMD regions. + + +Control barrier. +The barrier must be encountered by all work-items. +A work-item in a work-group is not allowed to continue until all work-items in the work-group +have reached the barrier. + +Aditional memory fences are controlled by the following attributes: + +========= ====================================================================================== +Attribute Description +========= ====================================================================================== +.global Ensure that global memory accesses become visible to the work-group. +.local Ensure that local memory accesses become visible to the work-group. +========= ====================================================================================== + Cast .... @@ -1047,7 +1084,6 @@ Additional instructions .. code:: abnf - barrier-instruction = "barrier" lifetime-stop-instruction = "lifetime_stop" local-identifier SPMD instructions diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index f38b3b8b..bc0447c3 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -73,6 +73,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_scalar_type_create(tinytc_data_type_t *dt, * @param stride_size [in][optional] number of elements in stride array; must be either 0 for * automatic stride calculation or must match shape_size; must be 0 if stride == nullptr * @param stride [in][optional][range(0, stride_size)] stride array + * @param addrspace [in][optional] Address space; default is tinytc_address_space_global * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise @@ -81,6 +82,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_memref_type_create(tinytc_data_type_t *dt, tinytc_scalar_type_t scalar_ty, uint32_t shape_size, const int64_t *shape, uint32_t stride_size, const int64_t *stride, + const tinytc_address_space_t addrspace, const tinytc_location_t *loc); /** @@ -207,6 +209,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_value_get_name(const_tinytc_value_t vl, cha /////// Instructions /////// //////////////////////////// +//! Convert address space to string +TINYTC_EXPORT char const *tinytc_address_space_to_string(tinytc_address_space_t as); //! Convert arithmetic operation type to string TINYTC_EXPORT char const *tinytc_arithmetic_to_string(tinytc_arithmetic_t op); //! Convert arithmetic operation type to string (unary) diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 6bc03307..816a1d58 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -330,18 +330,21 @@ inline data_type make_scalar(scalar_type scalar_ty, location const &loc = {}) { * @param scalar_ty Element type * @param shape Tensor shape * @param stride Tensor stride + * @param addrspace Address space * @param loc Source code location * * @return Data type */ inline data_type make_memref(scalar_type scalar_ty, std::vector const &shape, std::vector const &stride = {}, + const address_space addrspace = address_space::global, location const &loc = {}) { tinytc_data_type_t mt; - CHECK_STATUS_LOC(tinytc_memref_type_create(&mt, static_cast(scalar_ty), - shape.size(), shape.data(), stride.size(), - stride.data(), &loc), - loc); + CHECK_STATUS_LOC( + tinytc_memref_type_create(&mt, static_cast(scalar_ty), shape.size(), + shape.data(), stride.size(), stride.data(), + static_cast(addrspace), &loc), + loc); return data_type{mt}; } @@ -461,8 +464,8 @@ inline auto make_imm(float imm, location const &loc = {}) -> value { * * @return Value */ -inline auto make_imm(double imm, scalar_type type = scalar_type::f64, location const &loc = {}) - -> value { +inline auto make_imm(double imm, scalar_type type = scalar_type::f64, + location const &loc = {}) -> value { tinytc_value_t val; CHECK_STATUS_LOC( tinytc_float_imm_create(&val, imm, static_cast(type), &loc), loc); @@ -579,6 +582,17 @@ inline auto make_dynamic(location const &loc = {}) -> value { /////////// Inst /////////// //////////////////////////// +/** + * @brief Convert address space to string + * + * @param as Address space + * + * @return C-string + */ +inline char const *to_string(address_space as) { + return ::tinytc_address_space_to_string(static_cast<::tinytc_address_space_t>(as)); +} + /** * @brief Convert arithmetic operation type to string * @@ -1437,8 +1451,8 @@ class region_builder { * * @return Values returned by instruction */ - [[maybe_unused]] inline auto add_multivalued(inst i, std::string const &name = "") - -> std::vector { + [[maybe_unused]] inline auto + add_multivalued(inst i, std::string const &name = "") -> std::vector { auto results = i.get_values(); if (name.size() > 0) { int counter = 0; @@ -1553,8 +1567,8 @@ class region_builder { */ template auto ifelse(value const &condition, F &&then, G &&otherwise, - std::vector const &return_type_list = {}, location const &loc = {}) - -> std::vector { + std::vector const &return_type_list = {}, + location const &loc = {}) -> std::vector { auto bb1 = region_builder{}; then(bb1); auto bb2 = region_builder{}; @@ -1792,8 +1806,8 @@ inline auto make_core_info_intel_from_arch(intel_gpu_architecture arch) -> core_ * @return Core info */ inline auto make_core_info_intel(std::uint32_t ip_version, std::int32_t num_eus_per_subslice, - std::int32_t num_threads_per_eu, std::vector sgs) - -> core_info { + std::int32_t num_threads_per_eu, + std::vector sgs) -> core_info { tinytc_core_info_t info; CHECK_STATUS(tinytc_core_info_intel_create(&info, ip_version, num_eus_per_subslice, num_threads_per_eu, sgs.size(), sgs.data())); @@ -2331,8 +2345,8 @@ inline auto make_tall_and_skinny(core_info const &info, scalar_type ty, std::int inline auto make_tall_and_skinny_specialized(core_info const &info, scalar_type ty, std::int64_t M, std::int64_t N, std::int64_t K, std::int64_t ldA, std::int64_t ldB, std::int64_t ldC, - std::int32_t M_block_size = 0, source_context ctx = {}) - -> tall_and_skinny { + std::int32_t M_block_size = 0, + source_context ctx = {}) -> tall_and_skinny { tinytc_recipe_t rec; CHECK_STATUS(tinytc_recipe_tall_and_skinny_create_specialized( &rec, info.get(), static_cast(ty), M, N, K, ldA, ldB, ldC, diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 8d0170bf..ff1ff8b8 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -40,6 +40,7 @@ typedef enum { tinytc_status_invalid_kernel_arguments = 0xd, ///< Kernel got invalid arguments tinytc_status_unsupported_device = 0xe, ///< Unsupported device tinytc_status_invalid_core_info = 0xf, ///< Invalid core info object + tinytc_status_unknown_pass_name = 0x10, ///< Invalid compiler pass name // IR errors tinytc_status_ir_out_of_bounds = 0x100, ///< Out of bounds access tinytc_status_ir_invalid_shape = 0x101, ///< Invalid tensor shape @@ -60,7 +61,9 @@ typedef enum { tinytc_status_ir_expand_shape_mismatch = 0x110, ///< Invalid expand shape tinytc_status_ir_collective_called_from_spmd = 0x111, ///< Collective instruction from SPMD tinytc_status_ir_fp_unsupported = 0x112, ///< Instruction does not support floating type - tinytc_status_ir_spmd_called_from_collective = 0x113, ///< SPMD instruction from collective + tinytc_status_ir_spmd_called_from_collective = 0x113, ///< SPMD instruction from collective + tinytc_status_ir_expected_local_address_space = 0x114, ///< Expected local address space + tinytc_status_ir_expected_global_address_space = 0x115, ///< Expected global address space // Level zero errors tinytc_status_ze_result_not_ready = 0x10000, ///< ZE_RESULT_NOT_READY tinytc_status_ze_result_error_device_lost = 0x10001, ///< ZE_RESULT_ERROR_DEVICE_LOST @@ -262,6 +265,12 @@ typedef enum { tinytc_transpose_T = 1 ///< Transpose } tinytc_transpose_t; +//! Address space +typedef enum { + tinytc_address_space_global = 0x1, ///< Global memory + tinytc_address_space_local = 0x2 ///< Local memory, returned by alloca +} tinytc_address_space_t; + //! Core features that may be optionally enabled typedef enum { /** @@ -277,6 +286,8 @@ typedef enum { //! Type for combination of core feature flags typedef uint32_t tinytc_core_feature_flags_t; +//! + /** * @brief IP versions for Intel GPUs * diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 6465805a..cfcf5b04 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -50,6 +50,7 @@ enum class status { invalid_kernel_arguments = tinytc_status_invalid_kernel_arguments, unsupported_device = tinytc_status_unsupported_device, invalid_core_info = tinytc_status_invalid_core_info, + unknown_pass_name = tinytc_status_unknown_pass_name, // IR errors ir_out_of_bounds = tinytc_status_ir_out_of_bounds, ir_invalid_shape = tinytc_status_ir_invalid_shape, @@ -71,7 +72,8 @@ enum class status { ir_collective_called_from_spmd = tinytc_status_ir_collective_called_from_spmd, ir_fp_unsupported = tinytc_status_ir_fp_unsupported, ir_spmd_called_from_collective = tinytc_status_ir_spmd_called_from_collective, - // Level Zero errors + ir_expected_local_address_space = tinytc_status_ir_expected_local_address_space, + ir_expected_global_address_space = tinytc_status_ir_expected_global_address_space, ze_result_not_ready = tinytc_status_ze_result_not_ready, ze_result_error_device_lost = tinytc_status_ze_result_error_device_lost, ze_result_error_out_of_host_memory = tinytc_status_ze_result_error_out_of_host_memory, @@ -246,6 +248,12 @@ enum class transpose { T = tinytc_transpose_T ///< transpose }; +//! Address space +enum class address_space { + global = tinytc_address_space_global, ///< Global memory + local = tinytc_address_space_local ///< Local memory, returned by alloca +}; + //! @brief Cf. @ref tinytc_core_feature_flag_t enum class core_feature_flag { large_register_file = tinytc_core_feature_flag_large_register_file }; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a90ba8c5..025e7b87 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -38,10 +38,9 @@ set(SOURCES #pass/constant_propagation.cpp pass/convert_to_opencl.cpp pass/dump_ir.cpp - #pass/insert_barrier.cpp + pass/insert_barrier.cpp pass/insert_lifetime_stop.cpp #pass/lower_linalg.cpp - #pass/metadata.cpp pass/slot_tracker.cpp pass/stack.cpp pass/work_group_size.cpp diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index b79b130d..3c89c7d0 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -89,7 +89,7 @@ auto get_block_rw_config(scalar_type ty) { return block_rw_config{builtin_type::void_t, nullptr, nullptr, nullptr}; } -expr sub_group_block_read_helper(expr pointer, scalar_type ty, address_space as) { +expr sub_group_block_read_helper(expr pointer, scalar_type ty, clir::address_space as) { const auto cfg = get_block_rw_config(ty); if (cfg.sub_group_block_read == nullptr) { return pointer[get_sub_group_local_id()]; @@ -98,7 +98,7 @@ expr sub_group_block_read_helper(expr pointer, scalar_type ty, address_space as) auto inst = (*cfg.sub_group_block_read)(std::move(pointer)); return (*cfg.as_type)(std::move(inst)); } -expr sub_group_block_write_helper(expr pointer, expr data, scalar_type ty, address_space as) { +expr sub_group_block_write_helper(expr pointer, expr data, scalar_type ty, clir::address_space as) { const auto cfg = get_block_rw_config(ty); if (cfg.sub_group_block_write == nullptr) { return pointer[get_sub_group_local_id()] = std::move(data); @@ -108,8 +108,8 @@ expr sub_group_block_write_helper(expr pointer, expr data, scalar_type ty, addre return (*cfg.sub_group_block_write)(std::move(pointer), std::move(data)); } -void store_helper(block_builder &bb, bool is_atomic, expr dst, scalar_type ty, address_space as, - expr value, scalar_type beta_ty, expr beta) { +void store_helper(block_builder &bb, bool is_atomic, expr dst, scalar_type ty, + clir::address_space as, expr value, scalar_type beta_ty, expr beta) { if (is_atomic) { atomic_store_helper(bb, std::move(dst), ty, as, std::move(value), beta_ty, std::move(beta)); } else { @@ -118,8 +118,8 @@ void store_helper(block_builder &bb, bool is_atomic, expr dst, scalar_type ty, a } } -void atomic_store_helper(block_builder &bb, expr dst, scalar_type ty, address_space as, expr value, - scalar_type beta_ty, expr beta) { +void atomic_store_helper(block_builder &bb, expr dst, scalar_type ty, clir::address_space as, + expr value, scalar_type beta_ty, expr beta) { int mode = -1; visit(overloaded{ [&](clir::internal::int_imm &c) { @@ -331,8 +331,8 @@ expr matrix_block_description::condition(int m_block, std::int32_t subgroup_size } auto read_matrix_block_regular(block_builder &bb, matrix_block_description const &d, int M_mode, - core_config const &core_cfg, char const *block_name) - -> std::unique_ptr { + core_config const &core_cfg, + char const *block_name) -> std::unique_ptr { assert(M_mode == 0 || M_mode == 1); const int m_blocks = 1 + (d.Mb - 1) / core_cfg.subgroup_size; @@ -365,8 +365,8 @@ auto read_matrix_block_regular(block_builder &bb, matrix_block_description const } auto read_matrix_block_vector(block_builder &bb, matrix_block_description const &d, int M_mode, - core_config const &core_cfg, char const *block_name) - -> std::unique_ptr { + core_config const &core_cfg, + char const *block_name) -> std::unique_ptr { assert(M_mode == 0 || M_mode == 1); const int m_blocks = 1 + (d.Mb - 1) / core_cfg.subgroup_size; @@ -393,8 +393,8 @@ auto read_matrix_block_vector(block_builder &bb, matrix_block_description const } auto read_matrix_block(block_builder &bb, matrix_block_description const &d, int M_mode, - core_config const &core_cfg, char const *block_name) - -> std::unique_ptr { + core_config const &core_cfg, + char const *block_name) -> std::unique_ptr { assert(M_mode == 0 || M_mode == 1); if (d.is_unit_stride(1 - M_mode) && !is_complex_type(d.ty) && diff --git a/src/compiler.cpp b/src/compiler.cpp index 84047fa0..08d4a496 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -8,6 +8,7 @@ #include "pass/check_ir.hpp" #include "pass/convert_to_opencl.hpp" #include "pass/dump_ir.hpp" +#include "pass/insert_barrier.hpp" #include "pass/insert_lifetime_stop.hpp" #include "pass/stack.hpp" #include "pass/work_group_size.hpp" @@ -17,6 +18,7 @@ #include "source.hpp" #include "tinytc/tinytc.h" #include "tinytc/types.h" +#include "tinytc/types.hpp" #include #include @@ -52,6 +54,7 @@ tinytc_status_t tinytc_run_function_pass(char const *pass_name, tinytc_prog_t pr #include "passes.def" #undef FUNCTION_PASS #undef FUNCTION_PASS_WITH_INFO + throw status::unknown_pass_name; }, ctx); } diff --git a/src/data_type.cpp b/src/data_type.cpp index 96e0a281..82b55d74 100644 --- a/src/data_type.cpp +++ b/src/data_type.cpp @@ -4,11 +4,11 @@ #include "error.hpp" #include "location.hpp" #include "node/data_type_node.hpp" +#include "support/util.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" -#include "support/util.hpp" #include #include @@ -33,6 +33,7 @@ tinytc_status_t tinytc_scalar_type_create(tinytc_data_type_t *dt, tinytc_scalar_ tinytc_status_t tinytc_memref_type_create(tinytc_data_type_t *dt, tinytc_scalar_type_t scalar_ty, uint32_t shape_size, const int64_t *shape, uint32_t stride_size, const int64_t *stride, + const tinytc_address_space_t addrspace, const tinytc_location_t *loc) { if (dt == nullptr) { return tinytc_status_invalid_arguments; @@ -44,9 +45,9 @@ tinytc_status_t tinytc_memref_type_create(tinytc_data_type_t *dt, tinytc_scalar_ if (stride_size > 0) { stride_vec.insert(stride_vec.end(), stride, stride + stride_size); } - *dt = std::make_unique(enum_cast(scalar_ty), - std::move(shape_vec), std::move(stride_vec), - get_optional(loc)) + *dt = std::make_unique( + enum_cast(scalar_ty), std::move(shape_vec), std::move(stride_vec), + enum_cast(addrspace), get_optional(loc)) .release(); }); } diff --git a/src/error.cpp b/src/error.cpp index ab6108c5..93f43f03 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -115,6 +115,8 @@ char const *tinytc_error_string(tinytc_status_t status) { case tinytc_status_invalid_core_info: return "Invalid core info object (e.g. max work group size is 0 or subgroup sizes vector " "is empty)"; + case tinytc_status_unknown_pass_name: + return "Unknown compiler pass name"; // IR case tinytc_status_ir_out_of_bounds: return "Argument is out of bounds"; @@ -156,6 +158,10 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Floating point type unsupported for instruction"; case tinytc_status_ir_spmd_called_from_collective: return "SPMD instruction must not be called from collective region"; + case tinytc_status_ir_expected_local_address_space: + return "A memref with local address space is expected"; + case tinytc_status_ir_expected_global_address_space: + return "A memref with global address space is expected"; // Level Zero case tinytc_status_ze_result_not_ready: return "ZE_RESULT_NOT_READY"; diff --git a/src/gemm_generator.cpp b/src/gemm_generator.cpp index aa25060b..7e96ee66 100644 --- a/src/gemm_generator.cpp +++ b/src/gemm_generator.cpp @@ -126,7 +126,8 @@ auto max_register_block_gemm(std::int32_t C_scalar_type_size_in_bytes, std::int3 class generator { public: generator(gemm_configuration const &gemm_cfg, local_tiling const &tiling, - core_config const &core_cfg, address_space As, address_space Bs, address_space Cs) + core_config const &core_cfg, clir::address_space As, clir::address_space Bs, + clir::address_space Cs) : gemm_cfg(gemm_cfg), tiling(tiling), core_cfg(core_cfg), Aspace(As), Bspace(Bs), Cspace(Cs) {} bool use_double_buffering() const; @@ -142,7 +143,7 @@ class generator { gemm_configuration const gemm_cfg; local_tiling const tiling; core_config const core_cfg; - address_space Aspace, Bspace, Cspace; + clir::address_space Aspace, Bspace, Cspace; int row_blocks_in_register = 1; int cols_in_register = 1; var c_acc, c_acc_im, m; @@ -425,8 +426,8 @@ ::clir::func generator::function(std::string_view name) { } ::clir::func generate_gemm(gemm_configuration const &gemm_cfg, local_tiling const &tiling, - core_config const &core_cfg, std::string_view name, address_space As, - address_space Bs, address_space Cs) { + core_config const &core_cfg, std::string_view name, + clir::address_space As, clir::address_space Bs, clir::address_space Cs) { return generator{gemm_cfg, tiling, core_cfg, As, Bs, Cs}.function(name); } diff --git a/src/inst.cpp b/src/inst.cpp index 8cb37e8c..33e2c27a 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -21,6 +21,17 @@ using namespace tinytc; extern "C" { + +char const *tinytc_address_space_to_string(tinytc_address_space_t as) { + switch (as) { + case tinytc_address_space_global: + return "global"; + case tinytc_address_space_local: + return "local"; + } + return "unknown"; +} + char const *tinytc_arithmetic_to_string(tinytc_arithmetic_t op) { switch (op) { case tinytc_arithmetic_add: diff --git a/src/node/data_type_node.cpp b/src/node/data_type_node.cpp index fb7ea76a..483c057b 100644 --- a/src/node/data_type_node.cpp +++ b/src/node/data_type_node.cpp @@ -11,9 +11,10 @@ namespace tinytc { memref_data_type::memref_data_type(scalar_type type, std::vector shape, - std::vector stride, location const &lc) + std::vector stride, address_space addrspace, + location const &lc) : data_type_node(DTK::memref), element_ty_(std::move(type)), shape_(std::move(shape)), - stride_(std::move(stride)) { + stride_(std::move(stride)), addrspace_(addrspace) { loc(lc); for (auto const &s : shape_) { if (s < 0 && !is_dynamic_value(s)) { diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp index 74d7d83e..03afa140 100644 --- a/src/node/data_type_node.hpp +++ b/src/node/data_type_node.hpp @@ -10,9 +10,6 @@ #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" -#include -#include - #include #include #include @@ -63,13 +60,10 @@ class memref_data_type : public data_type_node { public: inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::memref; } memref_data_type(scalar_type type, std::vector shape, - std::vector stride = {}, location const &lc = {}); + std::vector stride = {}, + address_space addrspace = address_space::global, location const &lc = {}); inline scalar_type element_ty() const { return element_ty_; } - inline clir::data_type clir_element_ty() const { return to_clir_ty(element_ty_, addrspace_); } - inline clir::data_type clir_atomic_element_ty() const { - return to_clir_atomic_ty(element_ty_, addrspace_); - } inline std::int64_t dim() const { return shape_.size(); } inline auto const &shape() const { return shape_; } inline std::int64_t shape(std::int64_t i) const { return shape_[i]; } @@ -78,8 +72,8 @@ class memref_data_type : public data_type_node { inline std::int64_t size_in_bytes() const { return is_dynamic() ? dynamic : size(element_ty_) * stride_.back() * shape_.back(); } - inline clir::address_space addrspace() const { return addrspace_; } - inline void addrspace(clir::address_space space) { addrspace_ = space; } + inline auto addrspace() const -> address_space { return addrspace_; } + inline void addrspace(address_space space) { addrspace_ = space; } inline bool is_dynamic_shape() const { return std::any_of(shape_.begin(), shape_.end(), is_dynamic_value); @@ -95,7 +89,7 @@ class memref_data_type : public data_type_node { scalar_type element_ty_; std::vector shape_, stride_; - clir::address_space addrspace_ = clir::address_space::global_t; + address_space addrspace_ = address_space::global; }; class scalar_data_type : public data_type_node { @@ -107,7 +101,6 @@ class scalar_data_type : public data_type_node { } inline scalar_type ty() const { return ty_; } - inline clir::data_type clir_ty() const { return to_clir_ty(ty_); } private: scalar_type ty_; diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 48e6aa2d..9c97e1d2 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -85,7 +85,9 @@ alloca_inst::alloca_inst(data_type ty, location const &lc) if (memref == nullptr) { throw compilation_error(loc(), status::ir_expected_memref); } - memref->addrspace(clir::address_space::local_t); + if (memref->addrspace() != address_space::local) { + throw compilation_error(loc(), status::ir_expected_local_address_space); + } } axpby_inst::axpby_inst(transpose tA, value alpha0, value A0, value beta0, value B0, bool atomic, diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index b4998981..c489a2ad 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -387,7 +387,19 @@ class arith_unary_inst : public standard_inst<1, 1> { class barrier_inst : public standard_inst<0, 0> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::barrier; } - inline barrier_inst() : standard_inst{IK::barrier} {} + inline barrier_inst(std::int32_t fence_flags, location const &lc = {}) + : standard_inst{IK::barrier}, fence_flags_(fence_flags) { + loc(lc); + } + + inline auto fence_flags() const -> std::int32_t { return fence_flags_; } + inline auto fence_flags(std::int32_t fence_flags) { fence_flags_ = fence_flags; } + inline auto has_fence(address_space as) const { + return (fence_flags_ & static_cast(as)) > 0; + } + + private: + std::int32_t fence_flags_; }; class cast_inst : public standard_inst<1, 1> { diff --git a/src/parser/lexer.re b/src/parser/lexer.re index 66187da1..76037996 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -88,6 +88,10 @@ lex: ".n" { adv_loc(); return parser::make_NOTRANS(loc_); } ".t" { adv_loc(); return parser::make_TRANS(loc_); } ".atomic" { adv_loc(); return parser::make_ATOMIC(loc_); } + "local" { adv_loc(); return parser::make_LOCAL(loc_); } + "global" { adv_loc(); return parser::make_GLOBAL(loc_); } + ".local" { adv_loc(); return parser::make_LOCAL_ATTR(loc_); } + ".global" { adv_loc(); return parser::make_GLOBAL_ATTR(loc_); } // constants "true" { adv_loc(); return parser::make_INTEGER_CONSTANT(1, loc_); } @@ -124,6 +128,7 @@ lex: // instructions "axpby" { adv_loc(); return parser::make_AXPBY(loc_); } "arith" { adv_loc(); return parser::make_ARITH(loc_); } + "barrier" { adv_loc(); return parser::make_BARRIER(loc_); } "gemm" { adv_loc(); return parser::make_GEMM(loc_); } "gemv" { adv_loc(); return parser::make_GEMV(loc_); } "ger" { adv_loc(); return parser::make_GER(loc_); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 6e93fa8e..c3f5873a 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -103,12 +103,17 @@ NOTRANS ".n" TRANS ".t" ATOMIC ".atomic" + LOCAL "local" + GLOBAL "global" + LOCAL_ATTR ".local" + GLOBAL_ATTR ".global" MEMREF "memref" GROUP "group" OFFSET "offset" STRIDED "strided" AXPBY "axpby" ARITH "arith" + BARRIER "barrier" GEMM "gemm" GEMV "gemv" GER "ger" @@ -155,6 +160,7 @@ %nterm data_type %nterm scalar_type %nterm memref_type +%nterm optional_address_space %nterm > mode_list %nterm > optional_stride_list %nterm > stride_list @@ -171,6 +177,9 @@ %nterm <::tinytc::value> identifier_or_constant %nterm > optional_identifier_or_constant_list %nterm > identifier_or_constant_list +%nterm barrier_inst +%nterm optional_global_attr +%nterm optional_local_attr %nterm gemm_inst %nterm gemv_inst %nterm ger_inst @@ -302,11 +311,12 @@ scalar_type: ; memref_type: - MEMREF LCHEV scalar_type mode_list RCHEV { + MEMREF LCHEV scalar_type mode_list optional_address_space RCHEV { try { $$ = data_type { std::make_unique($scalar_type, std::move($mode_list), - std::vector{}, @memref_type) + std::vector{}, $optional_address_space, + @memref_type) .release() }; } catch (compilation_error const &e) { @@ -314,7 +324,7 @@ memref_type: YYERROR; } } - | MEMREF LCHEV scalar_type mode_list COMMA STRIDED LCHEV optional_stride_list RCHEV RCHEV { + | MEMREF LCHEV scalar_type mode_list COMMA STRIDED LCHEV optional_stride_list RCHEV optional_address_space RCHEV { if ($mode_list.size() != $optional_stride_list.size()) { auto loc = @scalar_type; loc.end = @optional_stride_list.end; @@ -323,7 +333,8 @@ memref_type: try { $$ = data_type { std::make_unique($scalar_type, std::move($mode_list), - std::move($optional_stride_list), @memref_type) + std::move($optional_stride_list), + $optional_address_space, @memref_type) .release() }; } catch (compilation_error const &e) { @@ -338,6 +349,12 @@ mode_list: | mode_list TIMES constant_or_dynamic { $$ = std::move($1); $$.push_back($constant_or_dynamic); } ; +optional_address_space: + %empty { $$ = address_space::global; } + | COMMA GLOBAL { $$ = address_space::global; } + | COMMA LOCAL { $$ = address_space::local; } +; + optional_stride_list: %empty {} | stride_list { $$ = std::move($1); } @@ -393,6 +410,7 @@ instructions: instruction: axpby_inst + | barrier_inst | gemm_inst | gemv_inst | ger_inst @@ -451,6 +469,30 @@ identifier_or_constant_list: } ; +barrier_inst: + BARRIER optional_global_attr optional_local_attr { + int32_t fence_flags = 0; + fence_flags |= $optional_global_attr; + fence_flags |= $optional_local_attr; + try { + $$ = inst { std::make_unique(fence_flags, @barrier_inst).release() }; + } catch (compilation_error const &e) { + error(e.loc(), e.what()); + YYERROR; + } + } +; + +optional_global_attr: + %empty { $$ = 0; } + | GLOBAL_ATTR { $$ = tinytc_address_space_global; } +; + +optional_local_attr: + %empty { $$ = 0; } + | LOCAL_ATTR { $$ = tinytc_address_space_local; } +; + gemm_inst: GEMM transpose[ta] transpose[tb] atomic identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index ab071c39..4586a253 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -181,10 +181,10 @@ clir::data_type convert_to_opencl_pass::operator()(group_data_type const &g) { return ptr_ty; } clir::data_type convert_to_opencl_pass::operator()(memref_data_type const &d) { - return clir::pointer_to(d.clir_element_ty()); + return clir::pointer_to(to_clir_ty(d.element_ty(), to_clir_address_space(d.addrspace()))); } clir::data_type convert_to_opencl_pass::operator()(scalar_data_type const &s) { - return s.clir_ty(); + return to_clir_ty(s.ty()); } /* Value nodes */ @@ -219,7 +219,7 @@ std::vector convert_to_opencl_pass::operator()(alloca_inst const &a) if (t == nullptr) { throw compilation_error(a.loc(), status::ir_expected_memref); } - auto ptr_ty = clir::pointer_to(t->clir_element_ty()); + auto ptr_ty = operator()(*t); auto result = declaration_assignment(ptr_ty, std::move(result_var), clir::cast(ptr_ty, stack_ + a.stack_ptr())); stack_high_water_mark_ = std::max(stack_high_water_mark_, @@ -255,8 +255,9 @@ std::vector convert_to_opencl_pass::operator()(axpby_inst const &ins auto a = Ab[(block + m) * adv.stride(pA)]; auto b = bb.declare_assign((*this)(*bt), "b", Bb + (block + m) * bdv.stride(0)); const auto a_scaled = multiply(alpha_ty, at->element_ty(), alpha, std::move(a)); - store_helper(bb, inst.atomic(), b, bt->element_ty(), bt->addrspace(), - std::move(a_scaled), beta_ty, beta); + store_helper(bb, inst.atomic(), b, bt->element_ty(), + to_clir_address_space(bt->addrspace()), std::move(a_scaled), + beta_ty, beta); }; if (is_remainder) { bb.add(clir::if_selection_builder(m < std::move(inner_trip_count)) @@ -273,8 +274,8 @@ std::vector convert_to_opencl_pass::operator()(axpby_inst const &ins if (bt->dim() == 0) { auto bb = clir::block_builder{}; const auto a_scaled = multiply(alpha_ty, at->element_ty(), alpha, A[0]); - store_helper(bb, inst.atomic(), B, bt->element_ty(), bt->addrspace(), std::move(a_scaled), - beta_ty, std::move(beta)); + store_helper(bb, inst.atomic(), B, bt->element_ty(), to_clir_address_space(bt->addrspace()), + std::move(a_scaled), beta_ty, std::move(beta)); return {bb.get_product()}; } @@ -311,9 +312,16 @@ std::vector convert_to_opencl_pass::operator()(axpby_inst const &ins throw compilation_error(inst.loc(), status::ir_expected_vector_or_matrix); } -std::vector convert_to_opencl_pass::operator()(barrier_inst const &) { - return {clir::expression_statement(clir::call_builtin( - clir::builtin_function::barrier, {clir::cl_mem_fence_flags::CLK_LOCAL_MEM_FENCE}))}; +std::vector convert_to_opencl_pass::operator()(barrier_inst const &b) { + clir::expr fence = 0; + if (b.has_fence(address_space::global)) { + fence = fence | clir::cl_mem_fence_flags::CLK_GLOBAL_MEM_FENCE; + } + if (b.has_fence(address_space::local)) { + fence = fence | clir::cl_mem_fence_flags::CLK_LOCAL_MEM_FENCE; + } + return {clir::expression_statement( + clir::call_builtin(clir::builtin_function::barrier, {std::move(fence)}))}; } std::vector convert_to_opencl_pass::operator()(arith_inst const &a) { @@ -630,8 +638,9 @@ std::vector convert_to_opencl_pass::operator()(gemm_inst const &g) { name = cfg.identifier("gemm" + std::to_string(++name_counter)); } if (has_gemm_.find(name) == has_gemm_.end()) { - auto f = generate_gemm(cfg, tiling_, core_cfg_, name, a->addrspace(), b->addrspace(), - c->addrspace()); + auto f = generate_gemm(cfg, tiling_, core_cfg_, name, to_clir_address_space(a->addrspace()), + to_clir_address_space(b->addrspace()), + to_clir_address_space(c->addrspace())); prog_builder_.add(std::move(f)); } has_gemm_.emplace(name); @@ -683,8 +692,9 @@ std::vector convert_to_opencl_pass::operator()(gemv_inst const &g) { name = cfg.identifier("gemv" + std::to_string(++name_counter)); } if (has_gemm_.find(name) == has_gemm_.end()) { - auto f = generate_gemm(cfg, tiling_, core_cfg_, name, a->addrspace(), b->addrspace(), - c->addrspace()); + auto f = generate_gemm(cfg, tiling_, core_cfg_, name, to_clir_address_space(a->addrspace()), + to_clir_address_space(b->addrspace()), + to_clir_address_space(c->addrspace())); prog_builder_.add(std::move(f)); } has_gemm_.emplace(name); @@ -745,8 +755,8 @@ std::vector convert_to_opencl_pass::operator()(ger_inst const &g) { const auto ab_scaled = multiply(alpha_ty, ct->element_ty(), alpha, std::move(ab)); store_helper(bb, g.atomic(), c, ct->element_ty(), - ct->addrspace(), std::move(ab_scaled), beta_ty, - beta); + to_clir_address_space(ct->addrspace()), + std::move(ab_scaled), beta_ty, beta); }; if (is_remainder) { bb.add(clir::if_selection_builder( @@ -830,8 +840,9 @@ std::vector convert_to_opencl_pass::operator()(hadamard_inst const & to_clir_ty(ct->element_ty()), "ab", multiply(at->element_ty(), bt->element_ty(), std::move(a), b)); const auto ab_scaled = multiply(alpha_ty, ct->element_ty(), alpha, std::move(ab)); - store_helper(bb, g.atomic(), c, ct->element_ty(), ct->addrspace(), - std::move(ab_scaled), beta_ty, beta); + store_helper(bb, g.atomic(), c, ct->element_ty(), + to_clir_address_space(ct->addrspace()), std::move(ab_scaled), beta_ty, + beta); }; if (is_remainder) { bb.add(clir::if_selection_builder(m < std::move(inner_trip_count)) @@ -1016,8 +1027,9 @@ std::vector convert_to_opencl_pass::operator()(sum_inst const &inst) clir::get_sub_group_local_id() == 0) .then([&](clir::block_builder &bb) { const auto sum_scaled = multiply(alpha_ty, at->element_ty(), alpha, sum); - store_helper(bb, inst.atomic(), B, bt->element_ty(), bt->addrspace(), - std::move(sum_scaled), beta_ty, beta); + store_helper(bb, inst.atomic(), B, bt->element_ty(), + to_clir_address_space(bt->addrspace()), std::move(sum_scaled), + beta_ty, beta); }) .get_product()); } else if (bt->dim() == 1) { @@ -1040,8 +1052,9 @@ std::vector convert_to_opencl_pass::operator()(sum_inst const &inst) .get_product()); auto b = bb.declare_assign((*this)(*bt), "b", B + (block + m) * bdv.stride(0)); const auto sum_scaled = multiply(alpha_ty, at->element_ty(), alpha, acc); - store_helper(bb, inst.atomic(), b, bt->element_ty(), bt->addrspace(), - std::move(sum_scaled), beta_ty, beta); + store_helper(bb, inst.atomic(), b, bt->element_ty(), + to_clir_address_space(bt->addrspace()), std::move(sum_scaled), + beta_ty, beta); }; if (is_remainder) { bb.add(clir::if_selection_builder(m < std::move(inner_trip_count)) diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 4cd250ad..b85ba967 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -38,6 +38,9 @@ void dump_ir_pass::operator()(memref_data_type const &d) { do_with_infix(d.stride().begin(), d.stride().end(), [&](auto const &a) { val(a); }); *os_ << ">"; } + if (d.addrspace() != address_space::global) { + *os_ << "," << to_string(d.addrspace()); + } *os_ << ">"; } void dump_ir_pass::operator()(scalar_data_type const &s) { *os_ << to_string(s.ty()); } @@ -134,7 +137,15 @@ void dump_ir_pass::operator()(arith_unary_inst const &a) { visit(*this, *a.a()->ty()); } -void dump_ir_pass::operator()(barrier_inst const &) { *os_ << "barrier"; } +void dump_ir_pass::operator()(barrier_inst const &b) { + *os_ << "barrier"; + if (b.has_fence(address_space::global)) { + *os_ << ".global"; + } + if (b.has_fence(address_space::local)) { + *os_ << ".local"; + } +} void dump_ir_pass::operator()(cast_inst const &c) { visit(*this, *c.result()); diff --git a/src/pass/insert_barrier.cpp b/src/pass/insert_barrier.cpp index 8fa2bab8..f210ebdc 100644 --- a/src/pass/insert_barrier.cpp +++ b/src/pass/insert_barrier.cpp @@ -2,7 +2,10 @@ // SPDX-License-Identifier: BSD-3-Clause #include "pass/insert_barrier.hpp" -#include "pass/alias_analysis.hpp" +#include "analysis/alias.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "node/value_node.hpp" #include "support/casting.hpp" #include "support/visit.hpp" #include "tinytc/tinytc.hpp" @@ -15,142 +18,133 @@ namespace tinytc { -/* Data type nodes */ -bool insert_barrier::operator()(void_data_type &) { return false; } -bool insert_barrier::operator()(group_data_type &b) { return visit(*this, *b.ty()); } -bool insert_barrier::operator()(memref_data_type &m) { - return m.addrspace() == clir::address_space::local_t; -} -bool insert_barrier::operator()(scalar_data_type &) { return false; } - -/* Value nodes */ -value_node *insert_barrier::operator()(float_imm &) { return nullptr; } -value_node *insert_barrier::operator()(int_imm &) { return nullptr; } -value_node *insert_barrier::operator()(val &v) { - if (visit(*this, *v.ty())) { - return &v; +auto intersects(std::unordered_set<::tinytc_value const *> const &a, + std::unordered_set<::tinytc_value const *> const &b, aa_results const &aa) { + for (auto &av : a) { + for (auto &bv : b) { + if (aa.alias(*av, *bv)) { + return true; + } + } } - return nullptr; + return false; } -/* Inst nodes */ -std::unordered_set insert_barrier::operator()(inst_node &) { return {}; } - -std::unordered_set insert_barrier::operator()(blas_a2_inst &g) { - auto rw = std::unordered_set{}; - rw.emplace(visit(*this, *g.A())); - rw.emplace(visit(*this, *g.B())); - return rw; +void insert_barrier_pass::reads_writes::clear(address_space as) { + const auto space = address_space_to_index(as); + reads[space].clear(); + writes[space].clear(); } -std::unordered_set insert_barrier::operator()(blas_a3_inst &inst) { - auto rw = std::unordered_set{}; - rw.emplace(visit(*this, *inst.A())); - rw.emplace(visit(*this, *inst.B())); - rw.emplace(visit(*this, *inst.C())); - return rw; +void insert_barrier_pass::reads_writes::merge(reads_writes &&other) { + for (std::size_t i = 0; i < reads.size(); ++i) { + reads[i].merge(std::move(other.reads[i])); + } + for (std::size_t i = 0; i < reads.size(); ++i) { + writes[i].merge(std::move(other.writes[i])); + } } -std::unordered_set insert_barrier::operator()(loop_inst &p) { - return visit(*this, *p.body()); +void insert_barrier_pass::reads_writes::emplace_read(address_space as, ::tinytc_value const *val) { + const auto space = address_space_to_index(as); + reads[space].emplace(val); } - -std::unordered_set insert_barrier::operator()(alloca_inst &) { return {}; } - -std::unordered_set insert_barrier::operator()(barrier_inst &) { - last_instruction_was_barrier_ = true; - return {}; +void insert_barrier_pass::reads_writes::emplace_write(address_space as, ::tinytc_value const *val) { + const auto space = address_space_to_index(as); + writes[space].emplace(val); } -std::unordered_set insert_barrier::operator()(expand_inst &) { return {}; } -std::unordered_set insert_barrier::operator()(fuse_inst &) { return {}; } - -std::unordered_set insert_barrier::operator()(load_inst &e) { - auto rw = std::unordered_set{}; - auto t = dyn_cast(e.operand()->ty().get()); - if (t) { - rw.emplace(visit(*this, *e.operand())); - } - return rw; +bool insert_barrier_pass::reads_writes::raw(address_space as, reads_writes const &rw, + aa_results const &aa) { + const auto space = address_space_to_index(as); + return intersects(rw.reads[space], writes[space], aa); } - -std::unordered_set insert_barrier::operator()(if_inst &in) { - auto s = visit(*this, *in.then()); - if (in.otherwise()) { - s.merge(visit(*this, *in.otherwise())); - } - return s; +bool insert_barrier_pass::reads_writes::war(address_space as, reads_writes const &rw, + aa_results const &aa) { + const auto space = address_space_to_index(as); + return intersects(rw.writes[space], reads[space], aa); } - -std::unordered_set insert_barrier::operator()(lifetime_stop_inst &) { return {}; } - -std::unordered_set insert_barrier::operator()(parallel_inst &p) { - return visit(*this, *p.body()); +bool insert_barrier_pass::reads_writes::waw(address_space as, reads_writes const &rw, + aa_results const &aa) { + const auto space = address_space_to_index(as); + return intersects(rw.writes[space], writes[space], aa); +} +bool insert_barrier_pass::reads_writes::raw_war_or_waw(address_space as, reads_writes const &rw, + aa_results const &aa) { + return raw(as, rw, aa) || war(as, rw, aa) || waw(as, rw, aa); } -std::unordered_set insert_barrier::operator()(size_inst &) { return {}; } - -std::unordered_set insert_barrier::operator()(store_inst &s) { - auto rw = std::unordered_set{}; - rw.emplace(visit(*this, *s.operand())); - return rw; +auto insert_barrier_pass::reads_writes::address_space_to_index(address_space as) -> std::size_t { + for (std::size_t i = 0; i < address_spaces.size(); ++i) { + if (as == address_spaces[i]) { + return i; + } + } + throw internal_compiler_error{}; } -std::unordered_set insert_barrier::operator()(subview_inst &) { return {}; } -std::unordered_set insert_barrier::operator()(yield_inst &) { return {}; } - -/* Region nodes */ -std::unordered_set insert_barrier::operator()(rgn &b) { - auto const intersects = [this](std::unordered_set const &a, - std::unordered_set const &b) { - for (auto &aa : a) { - if (aa != nullptr) { - for (auto &bb : b) { - if (aa_.alias(*aa, *bb)) { - return true; - } +auto insert_barrier_pass::run_on_region(rgn ®, aa_results const &aa) -> reads_writes { + auto invisible_rw = reads_writes{}; + for (auto it = reg.begin(); it != reg.end(); ++it) { + if (auto *barrier = dyn_cast(it->get()); barrier) { + for (auto &as : reads_writes::address_spaces) { + if (barrier->has_fence(as)) { + invisible_rw.clear(as); } } + } else { + auto rw = reads_writes{}; + for (auto &subreg : (*it)->child_regions()) { + rw.merge(run_on_region(*subreg, aa)); + } + + auto const emplace_read = [&rw](value const &v) { + if (auto *m = dyn_cast(v->ty().get()); m) { + rw.emplace_read(m->addrspace(), v.get()); + } + }; + auto const emplace_write = [&rw](value const &v) { + if (auto *m = dyn_cast(v->ty().get()); m) { + rw.emplace_write(m->addrspace(), v.get()); + } + }; + visit(overloaded{[&](blas_a2_inst &in) { + emplace_read(in.A()); + emplace_write(in.B()); + }, + [&](blas_a3_inst &in) { + emplace_read(in.A()); + emplace_read(in.B()); + emplace_write(in.C()); + }, + [&](load_inst &in) { emplace_read(in.operand()); }, + [&](store_inst &in) { emplace_read(in.operand()); }, + [](inst_node &) {}}, + **it); + + std::int32_t fence_flags = 0; + for (auto &as : reads_writes::address_spaces) { + if (invisible_rw.raw_war_or_waw(as, rw, aa)) { + fence_flags |= static_cast(as); + invisible_rw.clear(as); + } + } + if (fence_flags != 0) { + it = reg.insert(it, inst{std::make_unique(fence_flags).release()}); + ++it; // skip over barrier + } + + invisible_rw.merge(std::move(rw)); } - return false; - }; - - auto rw = std::unordered_set{}; - auto insts = std::vector{}; - insts.reserve(b.insts().size()); - for (auto &s : b.insts()) { - auto my_rw = visit(*this, *s); - if (intersects(my_rw, rw)) { - insts.emplace_back(inst{std::make_unique().release()}); - rw.clear(); - } - insts.emplace_back(s); - if (last_instruction_was_barrier_) { - last_instruction_was_barrier_ = false; - rw.clear(); - } - rw.merge(my_rw); } - b.insts(std::move(insts)); - return rw; -} -/* Function nodes */ -void insert_barrier::operator()(prototype &) {} - -void insert_barrier::operator()(function &fn) { - auto aa = alias_analyser{}; - aa(fn); - aa_ = aa.get_result(); - visit(*this, *fn.prototype()); - visit(*this, *fn.body()); + return invisible_rw; } -/* Program nodes */ -void insert_barrier::operator()(program &p) { - for (auto &fn : p.functions()) { - visit(*this, *fn); - } +/* Function nodes */ +void insert_barrier_pass::run_on_function(function &fn) { + auto aa = alias_analysis{}.run_on_function(fn); + run_on_region(*fn.body(), aa); } } // namespace tinytc diff --git a/src/pass/insert_barrier.hpp b/src/pass/insert_barrier.hpp index 9d9ac4b8..a4feddb7 100644 --- a/src/pass/insert_barrier.hpp +++ b/src/pass/insert_barrier.hpp @@ -4,62 +4,42 @@ #ifndef INSERT_BARRIER_20230310_HPP #define INSERT_BARRIER_20230310_HPP -#include "node/data_type_node.hpp" +#include "analysis/aa_results.hpp" #include "node/function_node.hpp" -#include "node/inst_node.hpp" -#include "node/program_node.hpp" #include "node/region_node.hpp" -#include "node/value_node.hpp" -#include "pass/aa_results.hpp" +#include #include namespace tinytc { -class insert_barrier { +class insert_barrier_pass { public: - /* Data type nodes */ - bool operator()(void_data_type &); - bool operator()(group_data_type &b); - bool operator()(memref_data_type &m); - bool operator()(scalar_data_type &s); + void run_on_function(function &fn); - /* Value nodes */ - value_node *operator()(int_imm &v); - value_node *operator()(float_imm &v); - value_node *operator()(val &v); + private: + class reads_writes { + public: + constexpr static std::array address_spaces = {address_space::global, + address_space::local}; - /* Stmt nodes */ - std::unordered_set operator()(inst_node &inst); - std::unordered_set operator()(blas_a2_inst &inst); - std::unordered_set operator()(blas_a3_inst &inst); - std::unordered_set operator()(loop_inst &p); - std::unordered_set operator()(alloca_inst &a); - std::unordered_set operator()(barrier_inst &b); - std::unordered_set operator()(expand_inst &e); - std::unordered_set operator()(fuse_inst &f); - std::unordered_set operator()(load_inst &e); - std::unordered_set operator()(if_inst &in); - std::unordered_set operator()(lifetime_stop_inst &); - std::unordered_set operator()(parallel_inst &p); - std::unordered_set operator()(size_inst &s); - std::unordered_set operator()(store_inst &s); - std::unordered_set operator()(subview_inst &s); - std::unordered_set operator()(yield_inst &y); + void clear(address_space as); + void merge(reads_writes &&other); + void emplace_read(address_space as, ::tinytc_value const *val); + void emplace_write(address_space as, ::tinytc_value const *val); - /* Region nodes */ - std::unordered_set operator()(rgn &b); + bool raw(address_space as, reads_writes const &rw, aa_results const &aa); + bool war(address_space as, reads_writes const &rw, aa_results const &aa); + bool waw(address_space as, reads_writes const &rw, aa_results const &aa); + bool raw_war_or_waw(address_space as, reads_writes const &rw, aa_results const &aa); - /* Func nodes */ - void operator()(prototype &p); - void operator()(function &fn); + private: + auto address_space_to_index(address_space as) -> std::size_t; - /* Program nodes */ - void operator()(program &p); + std::array, address_spaces.size()> reads, writes; + }; - private: - aa_results aa_; - bool last_instruction_was_barrier_ = false; + auto run_on_region(rgn ®, aa_results const &aa) -> reads_writes; }; } // namespace tinytc diff --git a/src/pass/insert_lifetime_stop.cpp b/src/pass/insert_lifetime_stop.cpp index dd74d4ad..ca221b85 100644 --- a/src/pass/insert_lifetime_stop.cpp +++ b/src/pass/insert_lifetime_stop.cpp @@ -33,9 +33,7 @@ auto insert_lifetime_stop_pass::run_on_region(rgn ®, aa_results const &aa) for (; prev_it != reg.begin(); --prev_it) { auto &i = *(prev_it - 1); for (auto &subreg : i->child_regions()) { - if (subreg) { - rgn_ops.merge(run_on_region(*subreg, aa)); - } + rgn_ops.merge(run_on_region(*subreg, aa)); } for (auto &v : i->operands()) { if (isa(*v->ty())) { diff --git a/src/pass/metadata.cpp b/src/pass/metadata.cpp deleted file mode 100644 index d4d8c76a..00000000 --- a/src/pass/metadata.cpp +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "pass/metadata.hpp" -#include "node/function_node.hpp" -#include "node/program_node.hpp" -#include "support/visit.hpp" -#include "tinytc/tinytc.hpp" - -#include -#include - -namespace tinytc { - -/* Function nodes */ -void metadata::operator()(function const &fn) { - auto m = kernel_metadata{}; - m.subgroup_size = fn.subgroup_size(); - m.work_group_size = fn.work_group_size(); - metadata_[std::string(fn.name())] = m; -} - -/* Program nodes */ -void metadata::operator()(program const &p) { - for (auto &fn : p.functions()) { - visit(*this, *fn); - } -} - -} // namespace tinytc diff --git a/src/pass/metadata.hpp b/src/pass/metadata.hpp deleted file mode 100644 index 67bc6c93..00000000 --- a/src/pass/metadata.hpp +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef METADATA_20240412_HPP -#define METADATA_20240412_HPP - -#include "kernel_metadata.hpp" -#include "node/function_node.hpp" -#include "node/program_node.hpp" - -#include -#include - -namespace tinytc { - -class metadata { - public: - /* Func nodes */ - void operator()(function const &fn); - - /* Program nodes */ - void operator()(program const &p); - - inline auto get_result() const -> std::unordered_map { - return metadata_; - } - - private: - std::unordered_map metadata_; -}; - -} // namespace tinytc - -#endif // METADATA_20240412_HPP diff --git a/src/passes.def b/src/passes.def index 61f2442f..2988565a 100644 --- a/src/passes.def +++ b/src/passes.def @@ -3,6 +3,7 @@ FUNCTION_PASS("check-ir", check_ir_pass{}) FUNCTION_PASS("dump-ir", dump_ir_pass{std::cout}) +FUNCTION_PASS("insert-barrier", insert_barrier_pass{}) FUNCTION_PASS("insert-lifetime-stop", insert_lifetime_stop_pass{}) FUNCTION_PASS("set-stack-ptr", set_stack_ptr_pass{}) FUNCTION_PASS_WITH_INFO("work-group-size", [](tinytc_core_info const* info) { return work_group_size_pass(info); }) diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index d74557af..a2c4895b 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -89,14 +89,17 @@ tinytc_recipe_small_gemm_batched_create(tinytc_recipe_t *recipe, const_tinytc_co auto const kernel = [&](function_builder &fb, bool is_beta_nonzero) { auto alpha = fb.argument(make_scalar(ty_, my_loc()), "alpha", my_loc()); - auto A = fb.argument(make_memref(ty_, {selA(M, K), selA(K, M), dynamic}, - {1, ldA, strideA}, my_loc()), - "A", my_loc()); - auto B = fb.argument(make_memref(ty_, {selB(K, N), selB(N, K), dynamic}, - {1, ldB, strideB}, my_loc()), - "B", my_loc()); + auto A = + fb.argument(make_memref(ty_, {selA(M, K), selA(K, M), dynamic}, + {1, ldA, strideA}, address_space::global, my_loc()), + "A", my_loc()); + auto B = + fb.argument(make_memref(ty_, {selB(K, N), selB(N, K), dynamic}, + {1, ldB, strideB}, address_space::global, my_loc()), + "B", my_loc()); auto beta_arg = fb.argument(make_scalar(ty_, my_loc()), "beta", my_loc()); - auto C = fb.argument(make_memref(ty_, {M, N, dynamic}, {1, ldC, strideC}, my_loc()), + auto C = fb.argument(make_memref(ty_, {M, N, dynamic}, {1, ldC, strideC}, + address_space::global, my_loc()), "C", my_loc()); auto beta = is_beta_nonzero ? std::move(beta_arg) : make_imm(0.0, ty_, my_loc()); diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 521bf295..5ca0bf3b 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -139,10 +139,16 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( auto const kernel = [&](function_builder &fb, bool is_beta_nonzero) { auto alpha = fb.argument(make_scalar(ty_, my_loc()), "alpha", my_loc()); - auto A = fb.argument(make_memref(ty_, {M, K}, {1, ldA}, my_loc()), "A", my_loc()); - auto B = fb.argument(make_memref(ty_, {K, N}, {1, ldB}, my_loc()), "B", my_loc()); + auto A = + fb.argument(make_memref(ty_, {M, K}, {1, ldA}, address_space::global, my_loc()), + "A", my_loc()); + auto B = + fb.argument(make_memref(ty_, {K, N}, {1, ldB}, address_space::global, my_loc()), + "B", my_loc()); auto beta_arg = fb.argument(make_scalar(ty_, my_loc()), "beta", my_loc()); - auto C = fb.argument(make_memref(ty_, {M, N}, {1, ldC}, my_loc()), "C", my_loc()); + auto C = + fb.argument(make_memref(ty_, {M, N}, {1, ldC}, address_space::global, my_loc()), + "C", my_loc()); fb.subgroup_size(sgs); auto const wgs = tiling.work_group_size(sgs); fb.work_group_size(wgs[0], wgs[1]); diff --git a/src/scalar_type.cpp b/src/scalar_type.cpp index e55da42a..c5407de5 100644 --- a/src/scalar_type.cpp +++ b/src/scalar_type.cpp @@ -118,6 +118,16 @@ clir::data_type to_clir_atomic_ty(scalar_type ty, clir::address_space as, clir:: return clir::data_type(base_type(ty), as, q); } +clir::address_space to_clir_address_space(address_space as) { + switch (as) { + case address_space::global: + return clir::address_space::global_t; + case address_space::local: + return clir::address_space::local_t; + } + return clir::address_space::global_t; +} + } // namespace tinytc char const *tinytc_scalar_type_to_string(tinytc_scalar_type_t ty) { diff --git a/src/scalar_type.hpp b/src/scalar_type.hpp index d813973e..2b1c6d5a 100644 --- a/src/scalar_type.hpp +++ b/src/scalar_type.hpp @@ -22,6 +22,7 @@ clir::data_type to_clir_ty(scalar_type ty, short size, clir::data_type to_clir_atomic_ty(scalar_type ty, clir::address_space as = clir::address_space::generic_t, clir::type_qualifier q = clir::type_qualifier::none); +clir::address_space to_clir_address_space(address_space as); } // namespace tinytc diff --git a/src/support/walk.hpp b/src/support/walk.hpp index 3b5d212b..b0e3f623 100644 --- a/src/support/walk.hpp +++ b/src/support/walk.hpp @@ -34,10 +34,8 @@ template void walk(inst_node &i, std::function(*j, callback); - } + for (auto &j : *reg) { + walk(*j, callback); } } if constexpr (Order == walk_order::post_order) { @@ -47,16 +45,14 @@ template void walk(inst_node &i, std::function void walk(inst_node &i, std::function callback) { for (auto ® : i.child_regions()) { - if (reg) { - if constexpr (Order == walk_order::pre_order) { - callback(reg); - } - for (auto &j : *reg) { - walk(*j, callback); - } - if constexpr (Order == walk_order::post_order) { - callback(reg); - } + if constexpr (Order == walk_order::pre_order) { + callback(reg); + } + for (auto &j : *reg) { + walk(*j, callback); + } + if constexpr (Order == walk_order::post_order) { + callback(reg); } } } diff --git a/test/opt/insert-lifetime-stop.ir b/test/opt/insert-lifetime-stop.ir index 4ca88b22..b2595e05 100644 --- a/test/opt/insert-lifetime-stop.ir +++ b/test/opt/insert-lifetime-stop.ir @@ -3,14 +3,14 @@ ; RUN: %tinytc-opt --insert-lifetime-stop < %s | filecheck %s func @basic() { - %0 = alloca -> memref -; CHECK: %0 = alloca -> memref + %0 = alloca -> memref +; CHECK: %0 = alloca -> memref ; CHECK-NEXT: lifetime_stop %0 } func @use1(%A: memref, %C: memref) { ; CHECK-LABEL: func @use1{{.*}} - %B = alloca -> memref + %B = alloca -> memref gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref ; CHECK: gemm.n.n{{.*}} ; CHECK-NEXT: lifetime_stop %B @@ -18,9 +18,9 @@ func @use1(%A: memref, %C: memref) { func @use2(%A: memref, %C: memref) { ; CHECK-LABEL: func @use2{{.*}} - %B = alloca -> memref + %B = alloca -> memref gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref - %B2 = alloca -> memref + %B2 = alloca -> memref gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref gemm.n.n 1.0, %A, %B2, 0.0, %C : f32, memref, memref, f32, memref ; CHECK: %B2 = {{.*}} @@ -32,7 +32,7 @@ func @use2(%A: memref, %C: memref) { func @use_alias(%A: memref, %C: memref) { ; CHECK-LABEL: func @use_alias{{.*}} - %B = alloca -> memref + %B = alloca -> memref %0 = fuse %B[1,3] : memref %1 = subview %0[0:8,:] : memref gemm.n.n 1.0, %A, %1, 0.0, %C : f32, memref, memref>, f32, memref @@ -42,11 +42,11 @@ func @use_alias(%A: memref, %C: memref) { func @region1() { ; CHECK-LABEL: func @region1{{.*}} - %0 = alloca -> memref + %0 = alloca -> memref for %i=0,4 : index { - %1 = alloca -> memref + %1 = alloca -> memref for %k=0,4 : index { - %2 = alloca -> memref + %2 = alloca -> memref gemm.n.n 1.0, %0, %1, 0.0, %2 : f32, memref, memref, f32, memref axpby.n 1.0, %0, 0.0, %1 : f32, memref, f32, memref } diff --git a/test/opt/work-group-size.ir b/test/opt/work-group-size.ir index 8b033814..04f0f81f 100644 --- a/test/opt/work-group-size.ir +++ b/test/opt/work-group-size.ir @@ -8,19 +8,20 @@ func @default_pvc() { func @f32_blas() { ; CHECK: func @f32_blas() subgroup_size(32) work_group_size(128,2) { - %0 = alloca -> memref - %1 = alloca -> memref + %0 = alloca -> memref + %1 = alloca -> memref for %i=0,4 { - axpby.n 1.0, %0, 0.0, %1 : f32, memref, f32, memref + axpby.n 1.0, %0, 0.0, %1 : f32, memref, f32, memref } } func @f64_blas() { ; CHECK: func @f64_blas() subgroup_size(16) work_group_size(128,8) { - %0 = alloca -> memref - %1 = alloca -> memref - %2 = alloca -> memref + %0 = alloca -> memref + %1 = alloca -> memref + %2 = alloca -> memref for %i=0,4 { - gemm.n.n 1.0, %0, %1, 0.0, %2 : f64, memref, memref, f64, memref + gemm.n.n 1.0, %0, %1, 0.0, %2 + : f64, memref, memref, f64, memref } } From 14a6bd2dcc9caf2d3206290bcb79aa0ea96a60af Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 17 Sep 2024 18:52:38 +0200 Subject: [PATCH 018/297] Update insert barrier pass to not visit spmd regions. Still need to account for control flow of for loop. Signed-off-by: Carsten Uphoff --- src/node/function_node.hpp | 2 + src/node/inst_node.hpp | 17 ++-- src/node/region_node.hpp | 13 +++- src/pass/check_ir.cpp | 5 +- src/pass/check_ir.hpp | 3 - src/pass/insert_barrier.cpp | 32 ++++---- src/pass/insert_barrier.hpp | 3 +- test/opt/insert-barrier.ir | 151 ++++++++++++++++++++++++++++++++++++ 8 files changed, 198 insertions(+), 28 deletions(-) create mode 100644 test/opt/insert-barrier.ir diff --git a/src/node/function_node.hpp b/src/node/function_node.hpp index 99be945c..a1813e85 100644 --- a/src/node/function_node.hpp +++ b/src/node/function_node.hpp @@ -5,6 +5,7 @@ #define FUNCTION_NODE_20230310_HPP #include "location.hpp" +#include "node/region_node.hpp" #include "reference_counted.hpp" #include "support/util.hpp" #include "tinytc/tinytc.hpp" @@ -27,6 +28,7 @@ struct tinytc_func : tinytc::reference_counted { : name_(std::move(name)), args_(std::move(args)), body_(std::move(body)), work_group_size_{0, 0}, subgroup_size_{0} { loc(lc); + body_->kind(tinytc::region_kind::collective); } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index c489a2ad..b10319b9 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -5,6 +5,7 @@ #define INST_NODE_20230327_HPP #include "error.hpp" +#include "node/region_node.hpp" #include "reference_counted.hpp" #include "support/type_list.hpp" #include "support/util.hpp" @@ -25,8 +26,8 @@ enum class inst_execution_kind { mixed, ///< mixed instruction on uniform or varying data collective, ///< collective instruction on uniform data, distributed among work-items spmd ///< SPMD instruction on varying data - }; + enum class IK { alloca, arith, @@ -525,13 +526,7 @@ class for_inst : public loop_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::for_loop; } inline for_inst(value loop_var, value from, value to, region body, location const &loc = {}) - : loop_inst{IK::for_loop, - std::move(loop_var), - std::move(from), - std::move(to), - {}, - std::move(body), - loc} {} + : for_inst{std::move(loop_var), std::move(from), std::move(to), {}, std::move(body), loc} {} inline for_inst(value loop_var, value from, value to, value step, region body, location const &loc = {}) : loop_inst{IK::for_loop, @@ -553,7 +548,9 @@ class foreach_inst : public loop_inst { std::move(to), {}, std::move(body), - loc} {} + loc} { + child_region(0)->kind(region_kind::spmd); + } }; class hadamard_inst : public blas_a3_inst { @@ -589,6 +586,8 @@ class parallel_inst : public standard_inst<0, 0, 1> { inline parallel_inst(region body, location const &lc = {}) : standard_inst{IK::parallel} { child_region(0) = std::move(body); loc(lc); + + child_region(0)->kind(region_kind::spmd); } inline auto body() const -> region const & { return child_region(0); } }; diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp index c30a3c30..9704260d 100644 --- a/src/node/region_node.hpp +++ b/src/node/region_node.hpp @@ -12,16 +12,26 @@ #include #include +namespace tinytc { + +//! Instruction classification +enum class region_kind { mixed, collective, spmd }; + +} // namespace tinytc + struct tinytc_region : tinytc::reference_counted { public: using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; inline tinytc_region(std::vector insts = {}, tinytc::location const &lc = {}) - : insts_(std::move(insts)) { + : insts_(std::move(insts)), kind_(tinytc::region_kind::mixed) { loc(lc); } + inline auto kind() const noexcept -> tinytc::region_kind { return kind_; } + inline void kind(tinytc::region_kind kind) noexcept { kind_ = kind; } + inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } @@ -45,6 +55,7 @@ struct tinytc_region : tinytc::reference_counted { private: std::vector insts_; + tinytc::region_kind kind_; tinytc::location loc_; }; diff --git a/src/pass/check_ir.cpp b/src/pass/check_ir.cpp index 6c2a65b1..e1df796e 100644 --- a/src/pass/check_ir.cpp +++ b/src/pass/check_ir.cpp @@ -3,6 +3,8 @@ #include "pass/check_ir.hpp" #include "error.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" #include "support/casting.hpp" #include "support/visit.hpp" #include "support/walk.hpp" @@ -16,7 +18,8 @@ namespace tinytc { void check_ir_pass::run_on_function(function &fn) { walk(fn, [this](inst_node const &i, walk_stage const &stage) { - const bool child_region_is_spmd_region = isa(i) || isa(i); + const bool child_region_is_spmd_region = + i.num_child_regions() > 0 && i.child_region(0)->kind() == region_kind::spmd; if (stage.is_before_all_regions()) { if (i.kind() == inst_execution_kind::collective && inside_spmd_region_) { diff --git a/src/pass/check_ir.hpp b/src/pass/check_ir.hpp index b6aad450..70ba8a76 100644 --- a/src/pass/check_ir.hpp +++ b/src/pass/check_ir.hpp @@ -5,9 +5,6 @@ #define CHECK_IR_20240222_HPP #include "node/function_node.hpp" -#include "node/inst_node.hpp" -#include "node/program_node.hpp" -#include "node/region_node.hpp" namespace tinytc { diff --git a/src/pass/insert_barrier.cpp b/src/pass/insert_barrier.cpp index f210ebdc..e57ab6fc 100644 --- a/src/pass/insert_barrier.cpp +++ b/src/pass/insert_barrier.cpp @@ -83,10 +83,11 @@ auto insert_barrier_pass::reads_writes::address_space_to_index(address_space as) throw internal_compiler_error{}; } -auto insert_barrier_pass::run_on_region(rgn ®, aa_results const &aa) -> reads_writes { +auto insert_barrier_pass::run_on_region(rgn ®, aa_results const &aa, + const bool insert_barriers) -> reads_writes { auto invisible_rw = reads_writes{}; for (auto it = reg.begin(); it != reg.end(); ++it) { - if (auto *barrier = dyn_cast(it->get()); barrier) { + if (auto *barrier = dyn_cast(it->get()); insert_barriers && barrier) { for (auto &as : reads_writes::address_spaces) { if (barrier->has_fence(as)) { invisible_rw.clear(as); @@ -95,7 +96,9 @@ auto insert_barrier_pass::run_on_region(rgn ®, aa_results const &aa) -> reads } else { auto rw = reads_writes{}; for (auto &subreg : (*it)->child_regions()) { - rw.merge(run_on_region(*subreg, aa)); + const bool insert_barriers_sub = + insert_barriers && subreg->kind() != region_kind::spmd; + rw.merge(run_on_region(*subreg, aa, insert_barriers_sub)); } auto const emplace_read = [&rw](value const &v) { @@ -118,20 +121,23 @@ auto insert_barrier_pass::run_on_region(rgn ®, aa_results const &aa) -> reads emplace_write(in.C()); }, [&](load_inst &in) { emplace_read(in.operand()); }, - [&](store_inst &in) { emplace_read(in.operand()); }, + [&](store_inst &in) { emplace_write(in.operand()); }, [](inst_node &) {}}, **it); - std::int32_t fence_flags = 0; - for (auto &as : reads_writes::address_spaces) { - if (invisible_rw.raw_war_or_waw(as, rw, aa)) { - fence_flags |= static_cast(as); - invisible_rw.clear(as); + if (insert_barriers) { + std::int32_t fence_flags = 0; + for (auto &as : reads_writes::address_spaces) { + if (invisible_rw.raw_war_or_waw(as, rw, aa)) { + fence_flags |= static_cast(as); + invisible_rw.clear(as); + } + } + if (fence_flags != 0) { + it = + reg.insert(it, inst{std::make_unique(fence_flags).release()}); + ++it; // skip over barrier } - } - if (fence_flags != 0) { - it = reg.insert(it, inst{std::make_unique(fence_flags).release()}); - ++it; // skip over barrier } invisible_rw.merge(std::move(rw)); diff --git a/src/pass/insert_barrier.hpp b/src/pass/insert_barrier.hpp index a4feddb7..20990ce0 100644 --- a/src/pass/insert_barrier.hpp +++ b/src/pass/insert_barrier.hpp @@ -39,7 +39,8 @@ class insert_barrier_pass { std::array, address_spaces.size()> reads, writes; }; - auto run_on_region(rgn ®, aa_results const &aa) -> reads_writes; + auto run_on_region(rgn ®, aa_results const &aa, + const bool insert_barriers = true) -> reads_writes; }; } // namespace tinytc diff --git a/test/opt/insert-barrier.ir b/test/opt/insert-barrier.ir new file mode 100644 index 00000000..e8bf6903 --- /dev/null +++ b/test/opt/insert-barrier.ir @@ -0,0 +1,151 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-opt --insert-barrier < %s | filecheck %s +func @rar(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { + gemm.n.n %a, %A, %B, %b, %C : f32, memref, memref, f32, memref + gemm.n.n %a, %A, %B, %b, %D : f32, memref, memref, f32, memref +; CHECK-LABEL: func @rar({{.*}} +; CHECK: gemm.n.n %a, %A, %B, %b, %C{{.*}} +; CHECK-NEXT: gemm.n.n %a, %A, %B, %b, %D{{.*}} +} + +func @raw(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { + gemm.n.n %a, %A, %B, %b, %C : f32, memref, memref, f32, memref + gemm.n.n %a, %C, %B, %b, %D : f32, memref, memref, f32, memref +; CHECK-LABEL: func @raw({{.*}} +; CHECK: gemm.n.n %a, %A, %B, %b, %C{{.*}} +; CHECK-NEXT: barrier.global +; CHECK-NEXT: gemm.n.n %a, %C, %B, %b, %D{{.*}} +} + +func @war(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { + gemm.n.n %a, %C, %B, %b, %A : f32, memref, memref, f32, memref + gemm.n.n %a, %D, %B, %b, %C : f32, memref, memref, f32, memref +; CHECK-LABEL: func @war({{.*}} +; CHECK: gemm.n.n %a, %C, %B, %b, %A{{.*}} +; CHECK-NEXT: barrier.global +; CHECK-NEXT: gemm.n.n %a, %D, %B, %b, %C{{.*}} +} + +func @waw(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { + gemm.n.n %a, %A, %B, %b, %C : f32, memref, memref, f32, memref + gemm.n.n %a, %A, %B, %b, %C : f32, memref, memref, f32, memref +; CHECK-LABEL: func @waw({{.*}} +; CHECK: gemm.n.n %a, %A, %B, %b, %C{{.*}} +; CHECK-NEXT: barrier.global +; CHECK-NEXT: gemm.n.n %a, %A, %B, %b, %C{{.*}} +} + +func @raw_local(%a: f32, %b: f32, %A: memref, %B: memref, %D: memref) { + %C = alloca -> memref + gemm.n.n %a, %A, %B, %b, %C : f32, memref, memref, f32, memref + gemm.n.n %a, %C, %B, %b, %D : f32, memref, memref, f32, memref +; CHECK-LABEL: func @raw_local({{.*}} +; CHECK: gemm.n.n %a, %A, %B, %b, %C{{.*}} +; CHECK-NEXT: barrier.local +; CHECK-NEXT: gemm.n.n %a, %C, %B, %b, %D{{.*}} +} + +func @raw_local_war_global(%a: f32, %b: f32, %A: memref, %B: memref, %D: memref) { + %C = alloca -> memref + gemm.n.n %a, %A, %B, %b, %C : f32, memref, memref, f32, memref + gemm.n.n %a, %C, %B, %b, %A : f32, memref, memref, f32, memref +; CHECK-LABEL: func @raw_local_war_global({{.*}} +; CHECK: gemm.n.n %a, %A, %B, %b, %C{{.*}} +; CHECK-NEXT: barrier.global.local +; CHECK-NEXT: gemm.n.n %a, %C, %B, %b, %A{{.*}} +} + +func @respect_manual_barrier(%a: f32, %b: f32, %A: memref, %B: memref, %D: memref) { + %C = alloca -> memref + gemm.n.n %a, %A, %B, %b, %C : f32, memref, memref, f32, memref + barrier.global.local + gemm.n.n %a, %C, %B, %b, %A : f32, memref, memref, f32, memref +; CHECK-LABEL: func @respect_manual_barrier({{.*}} +; CHECK: gemm.n.n %a, %A, %B, %b, %C{{.*}} +; CHECK-NEXT: barrier.global.local +; CHECK-NEXT: gemm.n.n %a, %C, %B, %b, %A{{.*}} +} + +func @war_alias(%a: f32, %b: f32, %A: memref, %C: memref) { + %B = alloca -> memref + %0 = subview %B[:,0:8] : memref + axpby.n %a, %B, %b, %C : f32, memref, f32, memref + axpby.n %a, %A, %b, %0 : f32, memref, f32, memref +; CHECK-LABEL: func @war_alias({{.*}} +; CHECK: axpby.n %a, %B, %b, %C{{.*}} +; CHECK-NEXT: barrier.local +; CHECK-NEXT: axpby.n %a, %A, %b, %0{{.*}} +} + +func @if(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { + %0 = cmp.gt %a, 42.0 : f32 + if %0 { + axpby.n %a, %A, %b, %B : f32, memref, f32, memref + axpby.n %a, %B, %b, %C : f32, memref, f32, memref + } else { + axpby.n %a, %C, %b, %D : f32, memref, f32, memref + } + axpby.n %a, %A, %b, %B : f32, memref, f32, memref +; CHECK-LABEL: func @if({{.*}} +; CHECK: if %0 { +; CHECK-NEXT: axpby.n %a, %A, %b, %B{{.*}} +; CHECK-NEXT: barrier.global +; CHECK-NEXT: axpby.n %a, %B, %b, %C{{.*}} +; CHECK-NEXT: } else { +; CHECK-NEXT: axpby.n %a, %C, %b, %D{{.*}} +; CHECK-NEXT: } +; CHECK-NEXT: barrier.global +; CHECK-NEXT: axpby.n %a, %A, %b, %B{{.*}} +} + +func @region1() { + %0 = alloca -> memref + for %i=0,4 : index { + %1 = alloca -> memref + for %k=0,4 : index { + %2 = alloca -> memref + gemm.n.n 1.0, %0, %1, 0.0, %2 + : f32, memref, memref, f32, memref + axpby.n 1.0, %1, 0.0, %0 : f32, memref, f32, memref + } + axpby.n 1.0, %0, 0.0, %1 : f32, memref, f32, memref + } +; CHECK-LABEL: func @region1({{.*}} +; CHECK: for %i=0,4 : index { +; CHECK-NEXT: %1 = alloca -> memref +; CHECK-NEXT: for %k=0,4 : index { +; CHECK-NEXT: %2 = alloca -> memref +; CHECK-NEXT: barrier.local +; CHECK-NEXT: gemm.n.n 0x1p+0, %0, %1, 0x0p+0, %2{{.*}} +; CHECK-NEXT: barrier.local +; CHECK-NEXT: axpby.n 0x1p+0, %1, 0x0p+0, %0{{.*}} +; CHECK-NEXT: } +; CHECK-NEXT: barrier.local +; CHECK-NEXT: axpby.n 0x1p+0, %0, 0x0p+0, %1{{.*}} +; CHECK-NEXT: } +} + +func @no_barrier_spmd(%a: f32, %b: f32, %A: memref, %B: memref) { + parallel { + %1 = subgroup_id + %2 = cmp.eq %1, 0 : i32 + if %2 { + %3 = load %A[3,4] : memref + store %3, %A[3,4] : memref + } + } + %0 = load %A[3,4] : memref +; CHECK-LABEL: func @no_barrier_spmd({{.*}} +; CHECK: parallel { +; CHECK-NEXT: %1 = subgroup_id +; CHECK-NEXT: %2 = cmp.eq %1, 0 : i32 +; CHECK-NEXT: if %2 { +; CHECK-NEXT: %3 = load %A[3,4] : memref +; CHECK-NEXT: store %3, %A[3,4] : memref +; CHECK-NEXT: } +; CHECK-NEXT: } +; CHECK-NEXT: barrier.global +; CHECK-NEXT: %0 = load %A[3,4] : memref +} From ce956062b4f12b1ac17cabc1d320ef548cbb60ba Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 18 Sep 2024 11:28:12 +0200 Subject: [PATCH 019/297] Make internal naming more consistent Signed-off-by: Carsten Uphoff --- src/analysis/alias.cpp | 2 +- src/analysis/alias.hpp | 2 +- src/codegen_tools.hpp | 12 ++++++------ src/func.cpp | 4 ++-- src/gemm_generator.hpp | 4 ++-- src/node/function_node.hpp | 2 +- src/node/program_node.hpp | 2 +- src/node/region_node.hpp | 2 +- src/parser/parser_impl.yy | 17 +++++++++-------- src/pass/check_ir.cpp | 2 +- src/pass/check_ir.hpp | 2 +- src/pass/constant_propagation.cpp | 4 ++-- src/pass/constant_propagation.hpp | 4 ++-- src/pass/convert_to_opencl.cpp | 6 +++--- src/pass/convert_to_opencl.hpp | 6 +++--- src/pass/dump_ir.cpp | 4 ++-- src/pass/dump_ir.hpp | 4 ++-- src/pass/insert_barrier.cpp | 4 ++-- src/pass/insert_barrier.hpp | 4 ++-- src/pass/insert_lifetime_stop.cpp | 4 ++-- src/pass/insert_lifetime_stop.hpp | 6 +++--- src/pass/lower_linalg.cpp | 4 ++-- src/pass/lower_linalg.hpp | 5 ++--- src/pass/slot_tracker.cpp | 2 +- src/pass/slot_tracker.hpp | 2 +- src/pass/stack.cpp | 2 +- src/pass/stack.hpp | 2 +- src/pass/work_group_size.cpp | 2 +- src/pass/work_group_size.hpp | 2 +- src/prog.cpp | 2 +- src/region.cpp | 2 +- src/support/walk.hpp | 5 +++-- src/sycl/kernel.cpp | 8 ++++---- src/sycl/recipe_handler.cpp | 8 ++++---- src/tiling.cpp | 8 ++++---- src/tiling.hpp | 8 ++++---- src/ze/error.hpp | 4 ++-- 37 files changed, 82 insertions(+), 81 deletions(-) diff --git a/src/analysis/alias.cpp b/src/analysis/alias.cpp index e51add20..812f3496 100644 --- a/src/analysis/alias.cpp +++ b/src/analysis/alias.cpp @@ -67,7 +67,7 @@ void alias_analysis_visitor::operator()(subview_inst const &s) { alias_[s.result().get()] = source; } -auto alias_analysis::run_on_function(function &fn) -> aa_results { +auto alias_analysis::run_on_function(function_node &fn) -> aa_results { auto visitor = alias_analysis_visitor{}; walk(fn, [&visitor](inst_node &i) { visit(visitor, i); }); diff --git a/src/analysis/alias.hpp b/src/analysis/alias.hpp index 02314ed0..65b029d4 100644 --- a/src/analysis/alias.hpp +++ b/src/analysis/alias.hpp @@ -11,7 +11,7 @@ namespace tinytc { class alias_analysis { public: - auto run_on_function(function &fn) -> aa_results; + auto run_on_function(function_node &fn) -> aa_results; }; } // namespace tinytc diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index a7037bc9..84725a5a 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -106,16 +106,16 @@ struct matrix_block_description { }; auto read_matrix_block_regular(clir::block_builder &bb, matrix_block_description const &d, - int M_mode, core_config const &core_cfg, char const *block_name) - -> std::unique_ptr; + int M_mode, core_config const &core_cfg, + char const *block_name) -> std::unique_ptr; auto read_matrix_block_vector(clir::block_builder &bb, matrix_block_description const &d, - int M_mode, core_config const &core_cfg, char const *block_name) - -> std::unique_ptr; + int M_mode, core_config const &core_cfg, + char const *block_name) -> std::unique_ptr; // Read MbxKb block auto read_matrix_block(clir::block_builder &bb, matrix_block_description const &d, int M_mode, - core_config const &core_cfg, char const *block_name) - -> std::unique_ptr; + core_config const &core_cfg, + char const *block_name) -> std::unique_ptr; // Write MbxKb block void write_matrix_block(clir::block_builder &bb, block_accessor const &block, diff --git a/src/func.cpp b/src/func.cpp index fcb12ac2..27013e46 100644 --- a/src/func.cpp +++ b/src/func.cpp @@ -31,8 +31,8 @@ tinytc_status_t tinytc_function_create(tinytc_func_t *fun, char const *name, uin for (uint32_t i = 0; i < arg_list_size; ++i) { arg_vec.emplace_back(value(arg_list[i], true)); } - *fun = std::make_unique(std::string(name), std::move(arg_vec), region{body, true}, - get_optional(loc)) + *fun = std::make_unique(std::string(name), std::move(arg_vec), + region{body, true}, get_optional(loc)) .release(); }); } diff --git a/src/gemm_generator.hpp b/src/gemm_generator.hpp index cf0ad937..b24e9b45 100644 --- a/src/gemm_generator.hpp +++ b/src/gemm_generator.hpp @@ -108,8 +108,8 @@ ::clir::func generate_gemm(gemm_configuration const &gemm_cfg, local_tiling cons */ auto max_register_block_gemm(std::int32_t C_scalar_type_size_in_bytes, std::int32_t sgs, std::int32_t register_space, - std::pair max_fill_fraction = {1, 2}) - -> std::pair; + std::pair max_fill_fraction = { + 1, 2}) -> std::pair; } // namespace tinytc diff --git a/src/node/function_node.hpp b/src/node/function_node.hpp index a1813e85..cec8bdc7 100644 --- a/src/node/function_node.hpp +++ b/src/node/function_node.hpp @@ -72,7 +72,7 @@ struct tinytc_func : tinytc::reference_counted { namespace tinytc { -using function = ::tinytc_func; +using function_node = ::tinytc_func; } // namespace tinytc diff --git a/src/node/program_node.hpp b/src/node/program_node.hpp index aa3a3dc1..c7462db0 100644 --- a/src/node/program_node.hpp +++ b/src/node/program_node.hpp @@ -50,7 +50,7 @@ struct tinytc_prog : tinytc::reference_counted { namespace tinytc { -using program = ::tinytc_prog; +using program_node = ::tinytc_prog; } // namespace tinytc diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp index 9704260d..15995958 100644 --- a/src/node/region_node.hpp +++ b/src/node/region_node.hpp @@ -60,7 +60,7 @@ struct tinytc_region : tinytc::reference_counted { }; namespace tinytc { -using rgn = ::tinytc_region; +using region_node = ::tinytc_region; } // namespace tinytc #endif // REGION_NODE_20230908_HPP diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index c3f5873a..5ac82a31 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -155,8 +155,8 @@ %nterm func %nterm > arguments %nterm <::tinytc::value> argument -%nterm >> attributes -%nterm > attribute +%nterm >> attributes +%nterm > attribute %nterm data_type %nterm scalar_type %nterm memref_type @@ -230,7 +230,7 @@ %% prog: func_list { - auto p = prog { std::make_unique(std::move($func_list), @prog).release() }; + auto p = prog { std::make_unique(std::move($func_list), @prog).release() }; ctx.program(p); $$ = std::move(p); } @@ -246,8 +246,9 @@ func: } GLOBAL_IDENTIFIER LPAREN arguments RPAREN attributes region { auto loc = @FUNC; loc.end = @RPAREN.end; - auto func_node = - std::make_unique($GLOBAL_IDENTIFIER, std::move($arguments), std::move($region), loc).release(); + auto func_node = std::make_unique($GLOBAL_IDENTIFIER, std::move($arguments), + std::move($region), loc) + .release(); for (auto &attr : $attributes) { attr(*func_node); } @@ -287,14 +288,14 @@ attribute: } auto const wgs = std::array{static_cast($m), static_cast($n)}; - $$ = [=](function &f) { f.work_group_size(wgs); }; + $$ = [=](function_node &f) { f.work_group_size(wgs); }; } | SUBGROUP_SIZE LPAREN INTEGER_CONSTANT RPAREN { if ($INTEGER_CONSTANT <= 0) { throw parser::syntax_error(@INTEGER_CONSTANT, "Must be a non-negative number"); } auto const sgs = static_cast($INTEGER_CONSTANT); - $$ = [=](function &f) { f.subgroup_size(sgs); }; + $$ = [=](function_node &f) { f.subgroup_size(sgs); }; } ; @@ -391,7 +392,7 @@ region: LBRACE { ctx.push_scope(); } instructions RBRACE { - $$ = region{std::make_unique(std::move($instructions), @region).release()}; + $$ = region{std::make_unique(std::move($instructions), @region).release()}; ctx.pop_scope(); } ; diff --git a/src/pass/check_ir.cpp b/src/pass/check_ir.cpp index e1df796e..b4885ea5 100644 --- a/src/pass/check_ir.cpp +++ b/src/pass/check_ir.cpp @@ -16,7 +16,7 @@ namespace tinytc { -void check_ir_pass::run_on_function(function &fn) { +void check_ir_pass::run_on_function(function_node &fn) { walk(fn, [this](inst_node const &i, walk_stage const &stage) { const bool child_region_is_spmd_region = i.num_child_regions() > 0 && i.child_region(0)->kind() == region_kind::spmd; diff --git a/src/pass/check_ir.hpp b/src/pass/check_ir.hpp index 70ba8a76..c58b4a90 100644 --- a/src/pass/check_ir.hpp +++ b/src/pass/check_ir.hpp @@ -10,7 +10,7 @@ namespace tinytc { class check_ir_pass { public: - void run_on_function(function &fn); + void run_on_function(function_node &fn); private: bool inside_spmd_region_ = false; diff --git a/src/pass/constant_propagation.cpp b/src/pass/constant_propagation.cpp index 68acd373..d6a589fa 100644 --- a/src/pass/constant_propagation.cpp +++ b/src/pass/constant_propagation.cpp @@ -153,14 +153,14 @@ void constant_propagation::operator()(arith_inst &arith) { void constant_propagation::operator()(parallel_inst &p) { visit(*this, *p.body()); } /* Region nodes */ -void constant_propagation::operator()(rgn &b) { +void constant_propagation::operator()(region_node &b) { for (auto &s : b.insts()) { visit(*this, *s); } } /* Function nodes */ -void constant_propagation::operator()(function &fn) { visit(*this, *fn.body()); } +void constant_propagation::operator()(function_node &fn) { visit(*this, *fn.body()); } /* Program nodes */ void constant_propagation::operator()(program &p) { diff --git a/src/pass/constant_propagation.hpp b/src/pass/constant_propagation.hpp index 737e6929..502f5eaa 100644 --- a/src/pass/constant_propagation.hpp +++ b/src/pass/constant_propagation.hpp @@ -23,10 +23,10 @@ class constant_propagation { void operator()(parallel_inst &p); /* Region nodes */ - void operator()(rgn &b); + void operator()(region_node &b); /* Func nodes */ - void operator()(function &fn); + void operator()(function_node &fn); /* Program nodes */ void operator()(program &p); diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 4586a253..143663d4 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -1086,7 +1086,7 @@ std::vector convert_to_opencl_pass::operator()(yield_inst const &in) } /* Region nodes */ -clir::stmt convert_to_opencl_pass::run_on_region(rgn ®) { +clir::stmt convert_to_opencl_pass::run_on_region(region_node ®) { declared_vars_.push_back({}); auto bb = clir::block_builder{}; for (auto &s : reg.insts()) { @@ -1099,7 +1099,7 @@ clir::stmt convert_to_opencl_pass::run_on_region(rgn ®) { } /* Function nodes */ -auto convert_to_opencl_pass::run_on_function(function &fn) -> clir::func { +auto convert_to_opencl_pass::run_on_function(function_node &fn) -> clir::func { stack_high_water_mark_ = 0; auto const subgroup_size = fn.subgroup_size(); try { @@ -1156,7 +1156,7 @@ auto convert_to_opencl_pass::run_on_function(function &fn) -> clir::func { } /* Program nodes */ -auto convert_to_opencl_pass::run_on_program(program &p) -> clir::prog { +auto convert_to_opencl_pass::run_on_program(program_node &p) -> clir::prog { reserved_names_.clear(); for (auto const &fn : p.functions()) { reserved_names_.insert(std::string(fn->name())); diff --git a/src/pass/convert_to_opencl.hpp b/src/pass/convert_to_opencl.hpp index 69cd84fd..82899d75 100644 --- a/src/pass/convert_to_opencl.hpp +++ b/src/pass/convert_to_opencl.hpp @@ -102,11 +102,11 @@ class convert_to_opencl_pass { std::vector operator()(sum_inst const &s); std::vector operator()(yield_inst const &in); - auto run_on_program(program &p) -> clir::prog; + auto run_on_program(program_node &p) -> clir::prog; private: - auto run_on_region(rgn ®) -> clir::stmt; - auto run_on_function(function &fn) -> clir::func; + auto run_on_region(region_node ®) -> clir::stmt; + auto run_on_function(function_node &fn) -> clir::func; auto get_dope_vector(value_node *v) -> dope_vector &; void set_dope_vector(value_node *v, dope_vector dv); diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index b85ba967..5abc268f 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -358,7 +358,7 @@ void dump_ir_pass::operator()(yield_inst const &y) { do_with_infix(y.op_begin(), y.op_end(), [this](auto const &i) { visit(*this, *i->ty()); }); } -void dump_ir_pass::dump_region(rgn const ®) { +void dump_ir_pass::dump_region(region_node const ®) { *os_ << "{" << std::endl; ++lvl_; auto ind = indent(); @@ -371,7 +371,7 @@ void dump_ir_pass::dump_region(rgn const ®) { *os_ << indent() << "}"; } -void dump_ir_pass::run_on_function(function const &fn) { +void dump_ir_pass::run_on_function(function_node const &fn) { *os_ << "func @" << fn.name() << "("; std::string infix = ",\n "; infix += std::string(fn.name().size(), ' '); diff --git a/src/pass/dump_ir.hpp b/src/pass/dump_ir.hpp index f55f93b9..bd2f99e4 100644 --- a/src/pass/dump_ir.hpp +++ b/src/pass/dump_ir.hpp @@ -64,10 +64,10 @@ class dump_ir_pass { void operator()(sum_inst const &s); void operator()(yield_inst const &y); - void run_on_function(function const &fn); + void run_on_function(function_node const &fn); private: - void dump_region(rgn const ®); + void dump_region(region_node const ®); void dump_blas_a2(blas_a2_inst const &g); void dump_blas_a3(blas_a3_inst const &g); diff --git a/src/pass/insert_barrier.cpp b/src/pass/insert_barrier.cpp index e57ab6fc..292fd6f9 100644 --- a/src/pass/insert_barrier.cpp +++ b/src/pass/insert_barrier.cpp @@ -83,7 +83,7 @@ auto insert_barrier_pass::reads_writes::address_space_to_index(address_space as) throw internal_compiler_error{}; } -auto insert_barrier_pass::run_on_region(rgn ®, aa_results const &aa, +auto insert_barrier_pass::run_on_region(region_node ®, aa_results const &aa, const bool insert_barriers) -> reads_writes { auto invisible_rw = reads_writes{}; for (auto it = reg.begin(); it != reg.end(); ++it) { @@ -148,7 +148,7 @@ auto insert_barrier_pass::run_on_region(rgn ®, aa_results const &aa, } /* Function nodes */ -void insert_barrier_pass::run_on_function(function &fn) { +void insert_barrier_pass::run_on_function(function_node &fn) { auto aa = alias_analysis{}.run_on_function(fn); run_on_region(*fn.body(), aa); } diff --git a/src/pass/insert_barrier.hpp b/src/pass/insert_barrier.hpp index 20990ce0..2f74b41c 100644 --- a/src/pass/insert_barrier.hpp +++ b/src/pass/insert_barrier.hpp @@ -15,7 +15,7 @@ namespace tinytc { class insert_barrier_pass { public: - void run_on_function(function &fn); + void run_on_function(function_node &fn); private: class reads_writes { @@ -39,7 +39,7 @@ class insert_barrier_pass { std::array, address_spaces.size()> reads, writes; }; - auto run_on_region(rgn ®, aa_results const &aa, + auto run_on_region(region_node ®, aa_results const &aa, const bool insert_barriers = true) -> reads_writes; }; diff --git a/src/pass/insert_lifetime_stop.cpp b/src/pass/insert_lifetime_stop.cpp index ca221b85..281bcf25 100644 --- a/src/pass/insert_lifetime_stop.cpp +++ b/src/pass/insert_lifetime_stop.cpp @@ -15,7 +15,7 @@ namespace tinytc { -auto insert_lifetime_stop_pass::run_on_region(rgn ®, aa_results const &aa) +auto insert_lifetime_stop_pass::run_on_region(region_node ®, aa_results const &aa) -> std::unordered_set { if (reg.empty()) { return {}; @@ -60,7 +60,7 @@ auto insert_lifetime_stop_pass::run_on_region(rgn ®, aa_results const &aa) return rgn_ops; } -void insert_lifetime_stop_pass::run_on_function(function &fn) { +void insert_lifetime_stop_pass::run_on_function(function_node &fn) { auto aa = alias_analysis{}.run_on_function(fn); run_on_region(*fn.body(), aa); } diff --git a/src/pass/insert_lifetime_stop.hpp b/src/pass/insert_lifetime_stop.hpp index 2c91129e..ec827fa3 100644 --- a/src/pass/insert_lifetime_stop.hpp +++ b/src/pass/insert_lifetime_stop.hpp @@ -16,11 +16,11 @@ namespace tinytc { class insert_lifetime_stop_pass { public: - void run_on_function(function &fn); + void run_on_function(function_node &fn); private: - auto run_on_region(rgn ®, aa_results const &aa) - -> std::unordered_set<::tinytc_value const *>; + auto run_on_region(region_node ®, + aa_results const &aa) -> std::unordered_set<::tinytc_value const *>; }; } // namespace tinytc diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index 05b44707..e1c59181 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -158,7 +158,7 @@ inst lower_linalg_pass::operator()(ger_inst &g) { } /* Region nodes */ -void lower_linalg_pass::operator()(rgn &b) { +void lower_linalg_pass::operator()(region_node &b) { for (auto &s : b.insts()) { if (auto lowered_inst = visit(*this, *s); lowered_inst) { s = lowered_inst; @@ -169,7 +169,7 @@ void lower_linalg_pass::operator()(rgn &b) { /* Function nodes */ void lower_linalg_pass::operator()(prototype &) {} -void lower_linalg_pass::operator()(function &fn) { +void lower_linalg_pass::operator()(function_node &fn) { auto const subgroup_size = fn.subgroup_size(); try { core_cfg_ = info_->get_core_config(subgroup_size); diff --git a/src/pass/lower_linalg.hpp b/src/pass/lower_linalg.hpp index ca8157b8..f4dcffcb 100644 --- a/src/pass/lower_linalg.hpp +++ b/src/pass/lower_linalg.hpp @@ -41,11 +41,10 @@ class lower_linalg_pass { inst operator()(parallel_inst &p); /* Region nodes */ - void operator()(rgn &b); + void operator()(region_node &b); /* Func nodes */ - void operator()(prototype &p); - void operator()(function &fn); + void operator()(function_node &fn); /* Program nodes */ void operator()(program &p); diff --git a/src/pass/slot_tracker.cpp b/src/pass/slot_tracker.cpp index e4a75854..209a00bc 100644 --- a/src/pass/slot_tracker.cpp +++ b/src/pass/slot_tracker.cpp @@ -17,7 +17,7 @@ void slot_tracker::set_slot(value_node const &v) { } } -void slot_tracker::run_on_function(function &fn) { +void slot_tracker::run_on_function(function_node &fn) { slot_ = 0; for (auto const &arg : fn.args()) { set_slot(*arg); diff --git a/src/pass/slot_tracker.hpp b/src/pass/slot_tracker.hpp index 9a52dfd4..86abaca2 100644 --- a/src/pass/slot_tracker.hpp +++ b/src/pass/slot_tracker.hpp @@ -17,7 +17,7 @@ namespace tinytc { class slot_tracker { public: - void run_on_function(function &fn); + void run_on_function(function_node &fn); auto get_slot(value_node const &v) -> std::int64_t; diff --git a/src/pass/stack.cpp b/src/pass/stack.cpp index 718591aa..4e8bfbb9 100644 --- a/src/pass/stack.cpp +++ b/src/pass/stack.cpp @@ -18,7 +18,7 @@ namespace tinytc { -void set_stack_ptr_pass::run_on_function(function &fn) { +void set_stack_ptr_pass::run_on_function(function_node &fn) { struct allocation { value_node *value; std::int64_t start, stop; diff --git a/src/pass/stack.hpp b/src/pass/stack.hpp index 72b6f7ad..ef10177f 100644 --- a/src/pass/stack.hpp +++ b/src/pass/stack.hpp @@ -10,7 +10,7 @@ namespace tinytc { class set_stack_ptr_pass { public: - void run_on_function(function &fn); + void run_on_function(function_node &fn); }; } // namespace tinytc diff --git a/src/pass/work_group_size.cpp b/src/pass/work_group_size.cpp index 62792619..c042ed9d 100644 --- a/src/pass/work_group_size.cpp +++ b/src/pass/work_group_size.cpp @@ -37,7 +37,7 @@ work_group_size_pass::work_group_size_pass(::tinytc_core_info const *info) } } -void work_group_size_pass::run_on_function(function &fn) { +void work_group_size_pass::run_on_function(function_node &fn) { auto subgroup_size = fn.subgroup_size(); auto work_group_size = fn.work_group_size(); diff --git a/src/pass/work_group_size.hpp b/src/pass/work_group_size.hpp index 69c31fd3..324d6c4e 100644 --- a/src/pass/work_group_size.hpp +++ b/src/pass/work_group_size.hpp @@ -13,7 +13,7 @@ class work_group_size_pass { public: work_group_size_pass(tinytc_core_info const *info); - void run_on_function(function &fn); + void run_on_function(function_node &fn); private: tinytc_core_info const *info_; diff --git a/src/prog.cpp b/src/prog.cpp index 3de2f4a2..0f34e076 100644 --- a/src/prog.cpp +++ b/src/prog.cpp @@ -36,7 +36,7 @@ tinytc_status_t tinytc_program_create(tinytc_prog_t *prg, uint32_t fun_list_size for (uint32_t i = 0; i < fun_list_size; ++i) { fun_vec.emplace_back(func(fun_list[i], true)); } - *prg = std::make_unique(std::move(fun_vec), get_optional(loc)).release(); + *prg = std::make_unique(std::move(fun_vec), get_optional(loc)).release(); }); } diff --git a/src/region.cpp b/src/region.cpp index ea9a507d..97aa2cd4 100644 --- a/src/region.cpp +++ b/src/region.cpp @@ -29,7 +29,7 @@ tinytc_status_t tinytc_region_create(tinytc_region_t *reg, uint32_t instruction_ for (uint32_t i = 0; i < instruction_list_size; ++i) { inst_vec.emplace_back(inst(instruction_list[i], true)); } - *reg = std::make_unique(std::move(inst_vec), get_optional(loc)).release(); + *reg = std::make_unique(std::move(inst_vec), get_optional(loc)).release(); }); } diff --git a/src/support/walk.hpp b/src/support/walk.hpp index b0e3f623..233b2380 100644 --- a/src/support/walk.hpp +++ b/src/support/walk.hpp @@ -59,13 +59,14 @@ template void walk(inst_node &i, std::function callback); -template void walk(function &fn, std::function callback) { +template +void walk(function_node &fn, std::function callback) { for (auto &i : *fn.body()) { walk(*i, callback); } } -inline void walk(function &fn, +inline void walk(function_node &fn, std::function callback) { for (auto &i : *fn.body()) { walk(*i, callback); diff --git a/src/sycl/kernel.cpp b/src/sycl/kernel.cpp index 3c46c482..ec2a5caa 100644 --- a/src/sycl/kernel.cpp +++ b/src/sycl/kernel.cpp @@ -72,8 +72,8 @@ auto make_kernel_bundle(context const &ctx, device const &dev, source const &src std::move(source_ctx)); } auto make_kernel_bundle(context const &ctx, device const &dev, prog prg, - tinytc_core_feature_flags_t core_features, source_context source_ctx) - -> kernel_bundle { + tinytc_core_feature_flags_t core_features, + source_context source_ctx) -> kernel_bundle { return dispatch(dev.get_backend(), ctx, dev, std::move(prg), core_features, std::move(source_ctx)); } @@ -106,8 +106,8 @@ template <> struct kernel_dispatcher { } }; -auto make_kernel(kernel_bundle const &bundle, char const *name) - -> kernel { +auto make_kernel(kernel_bundle const &bundle, + char const *name) -> kernel { return dispatch(bundle.get_backend(), bundle, name); } diff --git a/src/sycl/recipe_handler.cpp b/src/sycl/recipe_handler.cpp index 26b8a97e..799542e3 100644 --- a/src/sycl/recipe_handler.cpp +++ b/src/sycl/recipe_handler.cpp @@ -75,8 +75,8 @@ auto make_recipe_handler(sycl::context const &ctx, sycl::device const &dev, reci return sycl_recipe_handler{handler}; } -auto make_recipe_handler(sycl::queue const &q, recipe const &rec, source_context source_ctx) - -> sycl_recipe_handler { +auto make_recipe_handler(sycl::queue const &q, recipe const &rec, + source_context source_ctx) -> sycl_recipe_handler { tinytc_recipe_handler_t handler = std::make_unique(q.get_context(), q.get_device(), rec, std::move(source_ctx)) @@ -104,8 +104,8 @@ auto sycl_recipe_handler::submit(sycl::queue q, sycl::event const &dep_event) -> }); } -auto sycl_recipe_handler::submit(sycl::queue q, std::vector const &dep_events) - -> sycl::event { +auto sycl_recipe_handler::submit(sycl::queue q, + std::vector const &dep_events) -> sycl::event { return q.submit([&](sycl::handler &h) { h.depends_on(dep_events); parallel_for(h); diff --git a/src/tiling.cpp b/src/tiling.cpp index 1576eea6..6d33e14f 100644 --- a/src/tiling.cpp +++ b/src/tiling.cpp @@ -18,8 +18,8 @@ auto blas_shape::operator==(blas_shape const &other) const -> bool { } auto blas_shape::operator!=(blas_shape const &other) const -> bool { return !(*this == other); } -auto suggest_subgroup_size(std::vector const &shapes, ::tinytc_core_info const &info) - -> std::int32_t { +auto suggest_subgroup_size(std::vector const &shapes, + ::tinytc_core_info const &info) -> std::int32_t { std::size_t max_size = 1u; for (auto &shape : shapes) { max_size = std::max(max_size, size(shape.ty)); @@ -65,8 +65,8 @@ auto suggest_subgroup_size(std::vector const &shapes, ::tinytc_core_ return sensible_subgroup_sizes.back(); } -auto suggest_local_tiling(std::vector const &shapes, core_config const &core_cfg) - -> local_tiling { +auto suggest_local_tiling(std::vector const &shapes, + core_config const &core_cfg) -> local_tiling { if (shapes.empty()) { return {1, 1}; } diff --git a/src/tiling.hpp b/src/tiling.hpp index e89b075b..204f7ea4 100644 --- a/src/tiling.hpp +++ b/src/tiling.hpp @@ -55,8 +55,8 @@ struct blas_shape { * @param shapes Shapes that occur in kernel * @param info Core info */ -auto suggest_subgroup_size(std::vector const &shapes, ::tinytc_core_info const &info) - -> std::int32_t; +auto suggest_subgroup_size(std::vector const &shapes, + ::tinytc_core_info const &info) -> std::int32_t; /** * @brief Suggest a local tiling based on blas size @@ -75,8 +75,8 @@ auto suggest_local_tiling(blas_shape const &bshape, core_config const &core_cfg) * * @return */ -auto suggest_local_tiling(std::vector const &shapes, core_config const &core_cfg) - -> local_tiling; +auto suggest_local_tiling(std::vector const &shapes, + core_config const &core_cfg) -> local_tiling; /** * @brief Suggest both, subgroup size and tiling, based on blas sizes. diff --git a/src/ze/error.hpp b/src/ze/error.hpp index 1c123447..f6f581c6 100644 --- a/src/ze/error.hpp +++ b/src/ze/error.hpp @@ -12,8 +12,8 @@ namespace tinytc { template -auto exception_to_status_code_ze(F &&f, tinytc_source_context_t context = nullptr) - -> tinytc_status_t { +auto exception_to_status_code_ze(F &&f, + tinytc_source_context_t context = nullptr) -> tinytc_status_t { try { f(); } catch (status const &st) { From e650070442c0903205f88b7ea270a556ab84fb99 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 20 Sep 2024 13:36:27 +0200 Subject: [PATCH 020/297] Add control flow graph analysis Signed-off-by: Carsten Uphoff --- include/tinytc/tinytc.h | 5 +- include/tinytc/tinytc.hpp | 2 +- src/CMakeLists.txt | 2 + src/analysis/cfg.cpp | 72 +++++++++++++++++++++++ src/analysis/cfg.hpp | 54 ++++++++++++++++++ src/compiler.cpp | 1 + src/inst.cpp | 2 +- src/node/data_type_node.hpp | 1 + src/node/function_node.hpp | 2 +- src/node/inst_node.cpp | 24 ++++++-- src/node/inst_node.hpp | 2 + src/node/program_node.hpp | 2 +- src/node/region_node.hpp | 2 +- src/node/value_node.hpp | 2 +- src/parser/parser_impl.yy | 27 ++++----- src/pass/convert_to_opencl.cpp | 2 + src/pass/dump_cfg.cpp | 39 +++++++++++++ src/pass/dump_cfg.hpp | 25 ++++++++ src/pass/dump_ir.cpp | 38 ++++++++----- src/pass/dump_ir.hpp | 6 +- src/pass/insert_barrier.cpp | 98 +++++++++++++++++++++++++++++--- src/pass/insert_barrier.hpp | 13 +++-- src/passes.def | 1 + src/support/util.hpp | 5 -- src/support/visit.hpp | 5 ++ test/opt/insert-barrier.ir | 23 ++++++++ test/opt/insert-lifetime-stop.ir | 2 + tools/opt/args.cpp | 2 +- 28 files changed, 396 insertions(+), 63 deletions(-) create mode 100644 src/analysis/cfg.cpp create mode 100644 src/analysis/cfg.hpp create mode 100644 src/pass/dump_cfg.cpp create mode 100644 src/pass/dump_cfg.hpp diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index bc0447c3..1ae7e217 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -740,8 +740,9 @@ TINYTC_EXPORT tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc * @endcode * * @param instr [out] pointer to the inst object created - * @param yield_list_size [in] length of yielded values list; must be at least 1 - * @param yield_list [in][range(1, yield_list_size)] yielded values array + * @param yield_list_size [in] length of yielded values list + * @param yield_list [in][range(0, yield_list_size)] yielded values array; can be nullptr if + * yield_list_size is 0 * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 816a1d58..506a712a 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1262,7 +1262,7 @@ inline inst make_yield(std::vector const &yield_list, location const &loc tinytc_inst_t instr; auto len = yield_list.size(); if (len > std::numeric_limits::max()) { - throw std::out_of_range("slice list too long"); + throw std::out_of_range("yield list too long"); } tinytc_value_t *yl = const_cast(reinterpret_cast(yield_list.data())); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 025e7b87..39061b60 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -19,6 +19,7 @@ find_package(BISON 3.8.2 REQUIRED) set(SOURCES analysis/aa_results.cpp analysis/alias.cpp + analysis/cfg.cpp analysis/equal.cpp binary.cpp codegen_tools.cpp @@ -37,6 +38,7 @@ set(SOURCES pass/check_ir.cpp #pass/constant_propagation.cpp pass/convert_to_opencl.cpp + pass/dump_cfg.cpp pass/dump_ir.cpp pass/insert_barrier.cpp pass/insert_lifetime_stop.cpp diff --git a/src/analysis/cfg.cpp b/src/analysis/cfg.cpp new file mode 100644 index 00000000..a4bb1b2f --- /dev/null +++ b/src/analysis/cfg.cpp @@ -0,0 +1,72 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "analysis/cfg.hpp" +#include "node/inst_node.hpp" +#include "support/casting.hpp" +#include "support/visit.hpp" + +namespace tinytc { + +auto control_flow_graph::node_queue() const -> std::queue { + auto q = std::queue{}; + for (auto &[key, neighbors] : adj_) { + q.push(key); + } + return q; +} + +auto get_control_flow_graph(region_node &topreg) -> control_flow_graph { + auto cfg = control_flow_graph{}; + + const auto add_region = [&cfg](region_node ®, + auto &add_region_ref) -> std::pair { + if (reg.empty()) { + return {}; + } + + auto start = reg.begin()->get(); + cfg.add_node(start); + + auto pred_nodes = std::queue{}; + pred_nodes.push(start); + + for (auto it = reg.begin() + 1; it != reg.end(); ++it) { + inst_node *node = it->get(); + cfg.add_node(node); + + for (; !pred_nodes.empty(); pred_nodes.pop()) { + cfg.add_edge(pred_nodes.front(), node); + } + + if ((*it)->num_child_regions() > 0) { + for (auto &subreg : (*it)->child_regions()) { + auto [substart, subexit] = add_region_ref(*subreg, add_region_ref); + cfg.add_edge(node, substart); + if (isa(**it)) { + cfg.add_edge(subexit, node); + pred_nodes.push(node); + } else { + pred_nodes.push(subexit); + } + } + } else { + pred_nodes.push(node); + } + } + + // every region must have exactly one exit node and the exit node must be last + // @todo: NOT guaranteed for parallel_inst and function yet! + if (pred_nodes.size() != 1) { + throw internal_compiler_error{}; + } + + return std::make_pair(std::move(start), std::move(pred_nodes.front())); + }; + + add_region(topreg, add_region); + + return cfg; +} + +} // namespace tinytc diff --git a/src/analysis/cfg.hpp b/src/analysis/cfg.hpp new file mode 100644 index 00000000..1ac7221b --- /dev/null +++ b/src/analysis/cfg.hpp @@ -0,0 +1,54 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CFG_20240919_HPP +#define CFG_20240919_HPP + +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "support/util.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +class control_flow_graph { + public: + inline void add_node(inst_node *a) { adj_[a] = adjacency_list{}; } + inline void add_edge(inst_node *a, inst_node *b) { + adj_[a].succ.push_back(b); + adj_[b].pred.push_back(b); + } + + auto node_queue() const -> std::queue; + + inline auto pred_begin(inst_node *a) { return adj_[a].pred.begin(); } + inline auto pred_end(inst_node *a) { return adj_[a].pred.end(); } + inline auto + predecessors(inst_node *a) -> iterator_range_wrapper::iterator> { + return {pred_begin(a), pred_end(a)}; + } + + inline auto succ_begin(inst_node *a) { return adj_[a].succ.begin(); } + inline auto succ_end(inst_node *a) { return adj_[a].succ.end(); } + inline auto + successors(inst_node *a) -> iterator_range_wrapper::iterator> { + return {succ_begin(a), succ_end(a)}; + } + + private: + struct adjacency_list { + std::vector pred; + std::vector succ; + }; + std::unordered_map adj_; +}; + +auto get_control_flow_graph(region_node ®) -> control_flow_graph; + +} // namespace tinytc + +#endif // CFG_20240919_HPP diff --git a/src/compiler.cpp b/src/compiler.cpp index 08d4a496..74cea869 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -7,6 +7,7 @@ #include "parser.hpp" #include "pass/check_ir.hpp" #include "pass/convert_to_opencl.hpp" +#include "pass/dump_cfg.hpp" #include "pass/dump_ir.hpp" #include "pass/insert_barrier.hpp" #include "pass/insert_lifetime_stop.hpp" diff --git a/src/inst.cpp b/src/inst.cpp index 33e2c27a..777433e4 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -459,7 +459,7 @@ tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condi tinytc_status_t tinytc_yield_inst_create(tinytc_inst_t *instr, uint32_t yield_list_size, tinytc_value_t *yield_list, const tinytc_location_t *loc) { - if (instr == nullptr || yield_list_size == 0 || yield_list == nullptr) { + if (instr == nullptr || (yield_list_size != 0 && yield_list == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp index 03afa140..1b989fba 100644 --- a/src/node/data_type_node.hpp +++ b/src/node/data_type_node.hpp @@ -26,6 +26,7 @@ struct tinytc_data_type : tinytc::reference_counted { using leaves = tinytc::data_type_nodes; inline tinytc_data_type(tinytc::DTK tid) : tid_(tid) {} + virtual ~tinytc_data_type() = default; inline auto type_id() const -> tinytc::DTK { return tid_; } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } diff --git a/src/node/function_node.hpp b/src/node/function_node.hpp index cec8bdc7..85408955 100644 --- a/src/node/function_node.hpp +++ b/src/node/function_node.hpp @@ -21,7 +21,7 @@ using value_range = iterator_range_wrapper; using const_value_range = iterator_range_wrapper; } // namespace tinytc -struct tinytc_func : tinytc::reference_counted { +struct tinytc_func final : tinytc::reference_counted { public: inline tinytc_func(std::string name, std::vector args, tinytc::region body, tinytc::location const &lc = {}) diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 9c97e1d2..8965a224 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -52,14 +52,14 @@ blas_a3_inst::blas_a3_inst(IK tid, value alpha, value A, value B, value beta, va op(op_C) = std::move(C); } -loop_inst::loop_inst(IK tid, value loop_var0, value from0, value to0, value step0, region body, +loop_inst::loop_inst(IK tid, value loop_var0, value from0, value to0, value step0, region body0, location const &lc) : standard_inst{tid, step0 ? 4 : 3} { op(op_loop_var) = std::move(loop_var0); op(op_from) = std::move(from0); op(op_to) = std::move(to0); op(op_step) = std::move(step0); - child_region(0) = std::move(body); + child_region(0) = std::move(body0); loc(lc); auto lvt = get_scalar_type(loc(), loop_var()); @@ -74,6 +74,11 @@ loop_inst::loop_inst(IK tid, value loop_var0, value from0, value to0, value step if (lvt->ty() != fromt->ty() || lvt->ty() != tot->ty() || !step_ok) { throw compilation_error(loc(), status::ir_scalar_mismatch); } + + region_node &body = *child_region(0); + if (body.empty() || !isa(**(body.end() - 1))) { + body.insert(body.end(), make_yield({}, lc)); + } } alloca_inst::alloca_inst(data_type ty, location const &lc) @@ -461,16 +466,23 @@ hadamard_inst::hadamard_inst(value alpha0, value A0, value B0, value beta0, valu } } -if_inst::if_inst(value condition, region then, region otherwise, +if_inst::if_inst(value condition, region then0, region otherwise0, std::vector const &return_types, location const &lc) - : standard_inst{IK::if_, 1, static_cast(return_types.size()), otherwise ? 2 : 1} { + : standard_inst{IK::if_, 1, static_cast(return_types.size()), otherwise0 ? 2 : 1} { op(0) = std::move(condition); - child_region(child_region_then) = std::move(then); - child_region(child_region_otherwise) = std::move(otherwise); + child_region(child_region_then) = std::move(then0); + child_region(child_region_otherwise) = std::move(otherwise0); loc(lc); for (std::size_t i = 0; i < return_types.size(); ++i) { result(i) = make_value(return_types[i]); } + + for (std::int64_t i = 0; i < num_child_regions(); ++i) { + region_node &body = *child_region(i); + if (body.empty() || !isa(**(body.end() - 1))) { + body.insert(body.end(), make_yield({}, lc)); + } + } } size_inst::size_inst(value op0, std::int64_t mode, location const &lc) diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index b10319b9..7fbf9a90 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -91,6 +91,8 @@ struct tinytc_inst : tinytc::reference_counted { using leaves = tinytc::inst_nodes; inline tinytc_inst(tinytc::IK tid) : tid_(tid) {} + virtual ~tinytc_inst() = default; + inline auto type_id() const -> tinytc::IK { return tid_; } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } diff --git a/src/node/program_node.hpp b/src/node/program_node.hpp index c7462db0..42c93bd6 100644 --- a/src/node/program_node.hpp +++ b/src/node/program_node.hpp @@ -18,7 +18,7 @@ using func_range = iterator_range_wrapper; using const_func_range = iterator_range_wrapper; } // namespace tinytc -struct tinytc_prog : tinytc::reference_counted { +struct tinytc_prog final : tinytc::reference_counted { public: inline tinytc_prog(std::vector funcs, tinytc::location const &lc = {}) : funcs_(std::move(funcs)) { diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp index 15995958..f51a139d 100644 --- a/src/node/region_node.hpp +++ b/src/node/region_node.hpp @@ -19,7 +19,7 @@ enum class region_kind { mixed, collective, spmd }; } // namespace tinytc -struct tinytc_region : tinytc::reference_counted { +struct tinytc_region final : tinytc::reference_counted { public: using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; diff --git a/src/node/value_node.hpp b/src/node/value_node.hpp index 6e3cf561..2dc33ea0 100644 --- a/src/node/value_node.hpp +++ b/src/node/value_node.hpp @@ -22,7 +22,7 @@ struct tinytc_value : tinytc::reference_counted { using leaves = tinytc::value_nodes; inline tinytc_value(tinytc::VK tid) : tid_(tid) {} - inline virtual ~tinytc_value() {} + virtual ~tinytc_value() = default; inline auto type_id() const -> tinytc::VK { return tid_; } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 5ac82a31..93e4abf1 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -461,6 +461,7 @@ identifier_or_constant: optional_identifier_or_constant_list: %empty {} | identifier_or_constant_list { $$ = std::move($1); } +; identifier_or_constant_list: identifier_or_constant { $$.push_back(std::move($identifier_or_constant)); } @@ -629,23 +630,15 @@ for_loop_var_type: var_definition: identifier_list EQUALS valued_inst { $$ = std::move($valued_inst); - if ($identifier_list.size() == 1) { - if (!$$->result()) { - throw syntax_error(@identifier_list, "Instruction does not return value"); - } - $$->result()->name($identifier_list[0]); - ctx.val($identifier_list[0], $$->result(), @identifier_list); - } else { - auto results = $$->result_begin(); - if ($$->num_results() != static_cast($identifier_list.size())) { - throw syntax_error( - @identifier_list, - "Number of identifiers does not equal number of returned values"); - } - for (std::int64_t i = 0; i < $$->num_results(); ++i) { - results[i]->name($identifier_list[i]); - ctx.val($identifier_list[i], results[i], @identifier_list); - } + if (static_cast($identifier_list.size()) != $$->num_results()) { + throw syntax_error( + @identifier_list, + "Number of identifiers does not equal number of returned values"); + } + auto results = $$->result_begin(); + for (std::int64_t i = 0; i < $$->num_results(); ++i) { + results[i]->name($identifier_list[i]); + ctx.val($identifier_list[i], results[i], @identifier_list); } } ; diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 143663d4..05e366fd 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -775,6 +775,7 @@ std::vector convert_to_opencl_pass::operator()(ger_inst const &g) { std::vector convert_to_opencl_pass::operator()(for_inst const &p) { auto clinst = std::vector{}; + yielded_vars_.push_back(std::vector{}); auto lv = declare(*p.loop_var()); auto lv_ty = visit(*this, *p.loop_var()->ty()); @@ -785,6 +786,7 @@ std::vector convert_to_opencl_pass::operator()(for_inst const &p) { clinst.emplace_back(clir::stmt(std::make_shared( std::move(start), std::move(condition), std::move(step), std::move(body)))); + yielded_vars_.pop_back(); return clinst; } diff --git a/src/pass/dump_cfg.cpp b/src/pass/dump_cfg.cpp new file mode 100644 index 00000000..269629cd --- /dev/null +++ b/src/pass/dump_cfg.cpp @@ -0,0 +1,39 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/dump_cfg.hpp" +#include "analysis/cfg.hpp" +#include "pass/dump_ir.hpp" +#include "tinytc/tinytc.hpp" + +#include +#include + +namespace tinytc { + +dump_cfg_pass::dump_cfg_pass(std::ostream &os) : os_(&os) {} + +void dump_cfg_pass::run_on_function(function_node const &fn) { + auto dump_ir = dump_ir_pass(*os_, 0); + + *os_ << "digraph " << fn.name() << " {" << std::endl; + + auto cfg = get_control_flow_graph(*fn.body()); + auto q = cfg.node_queue(); + for (; !q.empty(); q.pop()) { + auto &node = q.front(); + + *os_ << reinterpret_cast(node) << " [label=\""; + dump_ir.run_on_instruction(*node); + *os_ << "\"]" << std::endl; + + for (auto &neigh : cfg.successors(node)) { + *os_ << reinterpret_cast(node) << " -> " + << reinterpret_cast(neigh) << std::endl; + } + } + + *os_ << "}" << std::endl; +} + +} // namespace tinytc diff --git a/src/pass/dump_cfg.hpp b/src/pass/dump_cfg.hpp new file mode 100644 index 00000000..08cfeea3 --- /dev/null +++ b/src/pass/dump_cfg.hpp @@ -0,0 +1,25 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DUMP_BACKWARD_CFG_20240919_HPP +#define DUMP_BACKWARD_CFG_20240919_HPP + +#include "node/function_node.hpp" + +#include + +namespace tinytc { + +class dump_cfg_pass { + public: + dump_cfg_pass(std::ostream &os); + + void run_on_function(function_node const &fn); + + private: + std::ostream *os_; +}; + +} // namespace tinytc + +#endif // DUMP_BACKWARD_CFG_20240919_HPP diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 5abc268f..31c1d407 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -12,7 +12,7 @@ namespace tinytc { -dump_ir_pass::dump_ir_pass(std::ostream &os) : os_(&os) {} +dump_ir_pass::dump_ir_pass(std::ostream &os, int level_limit) : os_(&os), lvl_limit_(level_limit) {} /* Data type nodes */ void dump_ir_pass::operator()(void_data_type const &) { *os_ << "void"; } @@ -353,22 +353,31 @@ void dump_ir_pass::operator()(sum_inst const &a) { void dump_ir_pass::operator()(yield_inst const &y) { *os_ << "yield "; - do_with_infix(y.op_begin(), y.op_end(), [this](auto const &i) { visit(*this, *i); }); - *os_ << " : "; - do_with_infix(y.op_begin(), y.op_end(), [this](auto const &i) { visit(*this, *i->ty()); }); + if (y.num_operands() > 0) { + do_with_infix(y.op_begin(), y.op_end(), [this](auto const &i) { visit(*this, *i); }, ", "); + *os_ << " : "; + do_with_infix( + y.op_begin(), y.op_end(), [this](auto const &i) { visit(*this, *i->ty()); }, ", "); + } else { + *os_ << ":"; + } } void dump_ir_pass::dump_region(region_node const ®) { - *os_ << "{" << std::endl; - ++lvl_; - auto ind = indent(); - for (auto const &i : reg) { - *os_ << ind; - visit(*this, *i); - *os_ << std::endl; + if (lvl_ < lvl_limit_) { + *os_ << "{" << std::endl; + ++lvl_; + auto ind = indent(); + for (auto const &i : reg) { + *os_ << ind; + visit(*this, *i); + *os_ << std::endl; + } + --lvl_; + *os_ << indent() << "}"; + } else { + *os_ << "{...}"; } - --lvl_; - *os_ << indent() << "}"; } void dump_ir_pass::run_on_function(function_node const &fn) { @@ -396,4 +405,7 @@ void dump_ir_pass::run_on_function(function_node const &fn) { *os_ << std::endl; } +void dump_ir_pass::run_on_region(region_node const ®) { dump_region(reg); } +void dump_ir_pass::run_on_instruction(inst_node const &in) { visit(*this, in); } + } // namespace tinytc diff --git a/src/pass/dump_ir.hpp b/src/pass/dump_ir.hpp index bd2f99e4..28c5f99c 100644 --- a/src/pass/dump_ir.hpp +++ b/src/pass/dump_ir.hpp @@ -12,6 +12,7 @@ #include "node/value_node.hpp" #include "pass/slot_tracker.hpp" +#include #include #include @@ -19,7 +20,7 @@ namespace tinytc { class dump_ir_pass { public: - dump_ir_pass(std::ostream &os); + dump_ir_pass(std::ostream &os, int level_limit = std::numeric_limits::max()); /* Data type nodes */ void operator()(void_data_type const &); @@ -65,6 +66,8 @@ class dump_ir_pass { void operator()(yield_inst const &y); void run_on_function(function_node const &fn); + void run_on_region(region_node const ®); + void run_on_instruction(inst_node const &in); private: void dump_region(region_node const ®); @@ -82,6 +85,7 @@ class dump_ir_pass { } inline auto indent() { return std::string(2 * lvl_, ' '); } std::ostream *os_; + int lvl_limit_; int lvl_ = 0; slot_tracker tracker_; diff --git a/src/pass/insert_barrier.cpp b/src/pass/insert_barrier.cpp index 292fd6f9..945bdd7a 100644 --- a/src/pass/insert_barrier.cpp +++ b/src/pass/insert_barrier.cpp @@ -3,6 +3,7 @@ #include "pass/insert_barrier.hpp" #include "analysis/alias.hpp" +#include "analysis/cfg.hpp" #include "node/data_type_node.hpp" #include "node/inst_node.hpp" #include "node/value_node.hpp" @@ -36,11 +37,20 @@ void insert_barrier_pass::reads_writes::clear(address_space as) { writes[space].clear(); } +void insert_barrier_pass::reads_writes::merge(reads_writes const &other) { + for (std::size_t i = 0; i < reads.size(); ++i) { + reads[i].insert(other.reads[i].begin(), other.reads[i].end()); + } + for (std::size_t i = 0; i < writes.size(); ++i) { + writes[i].insert(other.writes[i].begin(), other.writes[i].end()); + } +} + void insert_barrier_pass::reads_writes::merge(reads_writes &&other) { for (std::size_t i = 0; i < reads.size(); ++i) { reads[i].merge(std::move(other.reads[i])); } - for (std::size_t i = 0; i < reads.size(); ++i) { + for (std::size_t i = 0; i < writes.size(); ++i) { writes[i].merge(std::move(other.writes[i])); } } @@ -53,28 +63,37 @@ void insert_barrier_pass::reads_writes::emplace_write(address_space as, ::tinytc const auto space = address_space_to_index(as); writes[space].emplace(val); } +auto insert_barrier_pass::reads_writes::read_cardinal(address_space as) const -> std::size_t { + const auto space = address_space_to_index(as); + return reads[space].size(); +} +auto insert_barrier_pass::reads_writes::write_cardinal(address_space as) const -> std::size_t { + const auto space = address_space_to_index(as); + return writes[space].size(); +} bool insert_barrier_pass::reads_writes::raw(address_space as, reads_writes const &rw, - aa_results const &aa) { + aa_results const &aa) const { const auto space = address_space_to_index(as); return intersects(rw.reads[space], writes[space], aa); } bool insert_barrier_pass::reads_writes::war(address_space as, reads_writes const &rw, - aa_results const &aa) { + aa_results const &aa) const { const auto space = address_space_to_index(as); return intersects(rw.writes[space], reads[space], aa); } bool insert_barrier_pass::reads_writes::waw(address_space as, reads_writes const &rw, - aa_results const &aa) { + aa_results const &aa) const { const auto space = address_space_to_index(as); return intersects(rw.writes[space], writes[space], aa); } bool insert_barrier_pass::reads_writes::raw_war_or_waw(address_space as, reads_writes const &rw, - aa_results const &aa) { + aa_results const &aa) const { return raw(as, rw, aa) || war(as, rw, aa) || waw(as, rw, aa); } -auto insert_barrier_pass::reads_writes::address_space_to_index(address_space as) -> std::size_t { +auto insert_barrier_pass::reads_writes::address_space_to_index(address_space as) const + -> std::size_t { for (std::size_t i = 0; i < address_spaces.size(); ++i) { if (as == address_spaces[i]) { return i; @@ -85,6 +104,70 @@ auto insert_barrier_pass::reads_writes::address_space_to_index(address_space as) auto insert_barrier_pass::run_on_region(region_node ®, aa_results const &aa, const bool insert_barriers) -> reads_writes { + // irw = reads and writes invisible to other threads + auto irw_in = std::unordered_map{}; + auto irw_out = std::unordered_map{}; + + auto const get_rw = [](inst_node &in) -> reads_writes { + auto rw = reads_writes{}; + auto const emplace_read = [&rw](value const &v) { + if (auto *m = dyn_cast(v->ty().get()); m) { + rw.emplace_read(m->addrspace(), v.get()); + } + }; + auto const emplace_write = [&rw](value const &v) { + if (auto *m = dyn_cast(v->ty().get()); m) { + rw.emplace_write(m->addrspace(), v.get()); + } + }; + + visit(overloaded{[&](blas_a2_inst &in) { + emplace_read(in.A()); + emplace_write(in.B()); + }, + [&](blas_a3_inst &in) { + emplace_read(in.A()); + emplace_read(in.B()); + emplace_write(in.C()); + }, + [&](load_inst &in) { emplace_read(in.operand()); }, + [&](store_inst &in) { emplace_write(in.operand()); }, [](inst_node &) {}}, + in); + return rw; + }; + + auto const get_cardinal = [](reads_writes const &rw) { + return std::array{ + rw.read_cardinal(address_space::global), rw.read_cardinal(address_space::local), + rw.write_cardinal(address_space::global), rw.write_cardinal(address_space::local)}; + }; + + auto cfg = get_control_flow_graph(reg); + auto q = cfg.node_queue(); + while (!q.empty()) { + auto n = q.front(); + q.pop(); + + auto &in = irw_in[n]; + auto &out = irw_out[n]; + for (auto &p : cfg.predecessors(n)) { + in.merge(irw_out[p]); + } + + auto out_size_before_update = get_cardinal(out); + out = get_rw(*n); + out.merge(in); + // out has changed, need to enqueue successors again + if (out_size_before_update != get_cardinal(out)) { + for (auto &s : cfg.successors(n)) { + q.push(s); + } + } + } +} + +/*auto insert_barrier_pass::run_on_region(region_node ®, aa_results const &aa, + const bool insert_barriers) -> reads_writes { auto invisible_rw = reads_writes{}; for (auto it = reg.begin(); it != reg.end(); ++it) { if (auto *barrier = dyn_cast(it->get()); insert_barriers && barrier) { @@ -95,6 +178,7 @@ auto insert_barrier_pass::run_on_region(region_node ®, aa_results const &aa, } } else { auto rw = reads_writes{}; + for (auto &subreg : (*it)->child_regions()) { const bool insert_barriers_sub = insert_barriers && subreg->kind() != region_kind::spmd; @@ -145,7 +229,7 @@ auto insert_barrier_pass::run_on_region(region_node ®, aa_results const &aa, } return invisible_rw; -} +}*/ /* Function nodes */ void insert_barrier_pass::run_on_function(function_node &fn) { diff --git a/src/pass/insert_barrier.hpp b/src/pass/insert_barrier.hpp index 2f74b41c..3b028be0 100644 --- a/src/pass/insert_barrier.hpp +++ b/src/pass/insert_barrier.hpp @@ -24,17 +24,20 @@ class insert_barrier_pass { address_space::local}; void clear(address_space as); + void merge(reads_writes const &other); void merge(reads_writes &&other); void emplace_read(address_space as, ::tinytc_value const *val); void emplace_write(address_space as, ::tinytc_value const *val); + auto read_cardinal(address_space as) const -> std::size_t; + auto write_cardinal(address_space as) const -> std::size_t; - bool raw(address_space as, reads_writes const &rw, aa_results const &aa); - bool war(address_space as, reads_writes const &rw, aa_results const &aa); - bool waw(address_space as, reads_writes const &rw, aa_results const &aa); - bool raw_war_or_waw(address_space as, reads_writes const &rw, aa_results const &aa); + bool raw(address_space as, reads_writes const &rw, aa_results const &aa) const; + bool war(address_space as, reads_writes const &rw, aa_results const &aa) const; + bool waw(address_space as, reads_writes const &rw, aa_results const &aa) const; + bool raw_war_or_waw(address_space as, reads_writes const &rw, aa_results const &aa) const; private: - auto address_space_to_index(address_space as) -> std::size_t; + auto address_space_to_index(address_space as) const -> std::size_t; std::array, address_spaces.size()> reads, writes; }; diff --git a/src/passes.def b/src/passes.def index 2988565a..b9251672 100644 --- a/src/passes.def +++ b/src/passes.def @@ -2,6 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause FUNCTION_PASS("check-ir", check_ir_pass{}) +FUNCTION_PASS("dump-control-flow-graph", dump_cfg_pass{std::cout}) FUNCTION_PASS("dump-ir", dump_ir_pass{std::cout}) FUNCTION_PASS("insert-barrier", insert_barrier_pass{}) FUNCTION_PASS("insert-lifetime-stop", insert_lifetime_stop_pass{}) diff --git a/src/support/util.hpp b/src/support/util.hpp index c9a4b275..2f26cd13 100644 --- a/src/support/util.hpp +++ b/src/support/util.hpp @@ -9,11 +9,6 @@ namespace tinytc { -template struct overloaded : Ts... { - using Ts::operator()...; -}; -template overloaded(Ts...) -> overloaded; - template auto enum_cast(V val) { return T{std::underlying_type_t(val)}; } diff --git a/src/support/visit.hpp b/src/support/visit.hpp index a49b2042..d876769f 100644 --- a/src/support/visit.hpp +++ b/src/support/visit.hpp @@ -13,6 +13,11 @@ namespace tinytc { +template struct overloaded : Ts... { + using Ts::operator()...; +}; +template overloaded(Ts...) -> overloaded; + namespace detail { /** * @brief Computes \prod_{i=0}^{MaxMode-1} Size_i, where Size_0 = Head, and Size_i = Tail_i for i > diff --git a/test/opt/insert-barrier.ir b/test/opt/insert-barrier.ir index e8bf6903..263571e7 100644 --- a/test/opt/insert-barrier.ir +++ b/test/opt/insert-barrier.ir @@ -100,6 +100,29 @@ func @if(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref< ; CHECK-NEXT: axpby.n %a, %A, %b, %B{{.*}} } +func @if2(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { + %0 = cmp.gt %a, 42.0 : f32 + axpby.n %a, %B, %b, %A : f32, memref, f32, memref + if %0 { + axpby.n %a, %A, %b, %B : f32, memref, f32, memref + axpby.n %a, %B, %b, %C : f32, memref, f32, memref + } else { + axpby.n %a, %C, %b, %D : f32, memref, f32, memref + } + axpby.n %a, %A, %b, %B : f32, memref, f32, memref +; CHECK-LABEL: func @if({{.*}} +; CHECK: if %0 { +; CHECK-NEXT: barrier.global +; CHECK-NEXT: axpby.n %a, %A, %b, %B{{.*}} +; CHECK-NEXT: barrier.global +; CHECK-NEXT: axpby.n %a, %B, %b, %C{{.*}} +; CHECK-NEXT: } else { +; CHECK-NEXT: axpby.n %a, %C, %b, %D{{.*}} +; CHECK-NEXT: } +; CHECK-NEXT: barrier.global +; CHECK-NEXT: axpby.n %a, %A, %b, %B{{.*}} +} + func @region1() { %0 = alloca -> memref for %i=0,4 : index { diff --git a/test/opt/insert-lifetime-stop.ir b/test/opt/insert-lifetime-stop.ir index b2595e05..9f926eff 100644 --- a/test/opt/insert-lifetime-stop.ir +++ b/test/opt/insert-lifetime-stop.ir @@ -54,8 +54,10 @@ func @region1() { ; CHECK: gemm.n.n{{.*}} ; CHECK-NEXT: lifetime_stop %2 ; CHECK-NEXT: axpby.n{{.*}} +; CHECK-NEXT: yield : ; CHECK-NEXT: } ; CHECK-NEXT: lifetime_stop %1 +; CHECK-NEXT: yield : ; CHECK-NEXT: } ; CHECK-NEXT: lifetime_stop %0 } diff --git a/tools/opt/args.cpp b/tools/opt/args.cpp index 8bcd6e8f..01b4bff2 100644 --- a/tools/opt/args.cpp +++ b/tools/opt/args.cpp @@ -76,7 +76,7 @@ args arg_parser::parse_args(int argc, char **argv) { a.info = make_core_info_intel_from_arch(intel_gpu_architecture::pvc); } - if (a.pass_names.empty() || std::strcmp(a.pass_names.back().c_str(), "dump-ir") != 0) { + if (a.pass_names.empty() || std::strncmp(a.pass_names.back().c_str(), "dump", 4) != 0) { a.pass_names.emplace_back(std::string("dump-ir")); } From 40d63c27366c166fb9012b9bfd72b7771a7142f0 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 23 Sep 2024 11:31:10 +0200 Subject: [PATCH 021/297] Make instruction unique instead of shared; region stores instruction in intrusive linked list Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.rst | 20 ++-- docs/api/builder_capi.yaml | 4 +- include/tinytc/tinytc.h | 37 +++--- include/tinytc/tinytc.hpp | 70 ++++++------ src/analysis/cfg.cpp | 12 +- src/inst.cpp | 19 +-- src/node/inst_node.cpp | 28 ++++- src/node/inst_node.hpp | 26 ++--- src/node/region_node.hpp | 21 ++-- src/parser/parser_impl.yy | 67 +++++------ src/pass/convert_to_opencl.cpp | 2 +- src/pass/dump_ir.cpp | 2 +- src/pass/insert_lifetime_stop.cpp | 17 +-- src/region.cpp | 23 ++-- src/support/ilist.hpp | 26 +++++ src/support/ilist_base.hpp | 184 ++++++++++++++++++++++++++++++ src/support/walk.cpp | 2 +- src/support/walk.hpp | 8 +- 18 files changed, 387 insertions(+), 181 deletions(-) create mode 100644 src/support/ilist.hpp create mode 100644 src/support/ilist_base.hpp diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index 1a9fbe17..4dd4075c 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -387,9 +387,7 @@ Instruction * :ref:`tinytc_inst_get_values` - * :ref:`tinytc_inst_release` - - * :ref:`tinytc_inst_retain` + * :ref:`tinytc_inst_destroy` Instruction Functions --------------------- @@ -544,15 +542,10 @@ tinytc_inst_get_values .. doxygenfunction:: tinytc_inst_get_values -tinytc_inst_release +tinytc_inst_destroy ................... -.. doxygenfunction:: tinytc_inst_release - -tinytc_inst_retain -.................. - -.. doxygenfunction:: tinytc_inst_retain +.. doxygenfunction:: tinytc_inst_destroy Program ======= @@ -611,6 +604,8 @@ Region * :ref:`tinytc_region_create` + * :ref:`tinytc_region_add_instruction` + * :ref:`tinytc_region_release` * :ref:`tinytc_region_retain` @@ -623,6 +618,11 @@ tinytc_region_create .. doxygenfunction:: tinytc_region_create +tinytc_region_add_instruction +............................. + +.. doxygenfunction:: tinytc_region_add_instruction + tinytc_region_release ..................... diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index a33a1619..dd89d046 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -82,8 +82,7 @@ Builder C-API: - tinytc_yield_inst_create - tinytc_inst_get_value - tinytc_inst_get_values - - tinytc_inst_release - - tinytc_inst_retain + - tinytc_inst_destroy Program: function: - tinytc_program_create @@ -95,6 +94,7 @@ Builder C-API: Region: function: - tinytc_region_create + - tinytc_region_add_instruction - tinytc_region_release - tinytc_region_retain Value: diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 1ae7e217..cd3149da 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -753,24 +753,11 @@ TINYTC_EXPORT tinytc_status_t tinytc_yield_inst_create(tinytc_inst_t *instr, const tinytc_location_t *loc); /** - * @brief Release inst object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param instr [inout] inst object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_inst_release(tinytc_inst_t instr); - -/** - * @brief Increase reference count of inst object by 1 + * @brief Delete inst object * * @param instr [inout] inst object - * - * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_inst_retain(tinytc_inst_t instr); +TINYTC_EXPORT void tinytc_inst_destroy(tinytc_inst_t instr); /** * @brief Get value produced by instruction @@ -808,17 +795,27 @@ TINYTC_EXPORT tinytc_status_t tinytc_inst_get_values(const_tinytc_inst_t instr, * @brief Create region * * @param reg [out] pointer to the region object created - * @param instruction_list_size [in] length of instruction array - * @param instruction_list [in][range(0, instruction_list_size)] instruction array; can be nullptr - * if instruction_list_size is 0 * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_region_create(tinytc_region_t *reg, - uint32_t instruction_list_size, - tinytc_inst_t *instruction_list, const tinytc_location_t *loc); + +/** + * @brief Append instruction to region + * + * The region takes ownership of the instruction. + * An instruction must not be added to multiple regions. + * + * @param reg [inout] region object + * @param instruction [in,pass_ownership] instruction + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_region_add_instruction(tinytc_region_t reg, + tinytc_inst_t instruction); + /** * @brief Release region object * diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 506a712a..11973c3a 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -638,20 +638,15 @@ inline char const *to_string(transpose t) { } namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_inst_t handle) -> tinytc_status_t { - return tinytc_inst_retain(handle); - } - static auto release(tinytc_inst_t handle) -> tinytc_status_t { - return tinytc_inst_release(handle); - } +template <> struct unique_handle_traits { + static void destroy(tinytc_inst_t handle) { return tinytc_inst_destroy(handle); } }; } // namespace internal //! @brief Reference-counting wrapper for tinytc_inst_t -class inst : public shared_handle { +class inst : public unique_handle { public: - using shared_handle::shared_handle; + using unique_handle::unique_handle; /** * @brief Get result value @@ -705,26 +700,27 @@ template <> struct shared_handle_traits { class region : public shared_handle { public: using shared_handle::shared_handle; + + /** + * @brief Append instruction to region + * + * @param instruction instruction; region takes ownership + */ + inline void add_instruction(inst instruction) { + CHECK_STATUS(tinytc_region_add_instruction(get(), instruction.release())); + } }; /** * @brief Make region * - * @param instructions Vector of instructions * @param loc Source code location * * @return Region */ -inline region make_region(std::vector &instructions, location const &loc = {}) { +inline region make_region(location const &loc = {}) { tinytc_region_t reg; - static_assert(internal::inst_reinterpret_allowed); - if (instructions.size() > std::numeric_limits::max()) { - throw std::out_of_range("Instruction list too long"); - } - CHECK_STATUS_LOC(tinytc_region_create(®, instructions.size(), - reinterpret_cast(instructions.data()), - &loc), - loc); + CHECK_STATUS_LOC(tinytc_region_create(®, &loc), loc); return region{reg}; } @@ -1416,15 +1412,18 @@ inline prog make_program(std::vector &fun_list, location const &loc = {}) class region_builder { public: /** - * @brief Returns built product + * @brief ctor * * @param loc Source code location + */ + region_builder(location const &loc = {}) : reg_{make_region(loc)} {} + + /** + * @brief Returns built product * * @return Region */ - inline auto get_product(location const &loc = {}) -> region { - return make_region(instructions_, loc); - } + inline auto get_product() && -> region { return std::move(reg_); } /** * @brief Add instruction @@ -1439,7 +1438,7 @@ class region_builder { if (result && name.size() > 0) { result.name(name); } - instructions_.emplace_back(std::move(i)); + reg_.add_instruction(std::move(i)); return result; } @@ -1460,7 +1459,7 @@ class region_builder { result.name(name + std::to_string(counter++)); } } - instructions_.emplace_back(std::move(i)); + reg_.add_instruction(std::move(i)); return results; } @@ -1506,7 +1505,8 @@ class region_builder { } auto bb = region_builder{}; f(bb, loop_var); - add(::tinytc::make_for(std::move(loop_var), from, to, step, bb.get_product(), loc)); + add(::tinytc::make_for(std::move(loop_var), from, to, step, std::move(bb).get_product(), + loc)); } /** * @brief Build foreach-loop with functor f(region_builder&) -> void @@ -1528,7 +1528,8 @@ class region_builder { } auto bb = region_builder{}; f(bb); - add(::tinytc::make_foreach(std::move(loop_var), from, to, bb.get_product(), loc)); + add(::tinytc::make_foreach(std::move(loop_var), from, to, std::move(bb).get_product(), + loc)); } /** @@ -1548,8 +1549,8 @@ class region_builder { location const &loc = {}) -> std::vector { auto bb = region_builder{}; then(bb); - return add_multivalued(::tinytc::make_if(std::move(condition), bb.get_product(), region{}, - return_type_list, loc)); + return add_multivalued(::tinytc::make_if(std::move(condition), std::move(bb).get_product(), + region{}, return_type_list, loc)); } /** * @brief Build if/else with functors then(region_builder&) -> void and @@ -1573,12 +1574,13 @@ class region_builder { then(bb1); auto bb2 = region_builder{}; otherwise(bb2); - return add_multivalued(::tinytc::make_if(std::move(condition), bb1.get_product(), - bb2.get_product(), return_type_list, loc)); + return add_multivalued(::tinytc::make_if(std::move(condition), std::move(bb1).get_product(), + std::move(bb2).get_product(), return_type_list, + loc)); } private: - std::vector instructions_; + region reg_; }; //! Builder for functions @@ -1653,9 +1655,9 @@ class function_builder { * @param loc Source code location */ template void body(F &&f, location const &loc = {}) { - auto bb = region_builder{}; + auto bb = region_builder{loc}; f(bb); - body_ = bb.get_product(loc); + body_ = std::move(bb).get_product(); } private: diff --git a/src/analysis/cfg.cpp b/src/analysis/cfg.cpp index a4bb1b2f..ae9bde46 100644 --- a/src/analysis/cfg.cpp +++ b/src/analysis/cfg.cpp @@ -25,25 +25,25 @@ auto get_control_flow_graph(region_node &topreg) -> control_flow_graph { return {}; } - auto start = reg.begin()->get(); + auto start = reg.begin().get(); cfg.add_node(start); auto pred_nodes = std::queue{}; pred_nodes.push(start); - for (auto it = reg.begin() + 1; it != reg.end(); ++it) { - inst_node *node = it->get(); + for (auto it = ++reg.begin(); it != reg.end(); ++it) { + inst_node *node = it.get(); cfg.add_node(node); for (; !pred_nodes.empty(); pred_nodes.pop()) { cfg.add_edge(pred_nodes.front(), node); } - if ((*it)->num_child_regions() > 0) { - for (auto &subreg : (*it)->child_regions()) { + if (it->num_child_regions() > 0) { + for (auto &subreg : it->child_regions()) { auto [substart, subexit] = add_region_ref(*subreg, add_region_ref); cfg.add_edge(node, substart); - if (isa(**it)) { + if (isa(*it)) { cfg.add_edge(subexit, node); pred_nodes.push(node); } else { diff --git a/src/inst.cpp b/src/inst.cpp index 777433e4..4e7c0829 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -472,24 +472,7 @@ tinytc_status_t tinytc_yield_inst_create(tinytc_inst_t *instr, uint32_t yield_li }); } -tinytc_status_t tinytc_inst_release(tinytc_inst_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; -} - -tinytc_status_t tinytc_inst_retain(tinytc_inst_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - obj->inc_ref(); - return tinytc_status_success; -} +void tinytc_inst_destroy(tinytc_inst_t obj) { delete obj; } tinytc_status_t tinytc_inst_get_value(const_tinytc_inst_t instr, tinytc_value_t *result) { if (instr == nullptr || result == nullptr) { diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 8965a224..dd8659a2 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -4,6 +4,7 @@ #include "node/inst_node.hpp" #include "error.hpp" #include "node/data_type_node.hpp" +#include "node/region_node.hpp" #include "node/value_node.hpp" #include "scalar_type.hpp" #include "support/casting.hpp" @@ -76,8 +77,8 @@ loop_inst::loop_inst(IK tid, value loop_var0, value from0, value to0, value step } region_node &body = *child_region(0); - if (body.empty() || !isa(**(body.end() - 1))) { - body.insert(body.end(), make_yield({}, lc)); + if (body.empty() || !isa(*(--body.end()))) { + body.insert(body.end(), std::make_unique(std::vector{}, lc).release()); } } @@ -441,6 +442,17 @@ ger_inst::ger_inst(value alpha0, value A0, value B0, value beta0, value C0, bool } } +foreach_inst::foreach_inst(value loop_var, value from, value to, region body, location const &loc) + : loop_inst{IK::foreach_loop, + std::move(loop_var), + std::move(from), + std::move(to), + {}, + std::move(body), + loc} { + child_region(0)->kind(region_kind::spmd); +} + hadamard_inst::hadamard_inst(value alpha0, value A0, value B0, value beta0, value C0, bool atomic, location const &lc) : blas_a3_inst(IK::hadamard_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), @@ -479,12 +491,20 @@ if_inst::if_inst(value condition, region then0, region otherwise0, for (std::int64_t i = 0; i < num_child_regions(); ++i) { region_node &body = *child_region(i); - if (body.empty() || !isa(**(body.end() - 1))) { - body.insert(body.end(), make_yield({}, lc)); + if (body.empty() || !isa(*(--body.end()))) { + body.insert(body.end(), + std::make_unique(std::vector{}, lc).release()); } } } +parallel_inst::parallel_inst(region body, location const &lc) : standard_inst{IK::parallel} { + child_region(0) = std::move(body); + loc(lc); + + child_region(0)->kind(region_kind::spmd); +} + size_inst::size_inst(value op0, std::int64_t mode, location const &lc) : standard_inst{IK::size}, mode_(mode) { op(0) = std::move(op0); diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 7fbf9a90..90a510f2 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -5,8 +5,8 @@ #define INST_NODE_20230327_HPP #include "error.hpp" -#include "node/region_node.hpp" #include "reference_counted.hpp" +#include "support/ilist.hpp" #include "support/type_list.hpp" #include "support/util.hpp" #include "tinytc/tinytc.hpp" @@ -86,7 +86,7 @@ using const_region_range = iterator_range_wrapper; } // namespace tinytc -struct tinytc_inst : tinytc::reference_counted { +struct tinytc_inst : tinytc::ilist_node, tinytc::reference_counted { public: using leaves = tinytc::inst_nodes; @@ -227,6 +227,11 @@ namespace tinytc { using inst_node = ::tinytc_inst; +template <> struct ilist_traits { + static void on_insert(inst_node *) {} + static void on_erase(inst_node *node) { tinytc_inst_destroy(node); } +}; + template class object_container { public: object_container(std::int64_t num_objects) { @@ -543,16 +548,7 @@ class for_inst : public loop_inst { class foreach_inst : public loop_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::foreach_loop; } - inline foreach_inst(value loop_var, value from, value to, region body, location const &loc = {}) - : loop_inst{IK::foreach_loop, - std::move(loop_var), - std::move(from), - std::move(to), - {}, - std::move(body), - loc} { - child_region(0)->kind(region_kind::spmd); - } + foreach_inst(value loop_var, value from, value to, region body, location const &loc = {}); }; class hadamard_inst : public blas_a3_inst { @@ -585,12 +581,8 @@ class num_subgroups_inst : public standard_inst<0, 1> { class parallel_inst : public standard_inst<0, 0, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::parallel; } - inline parallel_inst(region body, location const &lc = {}) : standard_inst{IK::parallel} { - child_region(0) = std::move(body); - loc(lc); + parallel_inst(region body, location const &lc = {}); - child_region(0)->kind(region_kind::spmd); - } inline auto body() const -> region const & { return child_region(0); } }; diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp index f51a139d..1ae7a14d 100644 --- a/src/node/region_node.hpp +++ b/src/node/region_node.hpp @@ -4,7 +4,9 @@ #ifndef REGION_NODE_20230908_HPP #define REGION_NODE_20230908_HPP +#include "node/inst_node.hpp" #include "reference_counted.hpp" +#include "support/ilist.hpp" #include "support/util.hpp" #include "tinytc/tinytc.hpp" @@ -21,11 +23,10 @@ enum class region_kind { mixed, collective, spmd }; struct tinytc_region final : tinytc::reference_counted { public: - using iterator = std::vector::iterator; - using const_iterator = std::vector::const_iterator; + using iterator = tinytc::ilist::iterator; + using const_iterator = tinytc::ilist::const_iterator; - inline tinytc_region(std::vector insts = {}, tinytc::location const &lc = {}) - : insts_(std::move(insts)), kind_(tinytc::region_kind::mixed) { + inline tinytc_region(tinytc::location const &lc = {}) : kind_(tinytc::region_kind::mixed) { loc(lc); } @@ -43,19 +44,17 @@ struct tinytc_region final : tinytc::reference_counted { inline auto insts() const -> tinytc::iterator_range_wrapper { return {begin(), end()}; } - inline void insts(std::vector insts) { insts_ = std::move(insts); } + inline void push_back(tinytc_inst_t i) { insts_.push_back(i); } inline auto erase(iterator pos) -> iterator { return insts_.erase(pos); } - inline auto insert(iterator pos, tinytc::inst const &i) -> iterator { - return insts_.insert(pos, i); - } - inline auto insert(iterator pos, tinytc::inst &&i) -> iterator { - return insts_.insert(pos, std::move(i)); + inline auto insert(iterator pos, tinytc_inst_t i) -> iterator { return insts_.insert(pos, i); } + inline auto insert_after(iterator pos, tinytc_inst_t i) -> iterator { + return insts_.insert_after(pos, i); } inline auto empty() const -> bool { return insts_.empty(); } private: - std::vector insts_; tinytc::region_kind kind_; + tinytc::ilist insts_; tinytc::location loc_; }; diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 93e4abf1..159519c3 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -392,7 +392,10 @@ region: LBRACE { ctx.push_scope(); } instructions RBRACE { - $$ = region{std::make_unique(std::move($instructions), @region).release()}; + $$ = region{std::make_unique(@region).release()}; + for (auto& i : $instructions) { + $$.add_instruction(std::move(i)); + } ctx.pop_scope(); } ; @@ -410,20 +413,20 @@ instructions: ; instruction: - axpby_inst - | barrier_inst - | gemm_inst - | gemv_inst - | ger_inst - | for_inst - | foreach_inst - | hadamard_inst - | if_inst - | parallel_inst - | var_definition - | store_inst - | sum_inst - | yield_inst + axpby_inst { $$ = std::move($1); } + | barrier_inst { $$ = std::move($1); } + | gemm_inst { $$ = std::move($1); } + | gemv_inst { $$ = std::move($1); } + | ger_inst { $$ = std::move($1); } + | for_inst { $$ = std::move($1); } + | foreach_inst { $$ = std::move($1); } + | hadamard_inst { $$ = std::move($1); } + | if_inst { $$ = std::move($1); } + | parallel_inst { $$ = std::move($1); } + | var_definition { $$ = std::move($1); } + | store_inst { $$ = std::move($1); } + | sum_inst { $$ = std::move($1); } + | yield_inst { $$ = std::move($1); } ; axpby_inst: @@ -709,23 +712,23 @@ yield_inst: ; valued_inst: - alloca_inst - | arith_inst - | arith_unary_inst - | cast_inst - | compare_inst - | expand_inst - | fuse_inst - | group_id_inst - | group_size_inst - | if_inst - | load_inst - | num_subgroups_inst - | size_inst - | subgroup_id_inst - | subgroup_local_id_inst - | subgroup_size_inst - | subview_inst + alloca_inst { $$ = std::move($1); } + | arith_inst { $$ = std::move($1); } + | arith_unary_inst { $$ = std::move($1); } + | cast_inst { $$ = std::move($1); } + | compare_inst { $$ = std::move($1); } + | expand_inst { $$ = std::move($1); } + | fuse_inst { $$ = std::move($1); } + | group_id_inst { $$ = std::move($1); } + | group_size_inst { $$ = std::move($1); } + | if_inst { $$ = std::move($1); } + | load_inst { $$ = std::move($1); } + | num_subgroups_inst { $$ = std::move($1); } + | size_inst { $$ = std::move($1); } + | subgroup_id_inst { $$ = std::move($1); } + | subgroup_local_id_inst { $$ = std::move($1); } + | subgroup_size_inst { $$ = std::move($1); } + | subview_inst { $$ = std::move($1); } ; alloca_inst: diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 05e366fd..dd01a219 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -1092,7 +1092,7 @@ clir::stmt convert_to_opencl_pass::run_on_region(region_node ®) { declared_vars_.push_back({}); auto bb = clir::block_builder{}; for (auto &s : reg.insts()) { - for (auto &cs : visit(*this, *s)) { + for (auto &cs : visit(*this, s)) { bb.add(cs); } } diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 31c1d407..e9edbe45 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -370,7 +370,7 @@ void dump_ir_pass::dump_region(region_node const ®) { auto ind = indent(); for (auto const &i : reg) { *os_ << ind; - visit(*this, *i); + visit(*this, i); *os_ << std::endl; } --lvl_; diff --git a/src/pass/insert_lifetime_stop.cpp b/src/pass/insert_lifetime_stop.cpp index 281bcf25..494e374b 100644 --- a/src/pass/insert_lifetime_stop.cpp +++ b/src/pass/insert_lifetime_stop.cpp @@ -23,24 +23,24 @@ auto insert_lifetime_stop_pass::run_on_region(region_node ®, aa_results const auto allocas = std::vector{}; for (auto &i : reg) { - if (auto alloca = dyn_cast(i.get()); alloca != nullptr) { + if (auto alloca = dyn_cast(&i); alloca != nullptr) { allocas.emplace_back(alloca->result(0)); } } auto rgn_ops = std::unordered_set{}; auto prev_it = reg.end(); - for (; prev_it != reg.begin(); --prev_it) { - auto &i = *(prev_it - 1); - for (auto &subreg : i->child_regions()) { + while (prev_it != reg.begin()) { + auto &i = *(--prev_it); + for (auto &subreg : i.child_regions()) { rgn_ops.merge(run_on_region(*subreg, aa)); } - for (auto &v : i->operands()) { + for (auto &v : i.operands()) { if (isa(*v->ty())) { rgn_ops.insert(aa.root(*v)); } } - for (auto &v : i->results()) { + for (auto &v : i.results()) { if (isa(*v->ty())) { rgn_ops.insert(aa.root(*v)); } @@ -49,8 +49,9 @@ auto insert_lifetime_stop_pass::run_on_region(region_node ®, aa_results const auto alloca_it = allocas.begin(); while (alloca_it != allocas.end()) { if (rgn_ops.contains(alloca_it->get())) { - prev_it = reg.insert( - prev_it, inst{std::make_unique(*alloca_it).release()}); + prev_it = reg.insert_after( + prev_it, std::make_unique(*alloca_it).release()); + --prev_it; alloca_it = allocas.erase(alloca_it); } else { ++alloca_it; diff --git a/src/region.cpp b/src/region.cpp index 97aa2cd4..26869adc 100644 --- a/src/region.cpp +++ b/src/region.cpp @@ -17,20 +17,19 @@ using namespace tinytc; extern "C" { -tinytc_status_t tinytc_region_create(tinytc_region_t *reg, uint32_t instruction_list_size, - tinytc_inst_t *instruction_list, - const tinytc_location_t *loc) { - if (reg == nullptr || (instruction_list_size > 0 && instruction_list == nullptr)) { +tinytc_status_t tinytc_region_create(tinytc_region_t *reg, const tinytc_location_t *loc) { + if (reg == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { - auto inst_vec = std::vector(); - inst_vec.reserve(instruction_list_size); - for (uint32_t i = 0; i < instruction_list_size; ++i) { - inst_vec.emplace_back(inst(instruction_list[i], true)); - } - *reg = std::make_unique(std::move(inst_vec), get_optional(loc)).release(); - }); + return exception_to_status_code( + [&] { *reg = std::make_unique(get_optional(loc)).release(); }); +} + +tinytc_status_t tinytc_region_add_instruction(tinytc_region_t reg, tinytc_inst_t instruction) { + if (reg == nullptr || instruction == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { reg->push_back(instruction); }); } tinytc_status_t tinytc_region_release(tinytc_region_t obj) { diff --git a/src/support/ilist.hpp b/src/support/ilist.hpp new file mode 100644 index 00000000..3406904f --- /dev/null +++ b/src/support/ilist.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef ILIST_20240923_HPP +#define ILIST_20240923_HPP + +#include "support/ilist_base.hpp" + +namespace tinytc { + +template struct ilist_traits; + +template > +class ilist : public ilist_base { + public: + ilist() = default; + + ilist(ilist const &other) = delete; + ilist(ilist &&other) = default; + ilist &operator=(ilist const &other) = delete; + ilist &operator=(ilist &&other) = default; +}; + +} // namespace tinytc + +#endif // ILIST_20240923_HPP diff --git a/src/support/ilist_base.hpp b/src/support/ilist_base.hpp new file mode 100644 index 00000000..6dfd0806 --- /dev/null +++ b/src/support/ilist_base.hpp @@ -0,0 +1,184 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef ILIST_BASE_20240923_HPP +#define ILIST_BASE_20240923_HPP + +#include +#include +#include +#include + +namespace tinytc { + +template class ilist_node { + public: + auto prev() const -> T * { return prev_; } + void prev(T *prev) { prev_ = prev; } + auto next() const -> T * { return next_; } + void next(T *next) { next_ = next; } + + auto sentinel() const -> bool { return sentinel_; } + void set_sentinel() { sentinel_ = true; } + + private: + T *prev_ = nullptr, *next_ = nullptr; + bool sentinel_ = false; +}; + +template class ilist_iterator { + public: + using base_type = std::conditional_t, ilist_node>; + using base_pointer = base_type *; + using value_type = std::conditional_t; + using pointer = value_type *; + using reference = value_type &; + using difference_type = std::ptrdiff_t; + + ilist_iterator() : pos_{nullptr} {} + ilist_iterator(base_pointer pos) : pos_{std::move(pos)} {} + + auto operator*() const -> reference { return *static_cast(pos_); } + auto operator->() const -> pointer { return get(); } + auto get() const -> pointer { return static_cast(pos_); } + auto &operator++() { + pos_ = static_cast(pos_->next()); + return *this; + } + auto operator++(int) { + auto old_pos = pos_; + pos_ = static_cast(pos_->next()); + return ilist_iterator{old_pos}; + } + auto &operator--() { + pos_ = static_cast(pos_->prev()); + return *this; + } + auto operator--(int) { + auto old_pos = pos_; + pos_ = static_cast(pos_->prev()); + return ilist_iterator{old_pos}; + } + auto operator==(ilist_iterator const &other) const -> bool { return pos_ == other.pos_; } + auto operator!=(ilist_iterator const &other) const -> bool { return pos_ != other.pos_; } + + private: + base_pointer pos_; +}; + +template struct ilist_dummy_callback { + static void on_insert(NodeT *) {} + static void on_erase(NodeT *) {} +}; + +template > +requires requires(NodeT *node) { + std::is_base_of_v, NodeT>; + IListCallback::on_insert(node); + IListCallback::on_erase(node); +} +class ilist_base { + public: + using value_type = NodeT; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using pointer = value_type *; + using reference = value_type &; + using const_reference = const value_type &; + using iterator = ilist_iterator; + using const_iterator = ilist_iterator; + static_assert(std::bidirectional_iterator); + + ilist_base() { + sentinel_.set_sentinel(); + // let's go in a circle - yay! + sentinel_.prev(static_cast(&sentinel_)); + sentinel_.next(static_cast(&sentinel_)); + } + ~ilist_base() { clear(); } + + auto begin() -> iterator { return ++iterator{&sentinel_}; } + auto begin() const -> const_iterator { return cbegin(); } + auto cbegin() const -> const_iterator { return ++const_iterator{&sentinel_}; } + auto end() -> iterator { return iterator{&sentinel_}; } + auto end() const -> const_iterator { return cend(); } + auto cend() const -> const_iterator { return const_iterator{&sentinel_}; } + + auto empty() const -> bool { + return sentinel_.prev() == &sentinel_ && sentinel_.next() == &sentinel_; + } + auto size() const -> std::size_t { + std::size_t s = 0; + for (auto it = begin(); it != end(); ++it) { + ++s; + } + return s; + } + + void push_front(pointer node) { insert(begin(), node); } + void push_back(pointer node) { insert(end(), node); } + void pop_front() { erase(begin()); } + void pop_back() { erase(--end()); } + void clear() { erase(begin(), end()); } + + auto insert(iterator it, pointer node) -> iterator { + // let s = sentinel + // |0|: s{prev->s,next->s} + // |1|: n0{prev->s,next->s}, s{prev->n0,next->n0} + pointer prev = it->prev(); + prev->next(node); + node->prev(prev); + node->next(it.get()); + it->prev(node); + // |0| (it -> s) : node{prev->s,next->s}, s{prev->n0,next->n0} + // |1| (it -> n0): node{prev->s,next->n0}, n0{prev->node,next->s}, s{prev->n0,next->node} + // |1| (it -> s) : n0{prev->s,next->node}, node{prev->n0,next->s}, s{prev->node,next->n0} + IListCallback::on_insert(node); + return iterator{node}; + } + template auto insert(iterator it, ItT begin, ItT end) -> iterator { + if (begin != end) { + it = insert(it, *begin++); + auto first_it = it; + for (; begin != end; ++begin) { + it = insert(it, *begin); + ++it; // skip over just inserted value + } + return first_it; + } + return it; + } + auto insert_after(iterator it, pointer node) -> iterator { return insert(++it, node); } + + auto erase(iterator it) -> iterator { + // let s = sentinel + // |0|: s{prev->s,next->s} + // |1|: n0{prev->s,next->s}, s{prev->n0,next->n0} + // |2|: n0{prev->s,next->n1}, n1{prev->n0,next->s}, s{prev->n1,next->n0} + pointer prev = it->prev(); + pointer next = it->prev(); + prev->prev(next); + next->prev(prev); + it->prev(nullptr); + it->next(nullptr); + // |0| (it -> s) : s{prev->s,next->s} + // |1| (it -> n0): s{prev->s,next->s} + // |2| (it -> n0): n1{prev->s,next->s}, s{prev->n1,next->n1} + // |2| (it -> n1): n0{prev->s,next->s}, s{prev->n0,next->n0} + IListCallback::on_erase(it.get()); + return iterator{next}; + } + auto erase(iterator begin, iterator end) -> iterator { + while (begin != end) { + begin = erase(begin); + } + return begin; + } + + private: + ilist_node sentinel_; +}; + +} // namespace tinytc + +#endif // ILIST_BASE_20240923_HPP diff --git a/src/support/walk.cpp b/src/support/walk.cpp index b158f36c..a1a1762e 100644 --- a/src/support/walk.cpp +++ b/src/support/walk.cpp @@ -15,7 +15,7 @@ void walk(inst_node &i, std::function void walk(inst_node &i, std::function(*j, callback); + walk(j, callback); } } if constexpr (Order == walk_order::post_order) { @@ -49,7 +49,7 @@ template void walk(inst_node &i, std::function(*j, callback); + walk(j, callback); } if constexpr (Order == walk_order::post_order) { callback(reg); @@ -62,14 +62,14 @@ void walk(inst_node &i, std::function void walk(function_node &fn, std::function callback) { for (auto &i : *fn.body()) { - walk(*i, callback); + walk(i, callback); } } inline void walk(function_node &fn, std::function callback) { for (auto &i : *fn.body()) { - walk(*i, callback); + walk(i, callback); } } From bd9b07a20c89b9d5f5c95489994022df287ed48f Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 23 Sep 2024 15:19:25 +0200 Subject: [PATCH 022/297] Make region and function unique instead of shared Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.rst | 33 ++++----- docs/api/builder_capi.yaml | 7 +- include/tinytc/tinytc.h | 72 +++++++++---------- include/tinytc/tinytc.hpp | 111 ++++++++++++++---------------- src/func.cpp | 23 +------ src/inst.cpp | 18 +++-- src/node/function_node.hpp | 5 +- src/node/inst_node.hpp | 24 ++++--- src/node/program_node.hpp | 6 +- src/node/region_node.hpp | 10 ++- src/parser/parse_context.cpp | 16 ----- src/parser/parse_context.hpp | 4 -- src/parser/parser_impl.yy | 6 +- src/pass/convert_to_opencl.cpp | 20 +++--- src/pass/convert_to_opencl.hpp | 6 +- src/pass/dump_cfg.cpp | 2 +- src/pass/dump_ir.cpp | 14 ++-- src/pass/insert_barrier.cpp | 2 +- src/pass/insert_lifetime_stop.cpp | 2 +- src/pass/lower_linalg.cpp | 2 +- src/prog.cpp | 22 +++--- src/recipe/small_gemm_batched.cpp | 19 ++--- src/recipe/tall_and_skinny.cpp | 19 ++--- src/region.cpp | 19 +---- src/support/walk.hpp | 11 +-- 25 files changed, 205 insertions(+), 268 deletions(-) diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index 4dd4075c..920d105e 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -290,9 +290,7 @@ Function * :ref:`tinytc_function_set_work_group_size` - * :ref:`tinytc_func_release` - - * :ref:`tinytc_func_retain` + * :ref:`tinytc_func_destroy` Function Functions ------------------ @@ -312,15 +310,10 @@ tinytc_function_set_work_group_size .. doxygenfunction:: tinytc_function_set_work_group_size -tinytc_func_release +tinytc_func_destroy ................... -.. doxygenfunction:: tinytc_func_release - -tinytc_func_retain -.................. - -.. doxygenfunction:: tinytc_func_retain +.. doxygenfunction:: tinytc_func_destroy Instruction =========== @@ -554,6 +547,8 @@ Program * :ref:`tinytc_program_create` + * :ref:`tinytc_prog_add_function` + * :ref:`tinytc_prog_dump` * :ref:`tinytc_prog_print_to_file` @@ -572,6 +567,11 @@ tinytc_program_create .. doxygenfunction:: tinytc_program_create +tinytc_prog_add_function +........................ + +.. doxygenfunction:: tinytc_prog_add_function + tinytc_prog_dump ................ @@ -606,9 +606,7 @@ Region * :ref:`tinytc_region_add_instruction` - * :ref:`tinytc_region_release` - - * :ref:`tinytc_region_retain` + * :ref:`tinytc_region_destroy` Region Functions ---------------- @@ -623,15 +621,10 @@ tinytc_region_add_instruction .. doxygenfunction:: tinytc_region_add_instruction -tinytc_region_release +tinytc_region_destroy ..................... -.. doxygenfunction:: tinytc_region_release - -tinytc_region_retain -.................... - -.. doxygenfunction:: tinytc_region_retain +.. doxygenfunction:: tinytc_region_destroy Value ===== diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index dd89d046..e5b44117 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -48,8 +48,7 @@ Builder C-API: - tinytc_function_create - tinytc_function_set_subgroup_size - tinytc_function_set_work_group_size - - tinytc_func_release - - tinytc_func_retain + - tinytc_func_destroy Instruction: function: - tinytc_alloca_inst_create @@ -86,6 +85,7 @@ Builder C-API: Program: function: - tinytc_program_create + - tinytc_prog_add_function - tinytc_prog_dump - tinytc_prog_print_to_file - tinytc_prog_print_to_string @@ -95,8 +95,7 @@ Builder C-API: function: - tinytc_region_create - tinytc_region_add_instruction - - tinytc_region_release - - tinytc_region_retain + - tinytc_region_destroy Value: function: - tinytc_float_imm_create diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index cd3149da..458af722 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -524,12 +524,14 @@ TINYTC_EXPORT tinytc_status_t tinytc_num_subgroups_inst_create(tinytc_inst_t *in /** * @brief Create parallel region * + * Takes ownership of region. + * * @code * parallel { %body } * @endcode * * @param instr [out] pointer to the inst object created - * @param body [in] loop body + * @param body [in,pass_ownership] loop body * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise @@ -662,6 +664,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinyt /** * @brief Create for loop * + * Takes ownership of region. + * * @code * for %loop_var = %from, %to, %step : type(%loop_var) { %body } * ; type(%loop_var) == type(%from) @@ -674,7 +678,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinyt * @param from [in] loop begion * @param to [in] loop bound * @param step [in][optional] loop step; can be nullptr - * @param body [in] loop body + * @param body [in,pass_ownership] loop body * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise @@ -687,6 +691,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinyt /** * @brief Create foreach loop * + * Takes ownership of region. + * * @code * foreach %loop_var = %from, %to : type(%loop_var) { %body } * ; type(%loop_var) == type(%from) @@ -697,7 +703,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinyt * @param loop_var [in] loop variable * @param from [in] loop begion * @param to [in] loop bound - * @param body [in] loop body + * @param body [in,pass_ownership] loop body * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise @@ -711,14 +717,16 @@ TINYTC_EXPORT tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, /** * @brief Create if condition * + * Takes ownership of if and else region (if given). + * * @code * if %condition { %then } else { %otherwise } * @endcode * * @param instr [out] pointer to the inst object created * @param condition [in] condition - * @param then [in] region taken if condition is true - * @param otherwise [in][optional] region taken if condition is false; can be nullptr + * @param then [in,pass_ownership] region taken if condition is true + * @param otherwise [in,pass_ownership][optional] region taken if condition is false; can be nullptr * @param return_type_list_size [in] length of return type array * @param return_type_list [in][range(0, return_type_list_size)] return type array; can be nullptr * if return_type_list_size is 0 @@ -817,24 +825,11 @@ TINYTC_EXPORT tinytc_status_t tinytc_region_add_instruction(tinytc_region_t reg, tinytc_inst_t instruction); /** - * @brief Release region object - * - * Decreases reference count by 1, free memory if reference count is 0. + * @brief Delete region object * * @param reg [inout] region object - * - * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_region_release(tinytc_region_t reg); - -/** - * @brief Increase reference count of region object by 1 - * - * @param reg [inout] region object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_region_retain(tinytc_region_t reg); +TINYTC_EXPORT void tinytc_region_destroy(tinytc_region_t reg); //////////////////////////// /////////// Func /////////// @@ -843,11 +838,13 @@ TINYTC_EXPORT tinytc_status_t tinytc_region_retain(tinytc_region_t reg); /** * @brief Create function * + * Function takes ownership of region. + * * @param fun [out] pointer to the func object created * @param name [in] function name * @param arg_list_size [in] length of argument array * @param arg_list [in][range(0,arg_list_size)] argument array; can be nullptr if arg_list_size is 0 - * @param body [in] function body + * @param body [in,pass_ownership] function body * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise @@ -878,24 +875,13 @@ TINYTC_EXPORT tinytc_status_t tinytc_function_set_work_group_size(tinytc_func_t TINYTC_EXPORT tinytc_status_t tinytc_function_set_subgroup_size(tinytc_func_t fun, int32_t sgs); /** - * @brief Release function object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param fun [inout] function object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_func_release(tinytc_func_t fun); - -/** - * @brief Increase reference count of function object by 1 + * @brief Delete function object * * @param fun [inout] function object * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_func_retain(tinytc_func_t fun); +TINYTC_EXPORT void tinytc_func_destroy(tinytc_func_t fun); //////////////////////////// /////////// Prog /////////// @@ -905,16 +891,26 @@ TINYTC_EXPORT tinytc_status_t tinytc_func_retain(tinytc_func_t fun); * @brief Create program * * @param prg [out] pointer to the prog object created - * @param fun_list_size [in] length of func array - * @param fun_list [in][range(0, fun_list_size)] func array; can be nullptr if fun_list_size is 0 * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_program_create(tinytc_prog_t *prg, uint32_t fun_list_size, - tinytc_func_t *fun_list, +TINYTC_EXPORT tinytc_status_t tinytc_program_create(tinytc_prog_t *prg, const tinytc_location_t *loc); +/** + * @brief Append function to program + * + * The program takes ownership of the function. + * A function must not be added to multiple programs. + * + * @param prg [inout] program object + * @param fun [in,pass_ownership] function object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prog_add_function(tinytc_prog_t prg, tinytc_func_t fun); + /** * @brief Release program object * diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 11973c3a..13256bfb 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -675,31 +675,20 @@ class inst : public unique_handle { } }; -namespace internal { -//! Is reinterpret_cast(&i) allowed, where i has type inst -constexpr bool inst_reinterpret_allowed = - std::is_standard_layout_v && sizeof(inst) == sizeof(tinytc_inst_t); -} // namespace internal - //////////////////////////// ////////// Region ////////// //////////////////////////// namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_region_t handle) -> tinytc_status_t { - return tinytc_region_retain(handle); - } - static auto release(tinytc_region_t handle) -> tinytc_status_t { - return tinytc_region_release(handle); - } +template <> struct unique_handle_traits { + static void destroy(tinytc_region_t handle) { return tinytc_region_destroy(handle); } }; } // namespace internal //! @brief Reference-counting wrapper for tinytc_region_t -class region : public shared_handle { +class region : public unique_handle { public: - using shared_handle::shared_handle; + using unique_handle::unique_handle; /** * @brief Append instruction to region @@ -1039,9 +1028,9 @@ inline inst make_num_subgroups(location const &loc = {}) { * * @return Instruction */ -inline inst make_parallel(region const &body, location const &loc = {}) { +inline inst make_parallel(region body, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_parallel_inst_create(&instr, body.get(), &loc), loc); + CHECK_STATUS_LOC(tinytc_parallel_inst_create(&instr, body.release(), &loc), loc); return inst(instr); } @@ -1187,10 +1176,10 @@ inline inst make_sum(transpose tA, bool atomic, value const &alpha, value const * @return Instruction */ inline inst make_for(value const &loop_var, value const &from, value const &to, value const &step, - region const &body, location const &loc = {}) { + region body, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_for_inst_create(&instr, loop_var.get(), from.get(), to.get(), - step.get(), body.get(), &loc), + step.get(), body.release(), &loc), loc); return inst(instr); } @@ -1206,12 +1195,12 @@ inline inst make_for(value const &loop_var, value const &from, value const &to, * * @return Instruction */ -inline inst make_foreach(value const &loop_var, value const &from, value const &to, - region const &body, location const &loc = {}) { +inline inst make_foreach(value const &loop_var, value const &from, value const &to, region body, + location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC( - tinytc_foreach_inst_create(&instr, loop_var.get(), from.get(), to.get(), body.get(), &loc), - loc); + CHECK_STATUS_LOC(tinytc_foreach_inst_create(&instr, loop_var.get(), from.get(), to.get(), + body.release(), &loc), + loc); return inst(instr); } @@ -1226,7 +1215,7 @@ inline inst make_foreach(value const &loop_var, value const &from, value const & * * @return Instruction */ -inline inst make_if(value const &condition, region const &then, region const &otherwise = region{}, +inline inst make_if(value const &condition, region then, region otherwise = region{}, std::vector const &return_type_list = {}, location const &loc = {}) { tinytc_inst_t instr; @@ -1239,8 +1228,8 @@ inline inst make_if(value const &condition, region const &then, region const &ot for (auto const &rt : return_type_list) { rl_vec.emplace_back(static_cast(rt)); } - CHECK_STATUS_LOC(tinytc_if_inst_create(&instr, condition.get(), then.get(), otherwise.get(), - len, rl_vec.data(), &loc), + CHECK_STATUS_LOC(tinytc_if_inst_create(&instr, condition.get(), then.release(), + otherwise.release(), len, rl_vec.data(), &loc), loc); return inst(instr); } @@ -1271,28 +1260,17 @@ inline inst make_yield(std::vector const &yield_list, location const &loc //////////////////////////// namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_func_t handle) -> tinytc_status_t { - return tinytc_func_retain(handle); - } - static auto release(tinytc_func_t handle) -> tinytc_status_t { - return tinytc_func_release(handle); - } +template <> struct unique_handle_traits { + static void destroy(tinytc_func_t handle) { return tinytc_func_destroy(handle); } }; } // namespace internal //! @brief Reference-counting wrapper for tinytc_func_t -class func : public shared_handle { +class func : public unique_handle { public: - using shared_handle::shared_handle; + using unique_handle::unique_handle; }; -namespace internal { -//! Is reinterpret_cast(&f) allowed, where f has type func -constexpr bool func_reinterpret_allowed = - std::is_standard_layout_v && sizeof(func) == sizeof(tinytc_func_t); -} // namespace internal - /** * @brief Make function * @@ -1303,7 +1281,7 @@ constexpr bool func_reinterpret_allowed = * * @return Function */ -inline func make_function(char const *name, std::vector &arg_list, region const &body, +inline func make_function(char const *name, std::vector &arg_list, region body, location const &loc = {}) { static_assert(internal::value_reinterpret_allowed); tinytc_func_t fun; @@ -1312,7 +1290,7 @@ inline func make_function(char const *name, std::vector &arg_list, region throw std::out_of_range("argument list too long"); } tinytc_value_t *al = reinterpret_cast(arg_list.data()); - CHECK_STATUS_LOC(tinytc_function_create(&fun, name, len, al, body.get(), &loc), loc); + CHECK_STATUS_LOC(tinytc_function_create(&fun, name, len, al, body.release(), &loc), loc); return func(fun); } @@ -1360,6 +1338,15 @@ class prog : public shared_handle { public: using shared_handle::shared_handle; + /** + * @brief Append function to program + * + * @param fun function + */ + inline void add_function(func fun) { + CHECK_STATUS(tinytc_prog_add_function(get(), fun.release())); + } + /** * @brief Dump program to stderr */ @@ -1387,20 +1374,13 @@ class prog : public shared_handle { /** * @brief Make program * - * @param fun_list Vector of functions * @param loc Source code location * * @return Program */ -inline prog make_program(std::vector &fun_list, location const &loc = {}) { +inline prog make_program(location const &loc = {}) { tinytc_prog_t prg; - static_assert(internal::func_reinterpret_allowed); - auto len = fun_list.size(); - if (len > std::numeric_limits::max()) { - throw std::out_of_range("function list too long"); - } - tinytc_func_t *fl = reinterpret_cast(fun_list.data()); - CHECK_STATUS_LOC(tinytc_program_create(&prg, len, fl, &loc), loc); + CHECK_STATUS_LOC(tinytc_program_create(&prg, &loc), loc); return prog{prg}; } @@ -1591,7 +1571,8 @@ class function_builder { * * @param name Function name */ - inline function_builder(std::string name) : name_(std::move(name)), body_{nullptr} {} + inline function_builder(std::string name, location const &loc = {}) + : name_(std::move(name)), body_{nullptr}, loc_(loc) {} /** * @brief Returns built product @@ -1600,8 +1581,8 @@ class function_builder { * * @return Function */ - inline func get_product(location const &loc = {}) { - auto fun = make_function(name_.c_str(), arguments_, body_, loc); + inline func get_product() && { + auto fun = make_function(name_.c_str(), arguments_, std::move(body_), loc_); if (x_ > 0 && y_ > 0) { set_work_group_size(fun, x_, y_); } @@ -1663,6 +1644,7 @@ class function_builder { private: std::string name_; region body_; + location loc_; std::vector arguments_; std::int32_t x_ = 0, y_ = 0, sgs_ = 0; }; @@ -1670,6 +1652,13 @@ class function_builder { //! Builder for programs class program_builder { public: + /** + * @brief ctor + * + * @param loc Source code location + */ + program_builder(location const &loc = {}) : prg_{make_program(loc)} {} + /** * @brief create function \@name with functor f(function_builder&) -> void * @@ -1679,16 +1668,16 @@ class program_builder { * @param loc Source code location */ template void create(std::string name, F &&f, location const &loc = {}) { - auto fb = function_builder(std::move(name)); + auto fb = function_builder(std::move(name), loc); f(fb); - add(fb.get_product(loc)); + add(std::move(fb).get_product()); } /** * @brief Add function * * @param f function */ - inline void add(func f) { functions_.emplace_back(std::move(f)); } + inline void add(func f) { prg_.add_function(std::move(f)); } /** * @brief Returns built product * @@ -1696,10 +1685,10 @@ class program_builder { * * @return Program */ - inline prog get_product(location const &loc = {}) { return make_program(functions_, loc); } + inline prog get_product() && { return std::move(prg_); } private: - std::vector functions_; + prog prg_; }; //////////////////////////// diff --git a/src/func.cpp b/src/func.cpp index 27013e46..14f04f13 100644 --- a/src/func.cpp +++ b/src/func.cpp @@ -31,8 +31,8 @@ tinytc_status_t tinytc_function_create(tinytc_func_t *fun, char const *name, uin for (uint32_t i = 0; i < arg_list_size; ++i) { arg_vec.emplace_back(value(arg_list[i], true)); } - *fun = std::make_unique(std::string(name), std::move(arg_vec), - region{body, true}, get_optional(loc)) + *fun = std::make_unique(std::string(name), std::move(arg_vec), region{body}, + get_optional(loc)) .release(); }); } @@ -45,22 +45,5 @@ tinytc_status_t tinytc_function_set_subgroup_size(tinytc_func_t fun, int32_t sgs return exception_to_status_code([&] { fun->subgroup_size(sgs); }); } -tinytc_status_t tinytc_func_release(tinytc_func_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; -} - -tinytc_status_t tinytc_func_retain(tinytc_func_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - obj->inc_ref(); - return tinytc_status_success; -} +void tinytc_func_destroy(tinytc_func_t obj) { delete obj; } } diff --git a/src/inst.cpp b/src/inst.cpp index 4e7c0829..334e8616 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -309,7 +309,7 @@ tinytc_status_t tinytc_parallel_inst_create(tinytc_inst_t *instr, tinytc_region_ return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(region(body, true), get_optional(loc)).release(); + *instr = std::make_unique(region{body}, get_optional(loc)).release(); }); } @@ -415,7 +415,7 @@ tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_value_t loop return exception_to_status_code([&] { *instr = std::make_unique(value(loop_var, true), value(from, true), value(to, true), - value(step, true), region(body, true), get_optional(loc)) + value(step, true), region{body}, get_optional(loc)) .release(); }); } @@ -428,10 +428,9 @@ tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, tinytc_value_t return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = - std::make_unique(value(loop_var, true), value(from, true), - value(to, true), region(body, true), get_optional(loc)) - .release(); + *instr = std::make_unique(value(loop_var, true), value(from, true), + value(to, true), region{body}, get_optional(loc)) + .release(); }); } @@ -450,10 +449,9 @@ tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condi for (uint32_t i = 0; i < return_type_list_size; ++i) { rt.emplace_back(enum_cast(return_type_list[i])); } - *instr = - std::make_unique(value(condition, true), region(then, true), - region(otherwise, true), std::move(rt), get_optional(loc)) - .release(); + *instr = std::make_unique(value(condition, true), region{then}, region{otherwise}, + std::move(rt), get_optional(loc)) + .release(); }); } diff --git a/src/node/function_node.hpp b/src/node/function_node.hpp index 85408955..d344f03a 100644 --- a/src/node/function_node.hpp +++ b/src/node/function_node.hpp @@ -6,7 +6,6 @@ #include "location.hpp" #include "node/region_node.hpp" -#include "reference_counted.hpp" #include "support/util.hpp" #include "tinytc/tinytc.hpp" @@ -21,7 +20,7 @@ using value_range = iterator_range_wrapper; using const_value_range = iterator_range_wrapper; } // namespace tinytc -struct tinytc_func final : tinytc::reference_counted { +struct tinytc_func final { public: inline tinytc_func(std::string name, std::vector args, tinytc::region body, tinytc::location const &lc = {}) @@ -52,7 +51,7 @@ struct tinytc_func final : tinytc::reference_counted { } inline auto name() const -> std::string_view { return name_; } - inline auto body() const -> tinytc::region const & { return body_; } + inline auto body() const -> tinytc_region & { return *body_; } inline auto work_group_size() const -> std::array { return work_group_size_; } inline void work_group_size(std::array const &work_group_size) { diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 90a510f2..8dc64f35 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -5,7 +5,6 @@ #define INST_NODE_20230327_HPP #include "error.hpp" -#include "reference_counted.hpp" #include "support/ilist.hpp" #include "support/type_list.hpp" #include "support/util.hpp" @@ -86,7 +85,7 @@ using const_region_range = iterator_range_wrapper; } // namespace tinytc -struct tinytc_inst : tinytc::ilist_node, tinytc::reference_counted { +struct tinytc_inst : tinytc::ilist_node { public: using leaves = tinytc::inst_nodes; @@ -137,14 +136,14 @@ struct tinytc_inst : tinytc::ilist_node, tinytc::reference_counted inline auto child_regions_begin() -> tinytc::region * { return child_regions_begin_; } inline auto child_regions_end() -> tinytc::region * { return child_regions_end_; } inline auto child_regions() -> tinytc::region_range { - return tinytc::region_range{child_regions_begin_, child_regions_end_}; + return tinytc::region_range{child_regions_begin(), child_regions_end()}; } inline auto child_regions_begin() const -> tinytc::region const * { return child_regions_begin_; } inline auto child_regions_end() const -> tinytc::region const * { return child_regions_end_; } inline auto child_regions() const -> tinytc::const_region_range { - return tinytc::const_region_range{child_regions_begin_, child_regions_end_}; + return tinytc::const_region_range{child_regions_begin(), child_regions_end()}; } inline auto child_region(std::size_t pos) -> tinytc::region & { return child_regions_begin_[pos]; @@ -241,7 +240,7 @@ template class object_container { throw internal_compiler_error(); } } - inline auto get() -> T * { + auto get() -> T * { if constexpr (NumObjects == 0) { return nullptr; } @@ -338,7 +337,8 @@ class loop_inst : public standard_inst<4, 0, 1> { inline auto from() const -> value const & { return op(op_from); } inline auto to() const -> value const & { return op(op_to); } inline auto step() const -> value const & { return op(op_step); } - inline auto body() const -> region const & { return child_region(0); } + inline auto body() -> tinytc_region & { return *child_region(0); } + inline auto body() const -> tinytc_region const & { return *child_region(0); } }; class alloca_inst : public standard_inst<0, 1> { @@ -565,8 +565,13 @@ class if_inst : public standard_inst<1, dynamic, 2> { if_inst(value condition, region then, region otherwise = {}, std::vector const &return_types = {}, location const &lc = {}); inline auto condition() const -> value const & { return op(0); } - inline auto then() const -> region const & { return child_region(child_region_then); } - inline auto otherwise() const -> region const & { return child_region(child_region_otherwise); } + inline auto then() -> tinytc_region & { return *child_region(child_region_then); } + inline auto then() const -> tinytc_region const & { return *child_region(child_region_then); } + inline auto has_otherwise() const -> bool { return bool(child_region(child_region_otherwise)); } + inline auto otherwise() -> tinytc_region & { return *child_region(child_region_otherwise); } + inline auto otherwise() const -> tinytc_region const & { + return *child_region(child_region_otherwise); + } }; class num_subgroups_inst : public standard_inst<0, 1> { @@ -583,7 +588,8 @@ class parallel_inst : public standard_inst<0, 0, 1> { inline static bool classof(inst_node const &i) { return i.type_id() == IK::parallel; } parallel_inst(region body, location const &lc = {}); - inline auto body() const -> region const & { return child_region(0); } + inline auto body() -> tinytc_region & { return *child_region(0); } + inline auto body() const -> tinytc_region const & { return *child_region(0); } }; class size_inst : public standard_inst<1, 1> { diff --git a/src/node/program_node.hpp b/src/node/program_node.hpp index 42c93bd6..49fa0029 100644 --- a/src/node/program_node.hpp +++ b/src/node/program_node.hpp @@ -20,10 +20,7 @@ using const_func_range = iterator_range_wrapper; struct tinytc_prog final : tinytc::reference_counted { public: - inline tinytc_prog(std::vector funcs, tinytc::location const &lc = {}) - : funcs_(std::move(funcs)) { - loc(lc); - } + inline tinytc_prog(tinytc::location const &lc = {}) { loc(lc); } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } @@ -42,6 +39,7 @@ struct tinytc_prog final : tinytc::reference_counted { inline auto functions() const -> tinytc::const_func_range { return tinytc::const_func_range{begin(), end()}; } + inline void push_back(tinytc::func fun) { funcs_.push_back(std::move(fun)); } private: std::vector funcs_; diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp index 1ae7a14d..04f8d075 100644 --- a/src/node/region_node.hpp +++ b/src/node/region_node.hpp @@ -5,7 +5,6 @@ #define REGION_NODE_20230908_HPP #include "node/inst_node.hpp" -#include "reference_counted.hpp" #include "support/ilist.hpp" #include "support/util.hpp" #include "tinytc/tinytc.hpp" @@ -21,7 +20,7 @@ enum class region_kind { mixed, collective, spmd }; } // namespace tinytc -struct tinytc_region final : tinytc::reference_counted { +struct tinytc_region final { public: using iterator = tinytc::ilist::iterator; using const_iterator = tinytc::ilist::const_iterator; @@ -59,7 +58,14 @@ struct tinytc_region final : tinytc::reference_counted { }; namespace tinytc { + using region_node = ::tinytc_region; + +template <> struct ilist_traits { + static void on_insert(region_node *) {} + static void on_erase(region_node *node) { tinytc_region_destroy(node); } +}; + } // namespace tinytc #endif // REGION_NODE_20230908_HPP diff --git a/src/parser/parse_context.cpp b/src/parser/parse_context.cpp index 5824dbf7..97978947 100644 --- a/src/parser/parse_context.cpp +++ b/src/parser/parse_context.cpp @@ -36,22 +36,6 @@ value parse_context::val(std::string const &id, location const &l) { throw parser::syntax_error(l, "Undefined identifier %" + id); } -void parse_context::add_function(std::string const &id, func fn) { - if (auto other = function_map_.find(id); other != function_map_.end()) { - auto oss = std::ostringstream{}; - oss << "Identifier @" << id << " was already used at " << other->second->loc(); - throw parser::syntax_error(fn->loc(), oss.str()); - } - function_map_[id] = std::move(fn); -} - -func parse_context::get_function(std::string const &id, location const &l) { - if (auto j = function_map_.find(id); j != function_map_.end()) { - return j->second; - } - throw parser::syntax_error(l, "Undefined identifier @" + id); -} - void parse_context::add_error(location const &loc, std::string const &what) { errors_.emplace_back(std::make_pair(loc, what)); } diff --git a/src/parser/parse_context.hpp b/src/parser/parse_context.hpp index 1a213aaf..d6f8a25d 100644 --- a/src/parser/parse_context.hpp +++ b/src/parser/parse_context.hpp @@ -25,9 +25,6 @@ class parse_context { void val(std::string const &id, value val, location const &l); value val(std::string const &id, location const &l); - void add_function(std::string const &id, func fn); - func get_function(std::string const &id, location const &l); - void add_error(location const &loc, std::string const &what); inline auto errors() const -> std::vector> const & { @@ -36,7 +33,6 @@ class parse_context { private: std::vector> id_map_; - std::unordered_map function_map_; prog program_; std::vector> errors_; }; diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 159519c3..444d1bad 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -230,9 +230,12 @@ %% prog: func_list { - auto p = prog { std::make_unique(std::move($func_list), @prog).release() }; + auto p = prog { std::make_unique(@prog).release() }; ctx.program(p); $$ = std::move(p); + for (auto& f : $func_list) { + $$.add_function(std::move(f)); + } } ; @@ -253,7 +256,6 @@ func: attr(*func_node); } $func = func{func_node}; - ctx.add_function($GLOBAL_IDENTIFIER, $func); ctx.pop_scope(); } ; diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index dd01a219..0da096a1 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -782,7 +782,7 @@ std::vector convert_to_opencl_pass::operator()(for_inst const &p) { auto start = clir::declaration_assignment(std::move(lv_ty), lv, visit(*this, *p.from())); auto condition = lv < visit(*this, *p.to()); auto step = p.step() ? clir::add_into(lv, visit(*this, *p.step())) : ++lv; - auto body = run_on_region(*p.body()); + auto body = run_on_region(p.body()); clinst.emplace_back(clir::stmt(std::make_shared( std::move(start), std::move(condition), std::move(step), std::move(body)))); @@ -803,7 +803,7 @@ std::vector convert_to_opencl_pass::operator()(foreach_inst const &p bb, trip_count, core_cfg_.subgroup_size, tiling_.m_tiles() * tiling_.n_tiles(), std::move(sg), [&](clir::block_builder &bb, clir::expr block, bool, clir::expr) { bb.add(clir::declaration_assignment(lv_ty, lv, std::move(block) + m + from)); - bb.add(run_on_region(*p.body())); + bb.add(run_on_region(p.body())); }); return {bb.get_product()}; } @@ -866,9 +866,9 @@ std::vector convert_to_opencl_pass::operator()(if_inst const &in) { yielded_vars_.back().emplace_back(std::move(v)); } auto ib = clir::if_selection_builder(visit(*this, *in.condition())); - ib.set_then(run_on_region(*in.then())); - if (in.otherwise()) { - ib.set_otherwise(run_on_region(*in.otherwise())); + ib.set_then(run_on_region(in.then())); + if (in.has_otherwise()) { + ib.set_otherwise(run_on_region(in.otherwise())); } yielded_vars_.pop_back(); clinst.emplace_back(ib.get_product()); @@ -883,7 +883,7 @@ std::vector convert_to_opencl_pass::operator()(num_subgroups_inst co } std::vector convert_to_opencl_pass::operator()(parallel_inst const &p) { - return {run_on_region(*p.body())}; + return {run_on_region(p.body())}; } std::vector convert_to_opencl_pass::operator()(size_inst const &s) { @@ -1088,7 +1088,7 @@ std::vector convert_to_opencl_pass::operator()(yield_inst const &in) } /* Region nodes */ -clir::stmt convert_to_opencl_pass::run_on_region(region_node ®) { +clir::stmt convert_to_opencl_pass::run_on_region(region_node const ®) { declared_vars_.push_back({}); auto bb = clir::block_builder{}; for (auto &s : reg.insts()) { @@ -1101,7 +1101,7 @@ clir::stmt convert_to_opencl_pass::run_on_region(region_node ®) { } /* Function nodes */ -auto convert_to_opencl_pass::run_on_function(function_node &fn) -> clir::func { +auto convert_to_opencl_pass::run_on_function(function_node const &fn) -> clir::func { stack_high_water_mark_ = 0; auto const subgroup_size = fn.subgroup_size(); try { @@ -1140,7 +1140,7 @@ auto convert_to_opencl_pass::run_on_function(function_node &fn) -> clir::func { fb.attribute(clir::reqd_work_group_size(work_group_size[0], work_group_size[1], 1)); fb.attribute(clir::intel_reqd_sub_group_size(subgroup_size)); - auto body = run_on_region(*fn.body()); + auto body = run_on_region(fn.body()); if (stack_high_water_mark_ > 0) { auto bb = dynamic_cast(body.get()); @@ -1158,7 +1158,7 @@ auto convert_to_opencl_pass::run_on_function(function_node &fn) -> clir::func { } /* Program nodes */ -auto convert_to_opencl_pass::run_on_program(program_node &p) -> clir::prog { +auto convert_to_opencl_pass::run_on_program(program_node const &p) -> clir::prog { reserved_names_.clear(); for (auto const &fn : p.functions()) { reserved_names_.insert(std::string(fn->name())); diff --git a/src/pass/convert_to_opencl.hpp b/src/pass/convert_to_opencl.hpp index 82899d75..755987f3 100644 --- a/src/pass/convert_to_opencl.hpp +++ b/src/pass/convert_to_opencl.hpp @@ -102,11 +102,11 @@ class convert_to_opencl_pass { std::vector operator()(sum_inst const &s); std::vector operator()(yield_inst const &in); - auto run_on_program(program_node &p) -> clir::prog; + auto run_on_program(program_node const &p) -> clir::prog; private: - auto run_on_region(region_node ®) -> clir::stmt; - auto run_on_function(function_node &fn) -> clir::func; + auto run_on_region(region_node const ®) -> clir::stmt; + auto run_on_function(function_node const &fn) -> clir::func; auto get_dope_vector(value_node *v) -> dope_vector &; void set_dope_vector(value_node *v, dope_vector dv); diff --git a/src/pass/dump_cfg.cpp b/src/pass/dump_cfg.cpp index 269629cd..4d193129 100644 --- a/src/pass/dump_cfg.cpp +++ b/src/pass/dump_cfg.cpp @@ -18,7 +18,7 @@ void dump_cfg_pass::run_on_function(function_node const &fn) { *os_ << "digraph " << fn.name() << " {" << std::endl; - auto cfg = get_control_flow_graph(*fn.body()); + auto cfg = get_control_flow_graph(fn.body()); auto q = cfg.node_queue(); for (; !q.empty(); q.pop()) { auto &node = q.front(); diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index e9edbe45..a9993d22 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -246,7 +246,7 @@ void dump_ir_pass::operator()(for_inst const &p) { *os_ << " : "; visit(*this, *p.loop_var()->ty()); *os_ << " "; - dump_region(*p.body()); + dump_region(p.body()); } void dump_ir_pass::operator()(foreach_inst const &p) { @@ -259,7 +259,7 @@ void dump_ir_pass::operator()(foreach_inst const &p) { *os_ << " : "; visit(*this, *p.loop_var()->ty()); *os_ << " "; - dump_region(*p.body()); + dump_region(p.body()); } void dump_ir_pass::operator()(hadamard_inst const &g) { @@ -271,10 +271,10 @@ void dump_ir_pass::operator()(if_inst const &in) { *os_ << "if "; visit(*this, *in.condition()); *os_ << " "; - dump_region(*in.then()); - if (in.otherwise()) { + dump_region(in.then()); + if (in.has_otherwise()) { *os_ << " else "; - dump_region(*in.otherwise()); + dump_region(in.otherwise()); } } @@ -285,7 +285,7 @@ void dump_ir_pass::operator()(num_subgroups_inst const &sg) { void dump_ir_pass::operator()(parallel_inst const &p) { *os_ << "parallel "; - dump_region(*p.body()); + dump_region(p.body()); } void dump_ir_pass::operator()(size_inst const &s) { @@ -401,7 +401,7 @@ void dump_ir_pass::run_on_function(function_node const &fn) { if (wgs[0] != 0 && wgs[1] != 0) { *os_ << "work_group_size(" << wgs[0] << "," << wgs[1] << ") "; } - dump_region(*fn.body()); + dump_region(fn.body()); *os_ << std::endl; } diff --git a/src/pass/insert_barrier.cpp b/src/pass/insert_barrier.cpp index 945bdd7a..3fdb96d3 100644 --- a/src/pass/insert_barrier.cpp +++ b/src/pass/insert_barrier.cpp @@ -234,7 +234,7 @@ auto insert_barrier_pass::run_on_region(region_node ®, aa_results const &aa, /* Function nodes */ void insert_barrier_pass::run_on_function(function_node &fn) { auto aa = alias_analysis{}.run_on_function(fn); - run_on_region(*fn.body(), aa); + run_on_region(fn.body(), aa); } } // namespace tinytc diff --git a/src/pass/insert_lifetime_stop.cpp b/src/pass/insert_lifetime_stop.cpp index 494e374b..60cddadf 100644 --- a/src/pass/insert_lifetime_stop.cpp +++ b/src/pass/insert_lifetime_stop.cpp @@ -63,7 +63,7 @@ auto insert_lifetime_stop_pass::run_on_region(region_node ®, aa_results const void insert_lifetime_stop_pass::run_on_function(function_node &fn) { auto aa = alias_analysis{}.run_on_function(fn); - run_on_region(*fn.body(), aa); + run_on_region(fn.body(), aa); } } // namespace tinytc diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index e1c59181..e8eb357d 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -52,7 +52,7 @@ inst lower_linalg_pass::operator()(loop_inst &p) { inst lower_linalg_pass::operator()(if_inst &in) { visit(*this, *in.then()); - if (in.otherwise()) { + if (in.has_otherwise()) { visit(*this, *in.otherwise()); } return inst{nullptr}; diff --git a/src/prog.cpp b/src/prog.cpp index 0f34e076..1b4de169 100644 --- a/src/prog.cpp +++ b/src/prog.cpp @@ -25,19 +25,19 @@ using namespace tinytc; extern "C" { -tinytc_status_t tinytc_program_create(tinytc_prog_t *prg, uint32_t fun_list_size, - tinytc_func_t *fun_list, const tinytc_location_t *loc) { - if (prg == nullptr || (fun_list_size > 0 && fun_list == nullptr)) { +tinytc_status_t tinytc_program_create(tinytc_prog_t *prg, const tinytc_location_t *loc) { + if (prg == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { - auto fun_vec = std::vector(); - fun_vec.reserve(fun_list_size); - for (uint32_t i = 0; i < fun_list_size; ++i) { - fun_vec.emplace_back(func(fun_list[i], true)); - } - *prg = std::make_unique(std::move(fun_vec), get_optional(loc)).release(); - }); + return exception_to_status_code( + [&] { *prg = std::make_unique(get_optional(loc)).release(); }); +} + +tinytc_status_t tinytc_prog_add_function(tinytc_prog_t prg, tinytc_func_t fun) { + if (prg == nullptr || fun == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { prg->push_back(func{fun}); }); } tinytc_status_t tinytc_prog_release(tinytc_prog_t obj) { diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index a2c4895b..b65414b6 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -118,14 +118,17 @@ tinytc_recipe_small_gemm_batched_create(tinytc_recipe_t *recipe, const_tinytc_co }, my_loc()); }; - auto pb = program_builder{}; - pb.create( - small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm), - [&](function_builder &fb) { kernel(fb, true); }, my_loc()); - pb.create( - small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm_beta0), - [&](function_builder &fb) { kernel(fb, false); }, my_loc()); - auto p = pb.get_product(my_loc()); + auto p = [&] { + auto pb = program_builder{my_loc()}; + pb.create( + small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm), + [&](function_builder &fb) { kernel(fb, true); }, my_loc()); + pb.create( + small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm_beta0), + [&](function_builder &fb) { kernel(fb, false); }, my_loc()); + + return std::move(pb).get_product(); + }(); tinytc_source_t src; CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, p.get(), info, ctx)); *recipe = std::make_unique(std::move(p), source(src), ty_) diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 5ca0bf3b..43464360 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -157,15 +157,16 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( fb.body([&](region_builder &bb) { body(bb, alpha, A, B, beta, C); }, my_loc()); }; - auto pb = program_builder{}; - pb.create( - tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm), - [&](function_builder &fb) { kernel(fb, true); }, my_loc()); - pb.create( - tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm_beta0), - [&](function_builder &fb) { kernel(fb, false); }, my_loc()); - - auto p = pb.get_product(my_loc()); + auto p = [&] { + auto pb = program_builder{my_loc()}; + pb.create( + tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm), + [&](function_builder &fb) { kernel(fb, true); }, my_loc()); + pb.create( + tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm_beta0), + [&](function_builder &fb) { kernel(fb, false); }, my_loc()); + return std::move(pb).get_product(); + }(); tinytc_source_t src; CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, p.get(), info, ctx)); *recipe = std::make_unique(std::move(p), source(src), ty_, M, diff --git a/src/region.cpp b/src/region.cpp index 26869adc..043a5bb3 100644 --- a/src/region.cpp +++ b/src/region.cpp @@ -32,22 +32,5 @@ tinytc_status_t tinytc_region_add_instruction(tinytc_region_t reg, tinytc_inst_t return exception_to_status_code([&] { reg->push_back(instruction); }); } -tinytc_status_t tinytc_region_release(tinytc_region_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; -} - -tinytc_status_t tinytc_region_retain(tinytc_region_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - obj->inc_ref(); - return tinytc_status_success; -} +void tinytc_region_destroy(tinytc_region_t obj) { delete obj; } } diff --git a/src/support/walk.hpp b/src/support/walk.hpp index 765dab47..c959a380 100644 --- a/src/support/walk.hpp +++ b/src/support/walk.hpp @@ -43,16 +43,17 @@ template void walk(inst_node &i, std::function void walk(inst_node &i, std::function callback) { +template +void walk(inst_node &i, std::function callback) { for (auto ® : i.child_regions()) { if constexpr (Order == walk_order::pre_order) { - callback(reg); + callback(*reg); } for (auto &j : *reg) { walk(j, callback); } if constexpr (Order == walk_order::post_order) { - callback(reg); + callback(*reg); } } } @@ -61,14 +62,14 @@ void walk(inst_node &i, std::function void walk(function_node &fn, std::function callback) { - for (auto &i : *fn.body()) { + for (auto &i : fn.body()) { walk(i, callback); } } inline void walk(function_node &fn, std::function callback) { - for (auto &i : *fn.body()) { + for (auto &i : fn.body()) { walk(i, callback); } } From c28326326b3dd673381f1a57d05d11f1b5210f89 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 24 Sep 2024 10:21:43 +0200 Subject: [PATCH 023/297] insert barrier test passes Signed-off-by: Carsten Uphoff --- src/CMakeLists.txt | 2 + src/analysis/cfg.cpp | 65 ++++++++++------- src/analysis/cfg.hpp | 12 ++- src/func.cpp | 2 +- src/node/function_node.hpp | 12 +-- src/node/inst_node.cpp | 13 ---- src/node/inst_node.hpp | 7 +- src/node/program_node.cpp | 12 +++ src/node/program_node.hpp | 31 ++++---- src/node/region_node.cpp | 14 ++++ src/node/region_node.hpp | 32 ++++---- src/parser/parser_impl.yy | 2 +- src/pass/insert_barrier.cpp | 117 +++++++++++++----------------- src/pass/insert_barrier.hpp | 7 +- src/pass/insert_lifetime_stop.cpp | 2 +- src/passes.hpp | 10 ++- src/prog.cpp | 2 +- src/region.cpp | 2 +- src/support/ilist_base.hpp | 42 +++++++---- test/opt/insert-barrier.ir | 2 +- test/opt/insert-lifetime-stop.ir | 2 - 21 files changed, 209 insertions(+), 181 deletions(-) create mode 100644 src/node/program_node.cpp create mode 100644 src/node/region_node.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 39061b60..87a4f4bb 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -33,6 +33,8 @@ set(SOURCES location.cpp node/data_type_node.cpp node/inst_node.cpp + node/region_node.cpp + node/program_node.cpp parser/parse_context.cpp parser.cpp pass/check_ir.cpp diff --git a/src/analysis/cfg.cpp b/src/analysis/cfg.cpp index ae9bde46..c335d655 100644 --- a/src/analysis/cfg.cpp +++ b/src/analysis/cfg.cpp @@ -8,6 +8,12 @@ namespace tinytc { +void control_flow_graph::insert_before(inst_node *before_inst, inst_node *new_inst) { + add_node(new_inst, adj_[before_inst].kind_max); + adj_[new_inst].pred = std::move(adj_[before_inst].pred); + add_edge(new_inst, before_inst); +} + auto control_flow_graph::node_queue() const -> std::queue { auto q = std::queue{}; for (auto &[key, neighbors] : adj_) { @@ -19,52 +25,55 @@ auto control_flow_graph::node_queue() const -> std::queue { auto get_control_flow_graph(region_node &topreg) -> control_flow_graph { auto cfg = control_flow_graph{}; - const auto add_region = [&cfg](region_node ®, - auto &add_region_ref) -> std::pair { + const auto add_region = + [&cfg](region_node ®, region_kind kind_max, + auto &add_region_ref) -> std::pair> { if (reg.empty()) { return {}; } - auto start = reg.begin().get(); - cfg.add_node(start); - auto pred_nodes = std::queue{}; - pred_nodes.push(start); - - for (auto it = ++reg.begin(); it != reg.end(); ++it) { - inst_node *node = it.get(); - cfg.add_node(node); - - for (; !pred_nodes.empty(); pred_nodes.pop()) { - cfg.add_edge(pred_nodes.front(), node); - } - - if (it->num_child_regions() > 0) { - for (auto &subreg : it->child_regions()) { - auto [substart, subexit] = add_region_ref(*subreg, add_region_ref); + const auto visit_inst = [&](inst_node *node) { + if (node->num_child_regions() > 0) { + for (auto &subreg : node->child_regions()) { + auto [substart, subexits] = + add_region_ref(*subreg, std::max(kind_max, subreg->kind()), add_region_ref); cfg.add_edge(node, substart); - if (isa(*it)) { - cfg.add_edge(subexit, node); + if (isa(*node)) { + for (; !subexits.empty(); subexits.pop()) { + cfg.add_edge(subexits.front(), node); + } pred_nodes.push(node); } else { - pred_nodes.push(subexit); + for (; !subexits.empty(); subexits.pop()) { + pred_nodes.push(subexits.front()); + } } } } else { pred_nodes.push(node); } - } + }; + + auto start = reg.begin().get(); + cfg.add_node(start, kind_max); + visit_inst(start); + + for (auto it = ++reg.begin(); it != reg.end(); ++it) { + inst_node *node = it.get(); + cfg.add_node(node, kind_max); + + for (; !pred_nodes.empty(); pred_nodes.pop()) { + cfg.add_edge(pred_nodes.front(), node); + } - // every region must have exactly one exit node and the exit node must be last - // @todo: NOT guaranteed for parallel_inst and function yet! - if (pred_nodes.size() != 1) { - throw internal_compiler_error{}; + visit_inst(node); } - return std::make_pair(std::move(start), std::move(pred_nodes.front())); + return std::make_pair(std::move(start), std::move(pred_nodes)); }; - add_region(topreg, add_region); + add_region(topreg, topreg.kind(), add_region); return cfg; } diff --git a/src/analysis/cfg.hpp b/src/analysis/cfg.hpp index 1ac7221b..bbd31ca0 100644 --- a/src/analysis/cfg.hpp +++ b/src/analysis/cfg.hpp @@ -8,6 +8,7 @@ #include "node/region_node.hpp" #include "support/util.hpp" +#include #include #include #include @@ -17,14 +18,20 @@ namespace tinytc { class control_flow_graph { public: - inline void add_node(inst_node *a) { adj_[a] = adjacency_list{}; } + inline void add_node(inst_node *a, region_kind kind_max) { + adj_[a] = adjacency_list{}; + adj_[a].kind_max = kind_max; + } inline void add_edge(inst_node *a, inst_node *b) { adj_[a].succ.push_back(b); - adj_[b].pred.push_back(b); + adj_[b].pred.push_back(a); } + void insert_before(inst_node *before_inst, inst_node *new_inst); auto node_queue() const -> std::queue; + inline auto kind_max(inst_node *a) -> region_kind { return adj_[a].kind_max; } + inline auto pred_begin(inst_node *a) { return adj_[a].pred.begin(); } inline auto pred_end(inst_node *a) { return adj_[a].pred.end(); } inline auto @@ -41,6 +48,7 @@ class control_flow_graph { private: struct adjacency_list { + region_kind kind_max = region_kind::mixed; std::vector pred; std::vector succ; }; diff --git a/src/func.cpp b/src/func.cpp index 14f04f13..20c4f98c 100644 --- a/src/func.cpp +++ b/src/func.cpp @@ -31,7 +31,7 @@ tinytc_status_t tinytc_function_create(tinytc_func_t *fun, char const *name, uin for (uint32_t i = 0; i < arg_list_size; ++i) { arg_vec.emplace_back(value(arg_list[i], true)); } - *fun = std::make_unique(std::string(name), std::move(arg_vec), region{body}, + *fun = std::make_unique(std::string(name), std::move(arg_vec), body, get_optional(loc)) .release(); }); diff --git a/src/node/function_node.hpp b/src/node/function_node.hpp index d344f03a..02c5db38 100644 --- a/src/node/function_node.hpp +++ b/src/node/function_node.hpp @@ -22,16 +22,16 @@ using const_value_range = iterator_range_wrapper; struct tinytc_func final { public: - inline tinytc_func(std::string name, std::vector args, tinytc::region body, - tinytc::location const &lc = {}) - : name_(std::move(name)), args_(std::move(args)), body_(std::move(body)), + inline tinytc_func(std::string name, std::vector args, tinytc_region_t body, + tinytc_location const &lc = {}) + : name_(std::move(name)), args_(std::move(args)), body_(tinytc::region{body}), work_group_size_{0, 0}, subgroup_size_{0} { loc(lc); body_->kind(tinytc::region_kind::collective); } - inline auto loc() const noexcept -> tinytc::location const & { return loc_; } - inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } + inline auto loc() const noexcept -> tinytc_location const & { return loc_; } + inline void loc(tinytc_location const &loc) noexcept { loc_ = loc; } inline auto arg_begin() -> tinytc::value * { return args_.size() > 0 ? args_.data() : nullptr; } inline auto arg_end() -> tinytc::value * { @@ -66,7 +66,7 @@ struct tinytc_func final { tinytc::region body_; std::array work_group_size_; std::int32_t subgroup_size_; - tinytc::location loc_; + tinytc_location loc_; }; namespace tinytc { diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index dd8659a2..b90dc4c0 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -75,11 +75,6 @@ loop_inst::loop_inst(IK tid, value loop_var0, value from0, value to0, value step if (lvt->ty() != fromt->ty() || lvt->ty() != tot->ty() || !step_ok) { throw compilation_error(loc(), status::ir_scalar_mismatch); } - - region_node &body = *child_region(0); - if (body.empty() || !isa(*(--body.end()))) { - body.insert(body.end(), std::make_unique(std::vector{}, lc).release()); - } } alloca_inst::alloca_inst(data_type ty, location const &lc) @@ -488,14 +483,6 @@ if_inst::if_inst(value condition, region then0, region otherwise0, for (std::size_t i = 0; i < return_types.size(); ++i) { result(i) = make_value(return_types[i]); } - - for (std::int64_t i = 0; i < num_child_regions(); ++i) { - region_node &body = *child_region(i); - if (body.empty() || !isa(*(--body.end()))) { - body.insert(body.end(), - std::make_unique(std::vector{}, lc).release()); - } - } } parallel_inst::parallel_inst(region body, location const &lc) : standard_inst{IK::parallel} { diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 8dc64f35..56fa5fe5 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -85,7 +85,7 @@ using const_region_range = iterator_range_wrapper; } // namespace tinytc -struct tinytc_inst : tinytc::ilist_node { +struct tinytc_inst : tinytc::ilist_node_with_parent { public: using leaves = tinytc::inst_nodes; @@ -226,11 +226,6 @@ namespace tinytc { using inst_node = ::tinytc_inst; -template <> struct ilist_traits { - static void on_insert(inst_node *) {} - static void on_erase(inst_node *node) { tinytc_inst_destroy(node); } -}; - template class object_container { public: object_container(std::int64_t num_objects) { diff --git a/src/node/program_node.cpp b/src/node/program_node.cpp new file mode 100644 index 00000000..a8479b31 --- /dev/null +++ b/src/node/program_node.cpp @@ -0,0 +1,12 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "node/program_node.hpp" +#include "node/function_node.hpp" + +tinytc_prog::~tinytc_prog() { + for (auto &f : functions()) { + tinytc_func_destroy(f); + } +} + diff --git a/src/node/program_node.hpp b/src/node/program_node.hpp index 49fa0029..aebb3480 100644 --- a/src/node/program_node.hpp +++ b/src/node/program_node.hpp @@ -7,43 +7,44 @@ #include "location.hpp" #include "reference_counted.hpp" #include "support/util.hpp" -#include "tinytc/tinytc.hpp" #include #include #include namespace tinytc { -using func_range = iterator_range_wrapper; -using const_func_range = iterator_range_wrapper; +using func_range = iterator_range_wrapper; +using const_func_range = iterator_range_wrapper; } // namespace tinytc struct tinytc_prog final : tinytc::reference_counted { public: - inline tinytc_prog(tinytc::location const &lc = {}) { loc(lc); } + inline tinytc_prog(tinytc_location const &lc = {}) { loc(lc); } + ~tinytc_prog(); - inline auto loc() const noexcept -> tinytc::location const & { return loc_; } - inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } + inline auto loc() const noexcept -> tinytc_location const & { return loc_; } + inline void loc(tinytc_location const &loc) noexcept { loc_ = loc; } - inline auto begin() -> tinytc::func * { return funcs_.size() > 0 ? funcs_.data() : nullptr; } - inline auto end() -> tinytc::func * { + inline auto begin() -> tinytc_func_t * { return funcs_.size() > 0 ? funcs_.data() : nullptr; } + inline auto end() -> tinytc_func_t * { return funcs_.size() > 0 ? funcs_.data() + funcs_.size() : nullptr; } inline auto functions() -> tinytc::func_range { return tinytc::func_range{begin(), end()}; } - inline auto begin() const -> tinytc::func const * { - return funcs_.size() > 0 ? funcs_.data() : nullptr; + inline auto begin() const -> const_tinytc_func_t * { + return funcs_.size() > 0 ? const_cast(funcs_.data()) : nullptr; } - inline auto end() const -> tinytc::func const * { - return funcs_.size() > 0 ? funcs_.data() + funcs_.size() : nullptr; + inline auto end() const -> const_tinytc_func_t * { + return funcs_.size() > 0 ? const_cast(funcs_.data()) + funcs_.size() + : nullptr; } inline auto functions() const -> tinytc::const_func_range { return tinytc::const_func_range{begin(), end()}; } - inline void push_back(tinytc::func fun) { funcs_.push_back(std::move(fun)); } + inline void push_back(tinytc_func_t fun) { funcs_.push_back(fun); } private: - std::vector funcs_; - tinytc::location loc_; + std::vector funcs_; + tinytc_location loc_; }; namespace tinytc { diff --git a/src/node/region_node.cpp b/src/node/region_node.cpp new file mode 100644 index 00000000..5fd92be8 --- /dev/null +++ b/src/node/region_node.cpp @@ -0,0 +1,14 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "node/region_node.hpp" + +namespace tinytc { + +auto ilist_traits::get_parent_region() -> tinytc_region * { + return reinterpret_cast(reinterpret_cast(this) - + tinytc_region::inst_list_offset()); +} + +} // namespace tinytc + diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp index 04f8d075..7b35c22e 100644 --- a/src/node/region_node.hpp +++ b/src/node/region_node.hpp @@ -10,13 +10,20 @@ #include "tinytc/tinytc.hpp" #include +#include #include #include namespace tinytc { //! Instruction classification -enum class region_kind { mixed, collective, spmd }; +enum class region_kind { mixed = 0x0, collective = 0x1, spmd = 0x2 }; + +template <> struct ilist_traits { + auto get_parent_region() -> tinytc_region *; + void node_added(inst_node *node) { node->parent(get_parent_region()); } + void node_removed(inst_node *node) { tinytc_inst_destroy(node); } +}; } // namespace tinytc @@ -37,21 +44,19 @@ struct tinytc_region final { inline auto begin() -> iterator { return insts_.begin(); } inline auto end() -> iterator { return insts_.end(); } - inline auto insts() -> tinytc::iterator_range_wrapper { return {begin(), end()}; } + inline auto insts() -> tinytc::ilist & { return insts_; } inline auto begin() const -> const_iterator { return insts_.cbegin(); } inline auto end() const -> const_iterator { return insts_.cend(); } - inline auto insts() const -> tinytc::iterator_range_wrapper { - return {begin(), end()}; - } - inline void push_back(tinytc_inst_t i) { insts_.push_back(i); } - inline auto erase(iterator pos) -> iterator { return insts_.erase(pos); } - inline auto insert(iterator pos, tinytc_inst_t i) -> iterator { return insts_.insert(pos, i); } - inline auto insert_after(iterator pos, tinytc_inst_t i) -> iterator { - return insts_.insert_after(pos, i); - } + inline auto insts() const -> tinytc::ilist const & { return insts_; } inline auto empty() const -> bool { return insts_.empty(); } private: + static auto inst_list_offset() -> std::size_t { + static_assert(std::is_standard_layout_v, "offsetof not guaranteed to work"); + return offsetof(tinytc_region, insts_); + } + friend struct tinytc::ilist_traits; + tinytc::region_kind kind_; tinytc::ilist insts_; tinytc::location loc_; @@ -61,11 +66,6 @@ namespace tinytc { using region_node = ::tinytc_region; -template <> struct ilist_traits { - static void on_insert(region_node *) {} - static void on_erase(region_node *node) { tinytc_region_destroy(node); } -}; - } // namespace tinytc #endif // REGION_NODE_20230908_HPP diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 444d1bad..475fe571 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -250,7 +250,7 @@ func: auto loc = @FUNC; loc.end = @RPAREN.end; auto func_node = std::make_unique($GLOBAL_IDENTIFIER, std::move($arguments), - std::move($region), loc) + $region.release(), loc) .release(); for (auto &attr : $attributes) { attr(*func_node); diff --git a/src/pass/insert_barrier.cpp b/src/pass/insert_barrier.cpp index 3fdb96d3..67b1886b 100644 --- a/src/pass/insert_barrier.cpp +++ b/src/pass/insert_barrier.cpp @@ -31,6 +31,15 @@ auto intersects(std::unordered_set<::tinytc_value const *> const &a, return false; } +void insert_barrier_pass::reads_writes::clear() { + for (auto &rd : reads) { + rd.clear(); + } + for (auto &wr : writes) { + wr.clear(); + } +} + void insert_barrier_pass::reads_writes::clear(address_space as) { const auto space = address_space_to_index(as); reads[space].clear(); @@ -55,6 +64,12 @@ void insert_barrier_pass::reads_writes::merge(reads_writes &&other) { } } +void insert_barrier_pass::reads_writes::merge(address_space as, reads_writes const &other) { + const auto space = address_space_to_index(as); + reads[space].insert(other.reads[space].begin(), other.reads[space].end()); + writes[space].insert(other.writes[space].begin(), other.writes[space].end()); +} + void insert_barrier_pass::reads_writes::emplace_read(address_space as, ::tinytc_value const *val) { const auto space = address_space_to_index(as); reads[space].emplace(val); @@ -92,8 +107,7 @@ bool insert_barrier_pass::reads_writes::raw_war_or_waw(address_space as, reads_w return raw(as, rw, aa) || war(as, rw, aa) || waw(as, rw, aa); } -auto insert_barrier_pass::reads_writes::address_space_to_index(address_space as) const - -> std::size_t { +auto insert_barrier_pass::reads_writes::address_space_to_index(address_space as) -> std::size_t { for (std::size_t i = 0; i < address_spaces.size(); ++i) { if (as == address_spaces[i]) { return i; @@ -102,8 +116,7 @@ auto insert_barrier_pass::reads_writes::address_space_to_index(address_space as) throw internal_compiler_error{}; } -auto insert_barrier_pass::run_on_region(region_node ®, aa_results const &aa, - const bool insert_barriers) -> reads_writes { +void insert_barrier_pass::run_on_region(region_node ®, aa_results const &aa) { // irw = reads and writes invisible to other threads auto irw_in = std::unordered_map{}; auto irw_out = std::unordered_map{}; @@ -148,88 +161,56 @@ auto insert_barrier_pass::run_on_region(region_node ®, aa_results const &aa, auto n = q.front(); q.pop(); + const bool insert_barriers = cfg.kind_max(n) < region_kind::spmd; + auto &in = irw_in[n]; auto &out = irw_out[n]; + + in.clear(); for (auto &p : cfg.predecessors(n)) { in.merge(irw_out[p]); } auto out_size_before_update = get_cardinal(out); - out = get_rw(*n); - out.merge(in); - // out has changed, need to enqueue successors again - if (out_size_before_update != get_cardinal(out)) { - for (auto &s : cfg.successors(n)) { - q.push(s); - } - } - } -} -/*auto insert_barrier_pass::run_on_region(region_node ®, aa_results const &aa, - const bool insert_barriers) -> reads_writes { - auto invisible_rw = reads_writes{}; - for (auto it = reg.begin(); it != reg.end(); ++it) { - if (auto *barrier = dyn_cast(it->get()); insert_barriers && barrier) { + if (auto *barrier = dyn_cast(n); insert_barriers && barrier) { for (auto &as : reads_writes::address_spaces) { - if (barrier->has_fence(as)) { - invisible_rw.clear(as); + if (!barrier->has_fence(as)) { + out.merge(as, in); } } } else { - auto rw = reads_writes{}; - - for (auto &subreg : (*it)->child_regions()) { - const bool insert_barriers_sub = - insert_barriers && subreg->kind() != region_kind::spmd; - rw.merge(run_on_region(*subreg, aa, insert_barriers_sub)); - } + out = get_rw(*n); - auto const emplace_read = [&rw](value const &v) { - if (auto *m = dyn_cast(v->ty().get()); m) { - rw.emplace_read(m->addrspace(), v.get()); - } - }; - auto const emplace_write = [&rw](value const &v) { - if (auto *m = dyn_cast(v->ty().get()); m) { - rw.emplace_write(m->addrspace(), v.get()); - } - }; - visit(overloaded{[&](blas_a2_inst &in) { - emplace_read(in.A()); - emplace_write(in.B()); - }, - [&](blas_a3_inst &in) { - emplace_read(in.A()); - emplace_read(in.B()); - emplace_write(in.C()); - }, - [&](load_inst &in) { emplace_read(in.operand()); }, - [&](store_inst &in) { emplace_write(in.operand()); }, - [](inst_node &) {}}, - **it); - - if (insert_barriers) { - std::int32_t fence_flags = 0; - for (auto &as : reads_writes::address_spaces) { - if (invisible_rw.raw_war_or_waw(as, rw, aa)) { - fence_flags |= static_cast(as); - invisible_rw.clear(as); - } - } - if (fence_flags != 0) { - it = - reg.insert(it, inst{std::make_unique(fence_flags).release()}); - ++it; // skip over barrier + std::int32_t fence_flags = 0; + for (auto &as : reads_writes::address_spaces) { + if (insert_barriers && in.raw_war_or_waw(as, out, aa)) { + fence_flags |= static_cast(as); + } else { + out.merge(as, in); } } + if (fence_flags != 0) { + tinytc_region *subreg = n->parent(); + auto new_barrier = + subreg->insts() + .insert(n->iterator(), + std::make_unique(fence_flags).release()) + .get(); + // update cfg + cfg.insert_before(n, new_barrier); + q.push(new_barrier); + } + } - invisible_rw.merge(std::move(rw)); + // out has changed, need to enqueue successors again + if (out_size_before_update != get_cardinal(out)) { + for (auto &s : cfg.successors(n)) { + q.push(s); + } } } - - return invisible_rw; -}*/ +} /* Function nodes */ void insert_barrier_pass::run_on_function(function_node &fn) { diff --git a/src/pass/insert_barrier.hpp b/src/pass/insert_barrier.hpp index 3b028be0..60af109e 100644 --- a/src/pass/insert_barrier.hpp +++ b/src/pass/insert_barrier.hpp @@ -23,9 +23,11 @@ class insert_barrier_pass { constexpr static std::array address_spaces = {address_space::global, address_space::local}; + void clear(); void clear(address_space as); void merge(reads_writes const &other); void merge(reads_writes &&other); + void merge(address_space as, reads_writes const &other); void emplace_read(address_space as, ::tinytc_value const *val); void emplace_write(address_space as, ::tinytc_value const *val); auto read_cardinal(address_space as) const -> std::size_t; @@ -37,13 +39,12 @@ class insert_barrier_pass { bool raw_war_or_waw(address_space as, reads_writes const &rw, aa_results const &aa) const; private: - auto address_space_to_index(address_space as) const -> std::size_t; + static auto address_space_to_index(address_space as) -> std::size_t; std::array, address_spaces.size()> reads, writes; }; - auto run_on_region(region_node ®, aa_results const &aa, - const bool insert_barriers = true) -> reads_writes; + void run_on_region(region_node ®, aa_results const &aa); }; } // namespace tinytc diff --git a/src/pass/insert_lifetime_stop.cpp b/src/pass/insert_lifetime_stop.cpp index 60cddadf..aef570b5 100644 --- a/src/pass/insert_lifetime_stop.cpp +++ b/src/pass/insert_lifetime_stop.cpp @@ -49,7 +49,7 @@ auto insert_lifetime_stop_pass::run_on_region(region_node ®, aa_results const auto alloca_it = allocas.begin(); while (alloca_it != allocas.end()) { if (rgn_ops.contains(alloca_it->get())) { - prev_it = reg.insert_after( + prev_it = reg.insts().insert_after( prev_it, std::make_unique(*alloca_it).release()); --prev_it; alloca_it = allocas.erase(alloca_it); diff --git a/src/passes.hpp b/src/passes.hpp index d8961911..3f3cfd7e 100644 --- a/src/passes.hpp +++ b/src/passes.hpp @@ -9,9 +9,15 @@ namespace tinytc { +template void run_function_pass(FunctionPass &&pass, tinytc_prog &p) { + for (auto &fun : p.functions()) { + pass.run_on_function(*fun); + } +} + template void run_function_pass(FunctionPass &&pass, tinytc_prog const &p) { - for (auto const &func : p.functions()) { - pass.run_on_function(*func); + for (auto const &fun : p.functions()) { + pass.run_on_function(*fun); } } diff --git a/src/prog.cpp b/src/prog.cpp index 1b4de169..11c05e98 100644 --- a/src/prog.cpp +++ b/src/prog.cpp @@ -37,7 +37,7 @@ tinytc_status_t tinytc_prog_add_function(tinytc_prog_t prg, tinytc_func_t fun) { if (prg == nullptr || fun == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { prg->push_back(func{fun}); }); + return exception_to_status_code([&] { prg->push_back(fun); }); } tinytc_status_t tinytc_prog_release(tinytc_prog_t obj) { diff --git a/src/region.cpp b/src/region.cpp index 043a5bb3..2912c91f 100644 --- a/src/region.cpp +++ b/src/region.cpp @@ -29,7 +29,7 @@ tinytc_status_t tinytc_region_add_instruction(tinytc_region_t reg, tinytc_inst_t if (reg == nullptr || instruction == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { reg->push_back(instruction); }); + return exception_to_status_code([&] { reg->insts().push_back(instruction); }); } void tinytc_region_destroy(tinytc_region_t obj) { delete obj; } diff --git a/src/support/ilist_base.hpp b/src/support/ilist_base.hpp index 6dfd0806..b2d9cc4b 100644 --- a/src/support/ilist_base.hpp +++ b/src/support/ilist_base.hpp @@ -11,21 +11,35 @@ namespace tinytc { -template class ilist_node { +template class ilist_iterator; + +template class ilist_node { public: - auto prev() const -> T * { return prev_; } - void prev(T *prev) { prev_ = prev; } - auto next() const -> T * { return next_; } - void next(T *next) { next_ = next; } + auto prev() const -> NodeT * { return prev_; } + void prev(NodeT *prev) { prev_ = prev; } + auto next() const -> NodeT * { return next_; } + void next(NodeT *next) { next_ = next; } auto sentinel() const -> bool { return sentinel_; } void set_sentinel() { sentinel_ = true; } + auto iterator() -> ilist_iterator { return {this}; } + private: - T *prev_ = nullptr, *next_ = nullptr; + NodeT *prev_ = nullptr, *next_ = nullptr; bool sentinel_ = false; }; +template +class ilist_node_with_parent : public ilist_node { + public: + auto parent() const -> ParentT * { return parent_; } + void parent(ParentT *parent) { parent_ = parent; } + + private: + ParentT *parent_ = nullptr; +}; + template class ilist_iterator { public: using base_type = std::conditional_t, ilist_node>; @@ -67,17 +81,17 @@ template class ilist_iterator { }; template struct ilist_dummy_callback { - static void on_insert(NodeT *) {} - static void on_erase(NodeT *) {} + void node_added(NodeT *) {} + void node_removed(NodeT *) {} }; template > -requires requires(NodeT *node) { +requires requires(IListCallback &cb, NodeT *node) { std::is_base_of_v, NodeT>; - IListCallback::on_insert(node); - IListCallback::on_erase(node); + cb.node_added(node); + cb.node_removed(node); } -class ilist_base { +class ilist_base : protected IListCallback { public: using value_type = NodeT; using size_type = std::size_t; @@ -133,7 +147,7 @@ class ilist_base { // |0| (it -> s) : node{prev->s,next->s}, s{prev->n0,next->n0} // |1| (it -> n0): node{prev->s,next->n0}, n0{prev->node,next->s}, s{prev->n0,next->node} // |1| (it -> s) : n0{prev->s,next->node}, node{prev->n0,next->s}, s{prev->node,next->n0} - IListCallback::on_insert(node); + this->node_added(node); return iterator{node}; } template auto insert(iterator it, ItT begin, ItT end) -> iterator { @@ -165,7 +179,7 @@ class ilist_base { // |1| (it -> n0): s{prev->s,next->s} // |2| (it -> n0): n1{prev->s,next->s}, s{prev->n1,next->n1} // |2| (it -> n1): n0{prev->s,next->s}, s{prev->n0,next->n0} - IListCallback::on_erase(it.get()); + this->node_removed(it.get()); return iterator{next}; } auto erase(iterator begin, iterator end) -> iterator { diff --git a/test/opt/insert-barrier.ir b/test/opt/insert-barrier.ir index 263571e7..4a514e3b 100644 --- a/test/opt/insert-barrier.ir +++ b/test/opt/insert-barrier.ir @@ -110,7 +110,7 @@ func @if2(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref axpby.n %a, %C, %b, %D : f32, memref, f32, memref } axpby.n %a, %A, %b, %B : f32, memref, f32, memref -; CHECK-LABEL: func @if({{.*}} +; CHECK-LABEL: func @if2({{.*}} ; CHECK: if %0 { ; CHECK-NEXT: barrier.global ; CHECK-NEXT: axpby.n %a, %A, %b, %B{{.*}} diff --git a/test/opt/insert-lifetime-stop.ir b/test/opt/insert-lifetime-stop.ir index 9f926eff..b2595e05 100644 --- a/test/opt/insert-lifetime-stop.ir +++ b/test/opt/insert-lifetime-stop.ir @@ -54,10 +54,8 @@ func @region1() { ; CHECK: gemm.n.n{{.*}} ; CHECK-NEXT: lifetime_stop %2 ; CHECK-NEXT: axpby.n{{.*}} -; CHECK-NEXT: yield : ; CHECK-NEXT: } ; CHECK-NEXT: lifetime_stop %1 -; CHECK-NEXT: yield : ; CHECK-NEXT: } ; CHECK-NEXT: lifetime_stop %0 } From 1986f1947612d1965adcb3a92ff09fb6760d96cc Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 24 Sep 2024 15:28:04 +0200 Subject: [PATCH 024/297] Source context -> compiler context Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.rst | 7 + docs/api/builder_capi.yaml | 1 + docs/api/core_capi.rst | 123 +++++++++--------- docs/api/core_capi.yaml | 21 +-- docs/api/core_cxxapi.rst | 66 ++++++---- docs/api/core_cxxapi.yaml | 12 +- include/tinytc/tinytc.h | 106 +++++++-------- include/tinytc/tinytc.hpp | 206 ++++++++++++++++-------------- include/tinytc/types.h | 28 +++- include/tinytc/types.hpp | 2 + src/CMakeLists.txt | 1 + src/compiler.cpp | 13 +- src/compiler_context.cpp | 101 +++++++++++++++ src/compiler_context.hpp | 56 ++++++++ src/error.cpp | 2 +- src/error.hpp | 5 +- src/node/program_node.cpp | 10 ++ src/node/program_node.hpp | 6 +- src/parser.cpp | 132 +++---------------- src/parser.hpp | 36 ------ src/parser/parse_context.cpp | 4 +- src/parser/parse_context.hpp | 12 +- src/parser/parser_impl.yy | 4 +- src/prog.cpp | 18 ++- src/recipe/small_gemm_batched.cpp | 24 ++-- src/recipe/tall_and_skinny.cpp | 17 ++- tools/offline_compiler/main.cpp | 7 +- tools/opt/main.cpp | 7 +- 28 files changed, 567 insertions(+), 460 deletions(-) create mode 100644 src/compiler_context.cpp create mode 100644 src/compiler_context.hpp diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index 920d105e..ebd186cd 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -551,6 +551,8 @@ Program * :ref:`tinytc_prog_dump` + * :ref:`tinytc_prog_get_compiler_context` + * :ref:`tinytc_prog_print_to_file` * :ref:`tinytc_prog_print_to_string` @@ -577,6 +579,11 @@ tinytc_prog_dump .. doxygenfunction:: tinytc_prog_dump +tinytc_prog_get_compiler_context +................................ + +.. doxygenfunction:: tinytc_prog_get_compiler_context + tinytc_prog_print_to_file ......................... diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index e5b44117..e97c45ae 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -87,6 +87,7 @@ Builder C-API: - tinytc_program_create - tinytc_prog_add_function - tinytc_prog_dump + - tinytc_prog_get_compiler_context - tinytc_prog_print_to_file - tinytc_prog_print_to_string - tinytc_prog_release diff --git a/docs/api/core_capi.rst b/docs/api/core_capi.rst index 45052d7b..8188e74f 100644 --- a/docs/api/core_capi.rst +++ b/docs/api/core_capi.rst @@ -48,7 +48,7 @@ Common * :ref:`tinytc_source_t` - * :ref:`tinytc_source_context_t` + * :ref:`tinytc_compiler_context_t` * :ref:`const_tinytc_binary_t` @@ -60,7 +60,9 @@ Common * :ref:`const_tinytc_source_t` - * :ref:`const_tinytc_source_context_t` + * :ref:`const_tinytc_compiler_context_t` + + * :ref:`tinytc_error_reporter_t` Common Enumerations ------------------- @@ -154,10 +156,10 @@ tinytc_source_t .. doxygentypedef:: tinytc_source_t -tinytc_source_context_t -....................... +tinytc_compiler_context_t +......................... -.. doxygentypedef:: tinytc_source_context_t +.. doxygentypedef:: tinytc_compiler_context_t const_tinytc_binary_t ..................... @@ -184,10 +186,15 @@ const_tinytc_source_t .. doxygentypedef:: const_tinytc_source_t -const_tinytc_source_context_t -............................. +const_tinytc_compiler_context_t +............................... -.. doxygentypedef:: const_tinytc_source_context_t +.. doxygentypedef:: const_tinytc_compiler_context_t + +tinytc_error_reporter_t +....................... + +.. doxygentypedef:: tinytc_error_reporter_t Binary ====== @@ -273,6 +280,56 @@ tinytc_prog_compile_to_opencl .. doxygenfunction:: tinytc_prog_compile_to_opencl +Compiler Context +================ + +* Functions + + * :ref:`tinytc_compiler_context_create` + + * :ref:`tinytc_compiler_context_add_source` + + * :ref:`tinytc_compiler_context_set_error_reporter` + + * :ref:`tinytc_compiler_context_report_error` + + * :ref:`tinytc_compiler_context_release` + + * :ref:`tinytc_compiler_context_retain` + +Compiler Context Functions +-------------------------- + +tinytc_compiler_context_create +.............................. + +.. doxygenfunction:: tinytc_compiler_context_create + +tinytc_compiler_context_add_source +.................................. + +.. doxygenfunction:: tinytc_compiler_context_add_source + +tinytc_compiler_context_set_error_reporter +.......................................... + +.. doxygenfunction:: tinytc_compiler_context_set_error_reporter + +tinytc_compiler_context_report_error +.................................... + +.. doxygenfunction:: tinytc_compiler_context_report_error + +tinytc_compiler_context_release +............................... + +.. doxygenfunction:: tinytc_compiler_context_release + +tinytc_compiler_context_retain +.............................. + +.. doxygenfunction:: tinytc_compiler_context_retain + Device Info =========== @@ -565,53 +622,3 @@ tinytc_source_retain .. doxygenfunction:: tinytc_source_retain -Source Context -============== - -* Functions - - * :ref:`tinytc_source_context_create` - - * :ref:`tinytc_source_context_add_source` - - * :ref:`tinytc_source_context_get_error_log` - - * :ref:`tinytc_source_context_report_error` - - * :ref:`tinytc_source_context_release` - - * :ref:`tinytc_source_context_retain` - -Source Context Functions ------------------------- - -tinytc_source_context_create -............................ - -.. doxygenfunction:: tinytc_source_context_create - -tinytc_source_context_add_source -................................ - -.. doxygenfunction:: tinytc_source_context_add_source - -tinytc_source_context_get_error_log -................................... - -.. doxygenfunction:: tinytc_source_context_get_error_log - -tinytc_source_context_report_error -.................................. - -.. doxygenfunction:: tinytc_source_context_report_error - -tinytc_source_context_release -............................. - -.. doxygenfunction:: tinytc_source_context_release - -tinytc_source_context_retain -............................ - -.. doxygenfunction:: tinytc_source_context_retain - diff --git a/docs/api/core_capi.yaml b/docs/api/core_capi.yaml index 7f8a4f9d..d921060b 100644 --- a/docs/api/core_capi.yaml +++ b/docs/api/core_capi.yaml @@ -22,13 +22,14 @@ Core C-API: - tinytc_recipe_t - tinytc_recipe_handler_t - tinytc_source_t - - tinytc_source_context_t + - tinytc_compiler_context_t - const_tinytc_binary_t - const_tinytc_core_info_t - const_tinytc_recipe_t - const_tinytc_recipe_handler_t - const_tinytc_source_t - - const_tinytc_source_context_t + - const_tinytc_compiler_context_t + - tinytc_error_reporter_t Binary: function: - tinytc_binary_create @@ -43,6 +44,14 @@ Core C-API: - tinytc_run_function_pass - tinytc_list_function_passes - tinytc_prog_compile_to_opencl + Compiler Context: + function: + - tinytc_compiler_context_create + - tinytc_compiler_context_add_source + - tinytc_compiler_context_set_error_reporter + - tinytc_compiler_context_report_error + - tinytc_compiler_context_release + - tinytc_compiler_context_retain Device Info: enum: - tinytc_core_feature_flag_t @@ -89,11 +98,3 @@ Core C-API: - tinytc_source_get_extensions - tinytc_source_release - tinytc_source_retain - Source Context: - function: - - tinytc_source_context_create - - tinytc_source_context_add_source - - tinytc_source_context_get_error_log - - tinytc_source_context_report_error - - tinytc_source_context_release - - tinytc_source_context_retain diff --git a/docs/api/core_cxxapi.rst b/docs/api/core_cxxapi.rst index d1919835..f7c60e0c 100644 --- a/docs/api/core_cxxapi.rst +++ b/docs/api/core_cxxapi.rst @@ -28,6 +28,10 @@ Common * :ref:`unique_handle` +* Typedefs + + * :ref:`error_reporter_t` + Common Enumerations ------------------- @@ -72,6 +76,14 @@ unique_handle .. doxygenclass:: tinytc::unique_handle +Common Typedefs +--------------- + +error_reporter_t +................ + +.. doxygentypedef:: tinytc::error_reporter_t + Binary ====== @@ -140,6 +152,33 @@ compile_to_opencl .. doxygenfunction:: tinytc::compile_to_opencl +Compiler Context +================ + +* Functions + + * :ref:`make_compiler_context` + +* Classes + + * :ref:`compiler_context` + +Compiler Context Functions +-------------------------- + +make_compiler_context +..................... + +.. doxygenfunction:: tinytc::make_compiler_context + +Compiler Context Classes +------------------------ + +compiler_context +................ + +.. doxygenclass:: tinytc::compiler_context + Device Info =========== @@ -363,30 +402,3 @@ source .. doxygenclass:: tinytc::source -Source Context -============== - -* Functions - - * :ref:`make_source_context` - -* Classes - - * :ref:`source_context` - -Source Context Functions ------------------------- - -make_source_context -................... - -.. doxygenfunction:: tinytc::make_source_context - -Source Context Classes ----------------------- - -source_context -.............. - -.. doxygenclass:: tinytc::source_context - diff --git a/docs/api/core_cxxapi.yaml b/docs/api/core_cxxapi.yaml index 9c0caa41..13355dd6 100644 --- a/docs/api/core_cxxapi.yaml +++ b/docs/api/core_cxxapi.yaml @@ -12,6 +12,8 @@ Core C++-API: class: - tinytc::shared_handle - tinytc::unique_handle + typedef: + - tinytc::error_reporter_t Binary: enum: - tinytc::bundle_format @@ -24,6 +26,11 @@ Core C++-API: - tinytc::run_function_pass - tinytc::list_function_passes - tinytc::compile_to_opencl + Compiler Context: + function: + - tinytc::make_compiler_context + class: + - tinytc::compiler_context Device Info: enum: - tinytc::core_feature_flag @@ -61,8 +68,3 @@ Core C++-API: Source: class: - tinytc::source - Source Context: - function: - - tinytc::make_source_context - class: - - tinytc::source_context diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 458af722..63762240 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -891,11 +891,13 @@ TINYTC_EXPORT void tinytc_func_destroy(tinytc_func_t fun); * @brief Create program * * @param prg [out] pointer to the prog object created + * @param ctx [in] compiler context object * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_program_create(tinytc_prog_t *prg, + tinytc_compiler_context_t ctx, const tinytc_location_t *loc); /** @@ -911,6 +913,18 @@ TINYTC_EXPORT tinytc_status_t tinytc_program_create(tinytc_prog_t *prg, */ TINYTC_EXPORT tinytc_status_t tinytc_prog_add_function(tinytc_prog_t prg, tinytc_func_t fun); +/** + * @brief Get context object from program object + * + * @param prg [in] program object + * @param ctx [out] pointer to context object; reference count is increased so the user needs to + * call tinytc_compiler_context_release to clean up + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prog_get_compiler_context(const_tinytc_prog_t prg, + tinytc_compiler_context_t *ctx); + /** * @brief Release program object * @@ -1100,22 +1114,22 @@ TINYTC_EXPORT tinytc_status_t tinytc_core_info_retain(tinytc_core_info_t obj); * * @param prg [out] pointer to prog object created * @param filename [in] path to source file - * @param ctx [inout][optional] source context object; stores error log; can be nullptr + * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_parse_file(tinytc_prog_t *prg, char const *filename, - tinytc_source_context_t ctx); + tinytc_compiler_context_t ctx); /** * @brief Parser tensor language source from stdin and create prog * * @param prg [out] pointer to prog object created - * @param ctx [inout][optional] source context object; stores error log; can be nullptr + * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_source_context_t ctx); +TINYTC_EXPORT tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_compiler_context_t ctx); /** * @brief Parser tensor language source from string and create prog @@ -1123,90 +1137,88 @@ TINYTC_EXPORT tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_sour * @param prg [out] pointer to prog object created * @param source_size [in] length of source string * @param source [in] source string - * @param ctx [inout][optional] source context object; stores error log; can be nullptr + * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_parse_string(tinytc_prog_t *prg, size_t source_size, - char const *source, tinytc_source_context_t ctx); + char const *source, + tinytc_compiler_context_t ctx); /** - * @brief Create source context + * @brief Create context * - * The source context stores the tensor language source and enhaces error messages with - * source code context. + * The context stores the tensor language source and reports enhaces error messages with + * source code context. Moreover, the context caches data such as types and constants. * - * @param ctx [out] pointer to the source context object created + * @param ctx [out] pointer to the context object created * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_create(tinytc_source_context_t *ctx); +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_create(tinytc_compiler_context_t *ctx); /** * @brief Add source context * - * Manually add a source file to the source context that can be referenced in a tinytc_location. + * Manually add a source file to the context that can be referenced in a tinytc_location. * Useful to enhance error messages when using the builder methods and classes. * - * @param ctx [in] source context object + * @param ctx [in] context object * @param name [in] source name * @param text [in] source text * @param source_id [out] pointer to source id * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_add_source(tinytc_source_context_t ctx, - char const *name, char const *text, - int32_t *source_id); +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_add_source(tinytc_compiler_context_t ctx, + char const *name, char const *text, + int32_t *source_id); /** - * @brief Get error log + * @brief Set error reporter * - * The string's memory is owned by source context. - * Note that the pointer may invalidated by any function call involving the source context object, - * so the string should be copied or printed right after a call to this function. + * Error reporting function that is called whenever an error occurs in the parser or the builder. * - * @param ctx [in] source context object - * @param log [out] pointer to string + * @param ctx [in] context object + * @param reporter [in] error reporting callback; set to nullptr to disable reporting + * @param user_data [in][optional] pointer to user data that is passed to the callback; can be + * nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_get_error_log(const_tinytc_source_context_t ctx, - char const **log); +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_set_error_reporter( + tinytc_compiler_context_t ctx, tinytc_error_reporter_t reporter, void *user_data); /** * @brief Report an error and augment the error with source context * - * @param ctx [in] source context object + * @param ctx [in] context object * @param location [in] source location * @param what [in] error description - * @param append [in] true: append to error log, false: clear error log * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_report_error(tinytc_source_context_t ctx, - const tinytc_location_t *location, - char const *what, - tinytc_bool_t append); +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_report_error( + tinytc_compiler_context_t ctx, const tinytc_location_t *location, char const *what); /** - * @brief Release source context object + * @brief Release context object * * Decreases reference count by 1, free memory if reference count is 0. * - * @param obj [inout] source context object + * @param obj [inout] context object * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_release(tinytc_source_context_t obj); +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_release(tinytc_compiler_context_t obj); /** - * @brief Increase reference count of source context object by 1 + * @brief Increase reference count of context object by 1 * - * @param obj [inout] source context object + * @param obj [inout] context object * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_retain(tinytc_source_context_t obj); +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_retain(tinytc_compiler_context_t obj); //////////////////////////// ///////// Compiler ///////// @@ -1219,14 +1231,11 @@ TINYTC_EXPORT tinytc_status_t tinytc_source_context_retain(tinytc_source_context * @param prg [inout] tensor program; modified as compiler pass is run * @param info [in][optional] core info object; might be nullptr if core info is not required for * pass - * @param ctx [inout][optional] source context object to save extended error messages that are - * enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_run_function_pass(char const *pass_name, tinytc_prog_t prg, - const_tinytc_core_info_t info, - tinytc_source_context_t ctx); + const_tinytc_core_info_t info); /** * @brief List function passes @@ -1245,14 +1254,11 @@ TINYTC_EXPORT tinytc_status_t tinytc_list_function_passes(uint32_t *names_size, * @param src [out] pointer to the source object created * @param prg [inout] tensor program; modified as compiler passes are run * @param info [in] core info object - * @param ctx [inout][optional] source context object to save extended error messages that are - * enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_t prg, - const_tinytc_core_info_t info, - tinytc_source_context_t ctx); + const_tinytc_core_info_t info); /** * @brief Get source text @@ -1424,7 +1430,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_binary_retain(tinytc_binary_t bin); * @param strideB [in] Number of elements between B-matrices * @param ldC [in] Leading dimension of C * @param strideC [in] Number of elements between C-matrices - * @param ctx [inout][optional] source context object; saves error log; can be nullptr + * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr * * @return tinytc_status_success on success and error otherwise */ @@ -1432,7 +1438,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_small_gemm_batched_create( tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_scalar_type_t ty, tinytc_transpose_t tA, tinytc_transpose_t tB, int64_t M, int64_t N, int64_t K, int64_t ldA, int64_t strideA, int64_t ldB, int64_t strideB, int64_t ldC, int64_t strideC, - tinytc_source_context_t ctx); + tinytc_compiler_context_t ctx); /** * @brief Set kernel arguments for small GEMM batched recipe @@ -1488,13 +1494,13 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_small_gemm_batched_set_args( * @param K [in] Number columns of A, number of rows of B * @param M_block_size [in][optional] Size of M block that each work group gets; pass 0 to have the * parameter auto-selected - * @param ctx [inout][optional] source context object; saves error log; can be nullptr + * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_create( tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_scalar_type_t ty, int64_t N, - int64_t K, int32_t M_block_size, tinytc_source_context_t ctx); + int64_t K, int32_t M_block_size, tinytc_compiler_context_t ctx); /** * @brief Returns a tall and skinny recipe with additional specialization constants @@ -1526,14 +1532,14 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_create( * @param ldC [in] Leading dimension of C; can be TINYTC_DYNAMIC * @param M_block_size [in][optional] Size of M block that each work group gets; pass 0 to have the * parameter auto-selected - * @param ctx [inout][optional] source context object; saves error log; can be nullptr + * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr * * @return */ TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_scalar_type_t ty, int64_t M, int64_t N, int64_t K, int64_t ldA, int64_t ldB, int64_t ldC, int32_t M_block_size, - tinytc_source_context_t ctx); + tinytc_compiler_context_t ctx); /** * @brief Suggest an M block size for tall and skinny recipe diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 13256bfb..74178cb1 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -281,6 +281,75 @@ template class unique_handle { T obj_; }; +//////////////////////////// +///// Compiler context ///// +//////////////////////////// + +namespace internal { +template <> struct shared_handle_traits { + static auto retain(tinytc_compiler_context_t handle) -> tinytc_status_t { + return tinytc_compiler_context_retain(handle); + } + static auto release(tinytc_compiler_context_t handle) -> tinytc_status_t { + return tinytc_compiler_context_release(handle); + } +}; +} // namespace internal + +//! @brief Reference-counting wrapper for tinytc_compiler_context_t +class compiler_context : public shared_handle { + public: + using shared_handle::shared_handle; + + /** + * @brief Add compiler to context + * + * @param name File name + * @param text Source text + * + * @return Source id (should be set in position.source_id) + */ + inline auto add_source(char const *name, char const *text) -> std::int32_t { + std::int32_t source_id; + CHECK_STATUS(tinytc_compiler_context_add_source(obj_, name, text, &source_id)); + return source_id; + } + /** + * @brief Set error reporter + * + * Error reporting function that is called whenever an error occurs in the parser or the + * builder. + * + * @param reporter error reporting callback + * @param user_data pointer to user data that is passed to the callback + * + * @return tinytc_status_success on success and error otherwise + */ + inline void set_error_reporter(error_reporter_t reporter, void *user_data) { + CHECK_STATUS(tinytc_compiler_context_set_error_reporter(obj_, reporter, user_data)); + } + /** + * @brief Enhance error message with compiler context; useful when builder is used + * + * @param loc Source location + * @param what Error description + */ + inline void report_error(location const &loc, char const *what) { + CHECK_STATUS(tinytc_compiler_context_report_error(obj_, &loc, what)); + } +}; + +/** + * @brief Create compiler context + * + * @return Compiler context + */ +inline auto make_compiler_context() -> compiler_context { + tinytc_compiler_context_t ctx; + CHECK_STATUS(tinytc_compiler_context_create(&ctx)); + return compiler_context{ctx}; +} + //////////////////////////// ///////// Data type //////// //////////////////////////// @@ -1351,6 +1420,16 @@ class prog : public shared_handle { * @brief Dump program to stderr */ void dump() const { CHECK_STATUS(tinytc_prog_dump(obj_)); } + /** + * @brief Get context + * + * @return Compiler context + */ + auto get_compiler_context() const -> compiler_context { + tinytc_compiler_context_t ctx; + CHECK_STATUS(tinytc_prog_get_compiler_context(obj_, &ctx)); + return compiler_context{ctx}; + } /** * @brief Dump program to file * @@ -1374,13 +1453,14 @@ class prog : public shared_handle { /** * @brief Make program * + * @param ctx Compiler context * @param loc Source code location * * @return Program */ -inline prog make_program(location const &loc = {}) { +inline prog make_program(compiler_context const &ctx, location const &loc = {}) { tinytc_prog_t prg; - CHECK_STATUS_LOC(tinytc_program_create(&prg, &loc), loc); + CHECK_STATUS_LOC(tinytc_program_create(&prg, ctx.get(), &loc), loc); return prog{prg}; } @@ -1570,6 +1650,8 @@ class function_builder { * @brief creates function \@name * * @param name Function name + * @param loc Source code location + * */ inline function_builder(std::string name, location const &loc = {}) : name_(std::move(name)), body_{nullptr}, loc_(loc) {} @@ -1577,8 +1659,6 @@ class function_builder { /** * @brief Returns built product * - * @param loc Source code location - * * @return Function */ inline func get_product() && { @@ -1655,9 +1735,12 @@ class program_builder { /** * @brief ctor * + * @param ctx Compiler context * @param loc Source code location + * */ - program_builder(location const &loc = {}) : prg_{make_program(loc)} {} + program_builder(compiler_context const &ctx, location const &loc = {}) + : prg_{make_program(ctx, loc)} {} /** * @brief create function \@name with functor f(function_builder&) -> void @@ -1681,8 +1764,6 @@ class program_builder { /** * @brief Returns built product * - * @param loc Source code location - * * @return Program */ inline prog get_product() && { return std::move(prg_); } @@ -1809,112 +1890,43 @@ inline auto make_core_info_intel(std::uint32_t ip_version, std::int32_t num_eus_ ////////// Parser ////////// //////////////////////////// -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_source_context_t handle) -> tinytc_status_t { - return tinytc_source_context_retain(handle); - } - static auto release(tinytc_source_context_t handle) -> tinytc_status_t { - return tinytc_source_context_release(handle); - } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_source_context_t -class source_context : public shared_handle { - public: - using shared_handle::shared_handle; - - /** - * @brief Add source to context - * - * @param name File name - * @param text Source text - * - * @return Source id (should be set in position.source_id) - */ - inline auto add_source(char const *name, char const *text) -> std::int32_t { - std::int32_t source_id; - CHECK_STATUS(tinytc_source_context_add_source(obj_, name, text, &source_id)); - return source_id; - } - /** - * @brief Get error log - * - * @return C-string that is valid as long as source_context is not modified; empty string if - * source_context is empty - */ - inline auto get_error_log() const noexcept -> char const * { - if (obj_) { - char const *log; - // No need to call CHECK_STATUS, as the only possible error code is - // tinytc_status_invalid_arguments but we only pass valid arguments - tinytc_source_context_get_error_log(obj_, &log); - return log; - } - return ""; - } - /** - * @brief Enhance error message with source context; useful when builder is used - * - * @param loc Source location - * @param what Error description - * @param append True: append to error log; false: clear error log - */ - inline void report_error(location const &loc, char const *what, bool append = true) { - CHECK_STATUS(tinytc_source_context_report_error(obj_, &loc, what, - static_cast(append))); - } -}; - -/** - * @brief Create source context - * - * @return Source context - */ -inline auto make_source_context() -> source_context { - tinytc_source_context_t ctx; - CHECK_STATUS(tinytc_source_context_create(&ctx)); - return source_context{ctx}; -} - /** * @brief Parse source text from file * * @param filename Filename - * @param source_ctx Source context for improved error reporting + * @param ctx Compiler context * * @return Program */ -inline auto parse_file(char const *filename, source_context source_ctx = {}) -> prog { +inline auto parse_file(char const *filename, compiler_context ctx = {}) -> prog { tinytc_prog_t prg; - CHECK_STATUS(tinytc_parse_file(&prg, filename, source_ctx.get())); + CHECK_STATUS(tinytc_parse_file(&prg, filename, ctx.get())); return prog(prg); } /** * @brief Parse source text from stdin * - * @param source_ctx Source context for improved error reporting + * @param ctx Compiler context * * @return Program */ -inline auto parse_stdin(source_context source_ctx = {}) -> prog { +inline auto parse_stdin(compiler_context ctx = {}) -> prog { tinytc_prog_t prg; - CHECK_STATUS(tinytc_parse_stdin(&prg, source_ctx.get())); + CHECK_STATUS(tinytc_parse_stdin(&prg, ctx.get())); return prog(prg); } /** * @brief Parse source text from string * * @param src Source text - * @param source_ctx Source context for improved error reporting + * @param ctx Compiler context * * @return Porgram */ -inline auto parse_string(std::string const &src, source_context source_ctx = {}) -> prog { +inline auto parse_string(std::string const &src, compiler_context ctx = {}) -> prog { tinytc_prog_t prg; - CHECK_STATUS(tinytc_parse_string(&prg, src.size(), src.c_str(), source_ctx.get())); + CHECK_STATUS(tinytc_parse_string(&prg, src.size(), src.c_str(), ctx.get())); return prog(prg); } @@ -2045,12 +2057,9 @@ inline auto make_binary(bundle_format format, std::size_t data_size, std::uint8_ * @param pass_name name of function pass; cf. list_function_passes * @param prg tensor program; modified as compiler pass is run * @param info core info object; might be nullptr if core info is not required for pass - * @param ctx source context object to save extended error messages that are - * enhanced with source code context */ -inline void run_function_pass(char const *pass_name, prog prg, core_info info = {}, - source_context ctx = {}) { - CHECK_STATUS(tinytc_run_function_pass(pass_name, prg.get(), info.get(), ctx.get())); +inline void run_function_pass(char const *pass_name, prog prg, core_info info = {}) { + CHECK_STATUS(tinytc_run_function_pass(pass_name, prg.get(), info.get())); } /** @@ -2068,13 +2077,12 @@ inline void list_function_passes(std::uint32_t &names_size, char const *const *& * * @param prg Program * @param info Core info - * @param ctx Source context for improved error reporting * * @return Source */ -inline auto compile_to_opencl(prog prg, core_info const &info, source_context ctx = {}) -> source { +inline auto compile_to_opencl(prog prg, core_info const &info) -> source { tinytc_source_t src; - CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, prg.get(), info.get(), ctx.get())); + CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, prg.get(), info.get())); return source{src}; } @@ -2245,7 +2253,7 @@ class small_gemm_batched : public recipe { * @param strideB Stride of B-matrices * @param ldC Leading dimension of an C matrix * @param strideC Stride of C-matrices - * @param ctx Source context for improved error reporting + * @param ctx Compiler context * * @return Small GEMM batched recipe */ @@ -2253,7 +2261,7 @@ inline auto make_small_gemm_batched(core_info const &info, scalar_type ty, trans transpose tB, std::int64_t M, std::int64_t N, std::int64_t K, std::int64_t ldA, std::int64_t strideA, std::int64_t ldB, std::int64_t strideB, std::int64_t ldC, std::int64_t strideC, - source_context ctx = {}) -> small_gemm_batched { + compiler_context ctx = {}) -> small_gemm_batched { tinytc_recipe_t rec; CHECK_STATUS(tinytc_recipe_small_gemm_batched_create( &rec, info.get(), static_cast(ty), @@ -2302,13 +2310,13 @@ class tall_and_skinny : public recipe { * @param N Number of columns of B and C * @param K Number of columns of A, number of rows of B * @param M_block_size Chunk size for M-mode - * @param ctx Source context for improved error reporting + * @param ctx Compiler context * * @return Tall and skinny recipe */ inline auto make_tall_and_skinny(core_info const &info, scalar_type ty, std::int64_t N, std::int64_t K, std::int32_t M_block_size = 0, - source_context ctx = {}) -> tall_and_skinny { + compiler_context ctx = {}) -> tall_and_skinny { tinytc_recipe_t rec; CHECK_STATUS(tinytc_recipe_tall_and_skinny_create( &rec, info.get(), static_cast(ty), N, K, M_block_size, ctx.get())); @@ -2329,7 +2337,7 @@ inline auto make_tall_and_skinny(core_info const &info, scalar_type ty, std::int * @param ldB Leading dimension of B; can be dynamic * @param ldC Leading dimension of C; can be dynamic * @param M_block_size Chunk size for M-mode - * @param ctx Source context for improved error reporting + * @param ctx Compiler context * * @return Tall and skinny recipe */ @@ -2337,7 +2345,7 @@ inline auto make_tall_and_skinny_specialized(core_info const &info, scalar_type std::int64_t N, std::int64_t K, std::int64_t ldA, std::int64_t ldB, std::int64_t ldC, std::int32_t M_block_size = 0, - source_context ctx = {}) -> tall_and_skinny { + compiler_context ctx = {}) -> tall_and_skinny { tinytc_recipe_t rec; CHECK_STATUS(tinytc_recipe_tall_and_skinny_create_specialized( &rec, info.get(), static_cast(ty), M, N, K, ldA, ldB, ldC, diff --git a/include/tinytc/types.h b/include/tinytc/types.h index ff1ff8b8..812e3346 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -394,13 +394,13 @@ typedef struct tinytc_source *tinytc_source_t; //! @brief const source handle typedef const struct tinytc_source *const_tinytc_source_t; -//! @struct tintyc_source_context -//! @brief Opaque struct for source context -struct tinytc_source_context; -//! @brief source_context handle -typedef struct tinytc_source_context *tinytc_source_context_t; -//! @brief const source_context handle -typedef const struct tinytc_source_context *const_tinytc_source_context_t; +//! @struct tintyc_compiler_context +//! @brief Opaque struct for compiler context +struct tinytc_compiler_context; +//! @brief compiler_context handle +typedef struct tinytc_compiler_context *tinytc_compiler_context_t; +//! @brief const compiler_context handle +typedef const struct tinytc_compiler_context *const_tinytc_compiler_context_t; //! @struct tinytc_binary; //! @brief Opaque struct for a binary @@ -443,6 +443,20 @@ typedef struct tinytc_location { tinytc_position_t end; ///< End position } tinytc_location_t; +//////////////////////////// +///////// Callbacks //////// +//////////////////////////// + +/** + * @brief Signature for error reporting callback + * + * @param what Error description + * @param location Source code location + * @param user_data user data that is passed on to callback + */ +typedef void (*tinytc_error_reporter_t)(char const *what, const tinytc_location_t *location, + void *user_data); + #ifdef __cplusplus } #endif diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index cfcf5b04..7cb16de9 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -294,6 +294,8 @@ enum class support_level { using position = ::tinytc_position; //! @brief Alias for tinytc_location in namespace tinytc using location = ::tinytc_location; +//! @brief Alias for tinytc_error_reporter_t in namespace tinytc +using error_reporter_t = ::tinytc_error_reporter_t; } // namespace tinytc diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 87a4f4bb..9b7737e5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -24,6 +24,7 @@ set(SOURCES binary.cpp codegen_tools.cpp compiler.cpp + compiler_context.cpp data_type.cpp device_info.cpp error.cpp diff --git a/src/compiler.cpp b/src/compiler.cpp index 74cea869..c3c83d84 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -25,20 +25,18 @@ #include #include +#include #include #include #include #include -#include - using namespace tinytc; extern "C" { tinytc_status_t tinytc_run_function_pass(char const *pass_name, tinytc_prog_t prg, - const_tinytc_core_info_t info, - tinytc_source_context_t ctx) { + const_tinytc_core_info_t info) { if (prg == nullptr) { return tinytc_status_invalid_arguments; } @@ -57,7 +55,7 @@ tinytc_status_t tinytc_run_function_pass(char const *pass_name, tinytc_prog_t pr #undef FUNCTION_PASS_WITH_INFO throw status::unknown_pass_name; }, - ctx); + prg->get_context()); } tinytc_status_t tinytc_list_function_passes(uint32_t *names_size, char const *const **names) { @@ -78,8 +76,7 @@ tinytc_status_t tinytc_list_function_passes(uint32_t *names_size, char const *co } tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_t prg, - const_tinytc_core_info_t info, - tinytc_source_context_t ctx) { + const_tinytc_core_info_t info) { if (src == nullptr || prg == nullptr || info == nullptr) { return tinytc_status_invalid_arguments; } @@ -111,6 +108,6 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ info->core_features()) .release(); }, - ctx); + prg->get_context()); } } diff --git a/src/compiler_context.cpp b/src/compiler_context.cpp new file mode 100644 index 00000000..9ed9fd34 --- /dev/null +++ b/src/compiler_context.cpp @@ -0,0 +1,101 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "compiler_context.hpp" + +#include "compiler_context.hpp" +#include "error.hpp" +#include "location.hpp" +#include "tinytc/tinytc.h" +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include + +namespace tinytc { +void default_error_reporter(char const *what, const tinytc_location_t *, void *) { + std::cerr << what << std::endl; +} + +} // namespace tinytc + +using namespace tinytc; + +extern "C" { + +auto tinytc_compiler_context::source_name(std::int32_t source_id) + -> std::pair { + if (has_source_id(source_id)) { + auto &si = sources_[source_id - 1]; + return {si.name.c_str(), si.name.size()}; + } + return {unavailable_source_name, sizeof(unavailable_source_name) / sizeof(char) - 1}; +} +auto tinytc_compiler_context::source_text(std::int32_t source_id) + -> std::pair { + if (has_source_id(source_id)) { + auto &si = sources_[source_id - 1]; + return {si.text.c_str(), si.text.size()}; + } + return {"", 0}; +} +void tinytc_compiler_context::report_error(location const &l, char const *what) { + auto [name, name_size] = source_name(l.begin.source_id); + auto [text, text_size] = source_text(l.begin.source_id); + auto err = report_error_with_context(text, text_size, name, l, what); + reporter_(err.c_str(), &l, user_data_); +} + +tinytc_status_t tinytc_compiler_context_create(tinytc_compiler_context_t *ctx) { + if (ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *ctx = std::make_unique().release(); }); +} + +tinytc_status_t tinytc_compiler_context_add_source(tinytc_compiler_context_t ctx, char const *name, + char const *text, int32_t *source_id) { + if (ctx == nullptr || name == nullptr || text == nullptr || source_id == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *source_id = ctx->add_source(name, text); }); +} + +tinytc_status_t tinytc_compiler_context_set_error_reporter(tinytc_compiler_context_t ctx, + tinytc_error_reporter_t reporter, + void *user_data) { + if (ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { ctx->set_error_reporter(reporter, user_data); }); +} + +tinytc_status_t tinytc_compiler_context_report_error(tinytc_compiler_context_t ctx, + const tinytc_location_t *location, + char const *what) { + if (ctx == nullptr || location == nullptr || what == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { ctx->report_error(*location, what); }); +} + +tinytc_status_t tinytc_compiler_context_release(tinytc_compiler_context_t obj) { + if (obj == nullptr) { + return tinytc_status_invalid_arguments; + } + auto ref_count = obj->dec_ref(); + if (ref_count == 0) { + delete obj; + } + return tinytc_status_success; +} + +tinytc_status_t tinytc_compiler_context_retain(tinytc_compiler_context_t obj) { + if (obj == nullptr) { + return tinytc_status_invalid_arguments; + } + obj->inc_ref(); + return tinytc_status_success; +} +} diff --git a/src/compiler_context.hpp b/src/compiler_context.hpp new file mode 100644 index 00000000..01db18ed --- /dev/null +++ b/src/compiler_context.hpp @@ -0,0 +1,56 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef COMPILER_CONTEXT_20240924_HPP +#define COMPILER_CONTEXT_20240924_HPP + +#include "reference_counted.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" + +#include +#include +#include +#include + +namespace tinytc { +void default_error_reporter(char const *what, const tinytc_location_t *location, void *user_data); +} + +struct tinytc_compiler_context : tinytc::reference_counted { + public: + constexpr static const char unavailable_source_name[] = "Source name unavailable"; + + inline void set_error_reporter(tinytc::error_reporter_t reporter, void *user_data) { + reporter_ = reporter; + user_data_ = user_data; + } + + // source / error handling + inline auto add_source(std::string name, std::string text) -> std::int32_t { + sources_.emplace_back(source_input{std::move(name), std::move(text)}); + return static_cast(sources_.size()); + } + inline auto add_source(char const *name, char const *text) -> std::int32_t { + sources_.emplace_back(source_input{std::string(name), std::string(text)}); + return static_cast(sources_.size()); + } + auto source_name(std::int32_t source_id) -> std::pair; + auto source_text(std::int32_t source_id) -> std::pair; + void report_error(tinytc_location const &l, char const *what); + + private: + struct source_input { + std::string name, text; + }; + + inline bool has_source_id(std::int32_t source_id) const { + return source_id >= 1 && static_cast(source_id) <= sources_.size(); + } + + tinytc::error_reporter_t reporter_ = &tinytc::default_error_reporter; + void *user_data_ = nullptr; + std::vector sources_; +}; + +#endif // COMPILER_CONTEXT_20240924_HPP diff --git a/src/error.cpp b/src/error.cpp index 93f43f03..fce5d9b0 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -74,7 +74,7 @@ auto report_error_with_context(char const *code, std::size_t code_len, std::stri oerr << file_name << ":"; print_range(oerr, l.begin, l.end); oerr << ": " << what; - return oerr.str(); + return std::move(oerr).str(); } } // namespace tinytc diff --git a/src/error.hpp b/src/error.hpp index 3b48d27c..6784d180 100644 --- a/src/error.hpp +++ b/src/error.hpp @@ -4,7 +4,7 @@ #ifndef ERROR_20240410_HPP #define ERROR_20240410_HPP -#include "parser.hpp" +#include "compiler_context.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" @@ -47,7 +47,8 @@ class internal_compiler_error : public std::exception { }; template -auto exception_to_status_code(F &&f, tinytc_source_context_t context = nullptr) -> tinytc_status_t { +auto exception_to_status_code(F &&f, + tinytc_compiler_context_t context = nullptr) -> tinytc_status_t { try { f(); } catch (internal_compiler_error const &e) { diff --git a/src/node/program_node.cpp b/src/node/program_node.cpp index a8479b31..03ca8492 100644 --- a/src/node/program_node.cpp +++ b/src/node/program_node.cpp @@ -4,9 +4,19 @@ #include "node/program_node.hpp" #include "node/function_node.hpp" +using namespace tinytc; + +extern "C" { + +tinytc_prog::tinytc_prog(tinytc::compiler_context ctx, tinytc_location const &lc) + : ctx_{std::move(ctx)} { + loc(lc); +} + tinytc_prog::~tinytc_prog() { for (auto &f : functions()) { tinytc_func_destroy(f); } } +} diff --git a/src/node/program_node.hpp b/src/node/program_node.hpp index aebb3480..9511cb0e 100644 --- a/src/node/program_node.hpp +++ b/src/node/program_node.hpp @@ -7,6 +7,7 @@ #include "location.hpp" #include "reference_counted.hpp" #include "support/util.hpp" +#include "tinytc/tinytc.hpp" #include #include @@ -19,9 +20,11 @@ using const_func_range = iterator_range_wrapper; struct tinytc_prog final : tinytc::reference_counted { public: - inline tinytc_prog(tinytc_location const &lc = {}) { loc(lc); } + tinytc_prog(tinytc::compiler_context ctx, tinytc_location const &lc = {}); ~tinytc_prog(); + inline auto get_context() const -> tinytc_compiler_context_t { return ctx_.get(); } + inline auto loc() const noexcept -> tinytc_location const & { return loc_; } inline void loc(tinytc_location const &loc) noexcept { loc_ = loc; } @@ -43,6 +46,7 @@ struct tinytc_prog final : tinytc::reference_counted { inline void push_back(tinytc_func_t fun) { funcs_.push_back(fun); } private: + tinytc::compiler_context ctx_; std::vector funcs_; tinytc_location loc_; }; diff --git a/src/parser.cpp b/src/parser.cpp index 9abe3fb7..b6ce6bf4 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -18,68 +18,30 @@ #include #include #include -#include #include namespace tinytc { -auto parse(std::uint64_t size, char const *input) -> prog { - auto const initial_loc = location{position{0, 1, 1}, position{0, 1, 1}}; - auto lex = lexer(size, input, initial_loc); - auto ctx = parse_context{}; - auto p = parser(lex, ctx); - if (p() == 0) { - return ctx.program(); - } - return prog{}; -} -} // namespace tinytc -using namespace tinytc; - -extern "C" { - -tinytc_source_context::tinytc_source_context() {} - -auto tinytc_source_context::parse(std::string name, std::string text) -> prog { - sources_.emplace_back(source_input{std::move(name), std::move(text)}); - std::int32_t source_id = static_cast(sources_.size()); +auto parse(std::string name, std::string text, compiler_context const &compiler_ctx) -> prog { + std::int32_t source_id = compiler_ctx->add_source(std::move(name), std::move(text)); auto const initial_loc = location{position{source_id, 1, 1}, position{source_id, 1, 1}}; - auto const &input = sources_.back(); - auto lex = lexer(input.text.size(), input.text.c_str(), initial_loc); - auto ctx = parse_context{}; - auto p = parser(lex, ctx); + auto [ir, ir_size] = compiler_ctx->source_text(source_id); + auto lex = lexer(ir_size, ir, initial_loc); + auto parse_ctx = parse_context{compiler_ctx}; + auto p = parser(lex, parse_ctx); if (p() == 0) { - return ctx.program(); - } - last_error_log_.clear(); - for (auto const &err : ctx.errors()) { - last_error_log_ = report_error_with_context(input.text.c_str(), input.text.size(), - input.name, err.first, err.second); + return parse_ctx.program(); } return prog{}; } -void tinytc_source_context::report_error(location const &l, char const *what, bool append) { - auto err = std::string{}; - if (l.begin.source_id >= 1 && static_cast(l.begin.source_id) <= sources_.size()) { - auto const &src = sources_[l.begin.source_id - 1]; - err = report_error_with_context(src.text.c_str(), src.text.size(), src.name, l, what); - } else { - err = (std::ostringstream{} << "\n" - << l << ": " << what) - .str(); - } - if (append) { - last_error_log_ += std::move(err); - } else { - last_error_log_ = std::move(err); - } -} +} // namespace tinytc + +using namespace tinytc; tinytc_status_t tinytc_parse_file(tinytc_prog_t *prg, char const *filename, - tinytc_source_context_t source_ctx) { + tinytc_compiler_context_t ctx) { if (prg == nullptr || filename == nullptr) { return tinytc_status_invalid_arguments; } @@ -89,9 +51,8 @@ tinytc_status_t tinytc_parse_file(tinytc_prog_t *prg, char const *filename, throw status::file_io_error; } auto ir = std::string(std::istreambuf_iterator{ir_stream}, {}); - - auto prog = source_ctx ? source_ctx->parse(std::string(filename), std::move(ir)) - : parse(ir.size(), ir.c_str()); + auto ctx_ = ctx ? compiler_context{ctx} : make_compiler_context(); + auto prog = parse(std::string(filename), std::move(ir), ctx_); if (!prog) { throw status::parse_error; } @@ -99,14 +60,14 @@ tinytc_status_t tinytc_parse_file(tinytc_prog_t *prg, char const *filename, }); } -tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_source_context_t source_ctx) { +tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_compiler_context_t ctx) { if (prg == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { auto ir = std::string(std::istreambuf_iterator{std::cin}, {}); - auto prog = - source_ctx ? source_ctx->parse("", std::move(ir)) : parse(ir.size(), ir.c_str()); + auto ctx_ = ctx ? compiler_context{ctx} : make_compiler_context(); + auto prog = parse("", std::move(ir), ctx_); if (!prog) { throw status::parse_error; } @@ -115,71 +76,16 @@ tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_source_context_t s } tinytc_status_t tinytc_parse_string(tinytc_prog_t *prg, size_t source_size, char const *source, - tinytc_source_context_t source_ctx) { + tinytc_compiler_context_t ctx) { if (prg == nullptr || source_size == 0 || source == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto prog = source_ctx - ? source_ctx->parse("", std::string(source, source + source_size)) - : parse(source_size, source); + auto ctx_ = ctx ? compiler_context{ctx} : make_compiler_context(); + auto prog = parse("", std::string(source, source + source_size), ctx_); if (!prog) { throw status::parse_error; } *prg = prog.release(); }); } - -tinytc_status_t tinytc_source_context_create(tinytc_source_context_t *ctx) { - if (ctx == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code( - [&] { *ctx = std::make_unique().release(); }); -} - -tinytc_status_t tinytc_source_context_add_source(tinytc_source_context_t ctx, char const *name, - char const *text, int32_t *source_id) { - if (ctx == nullptr || name == nullptr || text == nullptr || source_id == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { *source_id = ctx->add_source(name, text); }); -} - -tinytc_status_t tinytc_source_context_get_error_log(const_tinytc_source_context_t ctx, - char const **log) { - if (ctx == nullptr || log == nullptr) { - return tinytc_status_invalid_arguments; - } - *log = ctx->last_error_log().c_str(); // last_error_log and c_str are noexcept - return tinytc_status_success; -} - -tinytc_status_t tinytc_source_context_report_error(tinytc_source_context_t ctx, - const tinytc_location_t *location, - char const *what, tinytc_bool_t append) { - if (ctx == nullptr || location == nullptr || what == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { ctx->report_error(*location, what, bool(append)); }); -} - -tinytc_status_t tinytc_source_context_release(tinytc_source_context_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; -} - -tinytc_status_t tinytc_source_context_retain(tinytc_source_context_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - obj->inc_ref(); - return tinytc_status_success; -} -} diff --git a/src/parser.hpp b/src/parser.hpp index 0d1cd1a9..f613edf8 100644 --- a/src/parser.hpp +++ b/src/parser.hpp @@ -4,48 +4,12 @@ #ifndef PARSER_20230614_HPP #define PARSER_20230614_HPP -#include "reference_counted.hpp" #include "tinytc/tinytc.hpp" -#include "tinytc/types.h" #include -#include -#include namespace tinytc { auto parse(std::uint64_t size, char const *input) -> prog; } -/** - * @brief Source manager - * - * The source manager can parse tensor programs from files, stdin, or memory. - * Source code is stored in the manager such that error messages can be enhanced - * with code context. - */ -struct tinytc_source_context : tinytc::reference_counted { - public: - //! @brief ctor - tinytc_source_context(); - - auto parse(std::string name, std::string text) -> tinytc::prog; - - inline auto add_source(char const *name, char const *text) -> std::int32_t { - sources_.emplace_back(source_input{std::string(name), std::string(text)}); - return static_cast(sources_.size()); - } - - //! Annotate context to error message - void report_error(tinytc_location const &l, char const *what, bool append = false); - //! Return error log of last parse call - inline auto last_error_log() const noexcept -> std::string const & { return last_error_log_; } - - private: - struct source_input { - std::string name, text; - }; - std::vector sources_; - std::string last_error_log_; -}; - #endif // PARSER_20230614_HPP diff --git a/src/parser/parse_context.cpp b/src/parser/parse_context.cpp index 97978947..0de4d002 100644 --- a/src/parser/parse_context.cpp +++ b/src/parser/parse_context.cpp @@ -36,8 +36,8 @@ value parse_context::val(std::string const &id, location const &l) { throw parser::syntax_error(l, "Undefined identifier %" + id); } -void parse_context::add_error(location const &loc, std::string const &what) { - errors_.emplace_back(std::make_pair(loc, what)); +void parse_context::report_error(location const &loc, std::string const &what) { + compiler_ctx_->report_error(loc, what.c_str()); } } // namespace tinytc diff --git a/src/parser/parse_context.hpp b/src/parser/parse_context.hpp index d6f8a25d..251286cb 100644 --- a/src/parser/parse_context.hpp +++ b/src/parser/parse_context.hpp @@ -16,7 +16,9 @@ namespace tinytc { class parse_context { public: - inline parse_context() { id_map_.push_back({}); } + inline parse_context(compiler_context compiler_ctx) : compiler_ctx_(compiler_ctx) { + id_map_.push_back({}); + } inline auto program() { return program_; } inline void program(prog p) { program_ = std::move(p); } @@ -25,16 +27,14 @@ class parse_context { void val(std::string const &id, value val, location const &l); value val(std::string const &id, location const &l); - void add_error(location const &loc, std::string const &what); + void report_error(location const &loc, std::string const &what); - inline auto errors() const -> std::vector> const & { - return errors_; - } + auto get_compiler_context() -> compiler_context const & { return compiler_ctx_; } private: + compiler_context compiler_ctx_; std::vector> id_map_; prog program_; - std::vector> errors_; }; } // namespace tinytc diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 475fe571..b3f79cc0 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -230,7 +230,7 @@ %% prog: func_list { - auto p = prog { std::make_unique(@prog).release() }; + auto p = prog { std::make_unique(ctx.get_compiler_context(), @prog).release() }; ctx.program(p); $$ = std::move(p); for (auto& f : $func_list) { @@ -1052,6 +1052,6 @@ slice_size: namespace tinytc { void parser::error(location_type const& l, std::string const& m) { - ctx.add_error(l, m); + ctx.report_error(l, m); } } diff --git a/src/prog.cpp b/src/prog.cpp index 11c05e98..609907b9 100644 --- a/src/prog.cpp +++ b/src/prog.cpp @@ -25,12 +25,15 @@ using namespace tinytc; extern "C" { -tinytc_status_t tinytc_program_create(tinytc_prog_t *prg, const tinytc_location_t *loc) { +tinytc_status_t tinytc_program_create(tinytc_prog_t *prg, tinytc_compiler_context_t ctx, + const tinytc_location_t *loc) { if (prg == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code( - [&] { *prg = std::make_unique(get_optional(loc)).release(); }); + return exception_to_status_code([&] { + *prg = std::make_unique(compiler_context{ctx, true}, get_optional(loc)) + .release(); + }); } tinytc_status_t tinytc_prog_add_function(tinytc_prog_t prg, tinytc_func_t fun) { @@ -66,6 +69,15 @@ tinytc_status_t tinytc_prog_dump(const_tinytc_prog_t prg) { return exception_to_status_code([&] { run_function_pass(dump_ir_pass{std::cerr}, *prg); }); } +tinytc_status_t tinytc_prog_get_compiler_context(const_tinytc_prog_t prg, + tinytc_compiler_context_t *ctx) { + if (prg == nullptr || ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return tinytc::exception_to_status_code( + [&] { *ctx = tinytc::compiler_context(prg->get_context()).release(); }); +} + tinytc_status_t tinytc_prog_print_to_file(const_tinytc_prog_t prg, char const *filename) { if (prg == nullptr || filename == nullptr) { return tinytc_status_invalid_arguments; diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index b65414b6..ff15111f 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -46,12 +46,11 @@ auto small_gemm_batched_recipe::kernel_name(int kernel_num) const -> char const using namespace tinytc; extern "C" { -tinytc_status_t -tinytc_recipe_small_gemm_batched_create(tinytc_recipe_t *recipe, const_tinytc_core_info_t info, - tinytc_scalar_type_t ty, tinytc_transpose_t tA, - tinytc_transpose_t tB, int64_t M, int64_t N, int64_t K, - int64_t ldA, int64_t strideA, int64_t ldB, int64_t strideB, - int64_t ldC, int64_t strideC, tinytc_source_context_t ctx) { +tinytc_status_t tinytc_recipe_small_gemm_batched_create( + tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_scalar_type_t ty, + tinytc_transpose_t tA, tinytc_transpose_t tB, int64_t M, int64_t N, int64_t K, int64_t ldA, + int64_t strideA, int64_t ldB, int64_t strideB, int64_t ldC, int64_t strideC, + tinytc_compiler_context_t ctx) { if (recipe == nullptr || info == nullptr || M == TINYTC_DYNAMIC || N == TINYTC_DYNAMIC || K == TINYTC_DYNAMIC || ldA == TINYTC_DYNAMIC || strideA == TINYTC_DYNAMIC || ldB == TINYTC_DYNAMIC || strideB == TINYTC_DYNAMIC || ldC == TINYTC_DYNAMIC || @@ -59,11 +58,10 @@ tinytc_recipe_small_gemm_batched_create(tinytc_recipe_t *recipe, const_tinytc_co return tinytc_status_invalid_arguments; } + auto ctx_ = ctx ? compiler_context{ctx} : make_compiler_context(); std::int32_t source_id = 0; - if (ctx) { - TINYTC_CHECK_STATUS( - tinytc_source_context_add_source(ctx, "recipe/small_gemm_batched.cpp", "", &source_id)); - } + TINYTC_CHECK_STATUS(tinytc_compiler_context_add_source( + ctx_.get(), "recipe/small_gemm_batched.cpp", "", &source_id)); auto const my_loc = [&](std::source_location const loc = std::source_location::current()) { auto l = location{}; @@ -119,7 +117,7 @@ tinytc_recipe_small_gemm_batched_create(tinytc_recipe_t *recipe, const_tinytc_co my_loc()); }; auto p = [&] { - auto pb = program_builder{my_loc()}; + auto pb = program_builder{ctx_, my_loc()}; pb.create( small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm), [&](function_builder &fb) { kernel(fb, true); }, my_loc()); @@ -130,11 +128,11 @@ tinytc_recipe_small_gemm_batched_create(tinytc_recipe_t *recipe, const_tinytc_co return std::move(pb).get_product(); }(); tinytc_source_t src; - CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, p.get(), info, ctx)); + CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, p.get(), info)); *recipe = std::make_unique(std::move(p), source(src), ty_) .release(); }, - ctx); + ctx_.get()); } tinytc_status_t tinytc_recipe_small_gemm_batched_set_args( diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 43464360..8d016479 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -56,7 +56,7 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create(tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_scalar_type_t ty, int64_t N, int64_t K, int32_t M_block_size, - tinytc_source_context_t ctx) { + tinytc_compiler_context_t ctx) { return tinytc_recipe_tall_and_skinny_create_specialized(recipe, info, ty, TINYTC_DYNAMIC, N, K, TINYTC_DYNAMIC, TINYTC_DYNAMIC, TINYTC_DYNAMIC, M_block_size, ctx); @@ -65,16 +65,15 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create(tinytc_recipe_t *recipe, tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_scalar_type_t ty, int64_t M, int64_t N, int64_t K, int64_t ldA, int64_t ldB, int64_t ldC, int32_t M_block_size, - tinytc_source_context_t ctx) { + tinytc_compiler_context_t ctx) { if (recipe == nullptr || info == nullptr || N == TINYTC_DYNAMIC || K == TINYTC_DYNAMIC) { return tinytc_status_invalid_arguments; } + auto ctx_ = ctx ? compiler_context{ctx} : make_compiler_context(); std::int32_t source_id = 0; - if (ctx) { - TINYTC_CHECK_STATUS( - tinytc_source_context_add_source(ctx, "recipe/tall_and_skinny.cpp", "", &source_id)); - } + TINYTC_CHECK_STATUS(tinytc_compiler_context_add_source(ctx_.get(), "recipe/tall_and_skinny.cpp", + "", &source_id)); auto const my_loc = [&](std::source_location const loc = std::source_location::current()) { auto l = location{}; @@ -158,7 +157,7 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( }; auto p = [&] { - auto pb = program_builder{my_loc()}; + auto pb = program_builder{ctx_, my_loc()}; pb.create( tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm), [&](function_builder &fb) { kernel(fb, true); }, my_loc()); @@ -168,12 +167,12 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( return std::move(pb).get_product(); }(); tinytc_source_t src; - CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, p.get(), info, ctx)); + CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, p.get(), info)); *recipe = std::make_unique(std::move(p), source(src), ty_, M, ldA, ldB, ldC, M_block_size) .release(); }, - ctx); + ctx_.get()); } tinytc_status_t tinytc_recipe_tall_and_skinny_suggest_block_size(const_tinytc_core_info_t info, diff --git a/tools/offline_compiler/main.cpp b/tools/offline_compiler/main.cpp index 83945726..7662a0f4 100644 --- a/tools/offline_compiler/main.cpp +++ b/tools/offline_compiler/main.cpp @@ -29,9 +29,9 @@ int main(int argc, char **argv) { return 0; } - auto ctx = source_context{}; + auto ctx = compiler_context{}; try { - ctx = make_source_context(); + ctx = make_compiler_context(); auto p = prog{}; if (!a.filename) { p = parse_stdin(ctx); @@ -39,11 +39,10 @@ int main(int argc, char **argv) { p = parse_file(a.filename, ctx); } - auto src = compile_to_opencl(std::move(p), a.info, ctx); + auto src = compile_to_opencl(std::move(p), a.info); std::cout << src.get_code(); } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; - std::cerr << "Error log: " << std::endl << ctx.get_error_log() << std::endl; return 1; } catch (std::exception const &e) { std::cerr << e.what() << std::endl; diff --git a/tools/opt/main.cpp b/tools/opt/main.cpp index 662714b9..8b7e9b1c 100644 --- a/tools/opt/main.cpp +++ b/tools/opt/main.cpp @@ -29,9 +29,9 @@ int main(int argc, char **argv) { return 0; } - auto ctx = source_context{}; + auto ctx = compiler_context{}; try { - ctx = make_source_context(); + ctx = make_compiler_context(); auto p = prog{}; if (!a.filename) { p = parse_stdin(ctx); @@ -40,11 +40,10 @@ int main(int argc, char **argv) { } for (auto const &pass_name : a.pass_names) { - run_function_pass(pass_name.c_str(), p, a.info, ctx); + run_function_pass(pass_name.c_str(), p, a.info); } } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; - std::cerr << "Error log: " << std::endl << ctx.get_error_log() << std::endl; return 1; } catch (std::exception const &e) { std::cerr << e.what() << std::endl; From 24604fef48e83b4635532885a56ba743229b56e1 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 24 Sep 2024 15:50:05 +0200 Subject: [PATCH 025/297] Fix includes Signed-off-by: Carsten Uphoff --- src/analysis/aa_results.hpp | 2 +- src/analysis/alias.cpp | 1 - src/analysis/cfg.cpp | 6 +++++- src/analysis/cfg.hpp | 2 -- src/compiler.cpp | 3 +-- src/compiler_context.cpp | 4 +--- src/compiler_context.hpp | 3 ++- src/func.cpp | 2 +- src/node/data_type_node.hpp | 1 - src/node/program_node.cpp | 3 +++ src/node/program_node.hpp | 6 +++--- src/node/region_node.hpp | 9 ++++----- src/parser.cpp | 4 ++-- src/parser/parse_context.cpp | 2 +- src/pass/check_ir.cpp | 5 ----- src/pass/dump_cfg.cpp | 5 ++++- src/pass/insert_barrier.cpp | 9 +++++++-- src/pass/insert_barrier.hpp | 3 +++ src/pass/insert_lifetime_stop.cpp | 7 ++++--- src/pass/insert_lifetime_stop.hpp | 3 +-- src/pass/slot_tracker.cpp | 3 +-- src/pass/slot_tracker.hpp | 3 --- src/pass/stack.cpp | 2 +- src/prog.cpp | 4 ++-- src/recipe.cpp | 1 + src/recipe/small_gemm_batched.cpp | 2 +- src/recipe/tall_and_skinny.cpp | 2 +- src/region.cpp | 6 ++---- src/support/walk.hpp | 3 ++- tools/opt/args.cpp | 1 + tools/opt/main.cpp | 4 ++-- 31 files changed, 57 insertions(+), 54 deletions(-) diff --git a/src/analysis/aa_results.hpp b/src/analysis/aa_results.hpp index 3c553e64..e2f593cb 100644 --- a/src/analysis/aa_results.hpp +++ b/src/analysis/aa_results.hpp @@ -4,7 +4,7 @@ #ifndef AA_RESULTS_20240314_HPP #define AA_RESULTS_20240314_HPP -#include "tinytc/types.h" +#include "node/value_node.hpp" #include #include diff --git a/src/analysis/alias.cpp b/src/analysis/alias.cpp index 812f3496..28a388c7 100644 --- a/src/analysis/alias.cpp +++ b/src/analysis/alias.cpp @@ -5,7 +5,6 @@ #include "error.hpp" #include "node/data_type_node.hpp" #include "node/inst_node.hpp" -#include "node/region_node.hpp" #include "node/value_node.hpp" #include "support/casting.hpp" #include "support/visit.hpp" diff --git a/src/analysis/cfg.cpp b/src/analysis/cfg.cpp index c335d655..3a8aa98b 100644 --- a/src/analysis/cfg.cpp +++ b/src/analysis/cfg.cpp @@ -4,7 +4,11 @@ #include "analysis/cfg.hpp" #include "node/inst_node.hpp" #include "support/casting.hpp" -#include "support/visit.hpp" +#include "support/ilist_base.hpp" +#include "tinytc/tinytc.hpp" + +#include +#include namespace tinytc { diff --git a/src/analysis/cfg.hpp b/src/analysis/cfg.hpp index bbd31ca0..ce2c7b07 100644 --- a/src/analysis/cfg.hpp +++ b/src/analysis/cfg.hpp @@ -8,10 +8,8 @@ #include "node/region_node.hpp" #include "support/util.hpp" -#include #include #include -#include #include namespace tinytc { diff --git a/src/compiler.cpp b/src/compiler.cpp index c3c83d84..59961081 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -4,7 +4,6 @@ #include "device_info.hpp" #include "error.hpp" #include "node/program_node.hpp" -#include "parser.hpp" #include "pass/check_ir.hpp" #include "pass/convert_to_opencl.hpp" #include "pass/dump_cfg.hpp" @@ -24,10 +23,10 @@ #include #include +#include #include #include #include -#include #include #include diff --git a/src/compiler_context.cpp b/src/compiler_context.cpp index 9ed9fd34..4d241dd1 100644 --- a/src/compiler_context.cpp +++ b/src/compiler_context.cpp @@ -1,16 +1,14 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "compiler_context.hpp" - #include "compiler_context.hpp" #include "error.hpp" -#include "location.hpp" #include "tinytc/tinytc.h" #include "tinytc/types.h" #include "tinytc/types.hpp" #include +#include namespace tinytc { void default_error_reporter(char const *what, const tinytc_location_t *, void *) { diff --git a/src/compiler_context.hpp b/src/compiler_context.hpp index 01db18ed..21cd4e36 100644 --- a/src/compiler_context.hpp +++ b/src/compiler_context.hpp @@ -5,9 +5,10 @@ #define COMPILER_CONTEXT_20240924_HPP #include "reference_counted.hpp" -#include "tinytc/tinytc.hpp" #include "tinytc/types.h" +#include "tinytc/types.hpp" +#include #include #include #include diff --git a/src/func.cpp b/src/func.cpp index 20c4f98c..ee94efae 100644 --- a/src/func.cpp +++ b/src/func.cpp @@ -4,7 +4,7 @@ #include "error.hpp" #include "location.hpp" #include "node/function_node.hpp" -#include "support/casting.hpp" +#include "node/region_node.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp index 1b989fba..611e5c06 100644 --- a/src/node/data_type_node.hpp +++ b/src/node/data_type_node.hpp @@ -5,7 +5,6 @@ #define DATA_TYPE_NODE_20230309_HPP #include "reference_counted.hpp" -#include "scalar_type.hpp" #include "support/type_list.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" diff --git a/src/node/program_node.cpp b/src/node/program_node.cpp index 03ca8492..66b56f61 100644 --- a/src/node/program_node.cpp +++ b/src/node/program_node.cpp @@ -3,6 +3,9 @@ #include "node/program_node.hpp" #include "node/function_node.hpp" +#include "tinytc/tinytc.h" + +#include using namespace tinytc; diff --git a/src/node/program_node.hpp b/src/node/program_node.hpp index 9511cb0e..56944fff 100644 --- a/src/node/program_node.hpp +++ b/src/node/program_node.hpp @@ -4,13 +4,13 @@ #ifndef PROGRAM_NODE_20240208_HPP #define PROGRAM_NODE_20240208_HPP -#include "location.hpp" +#include "compiler_context.hpp" +#include "node/function_node.hpp" #include "reference_counted.hpp" #include "support/util.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.h" -#include -#include #include namespace tinytc { diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp index 7b35c22e..279e6ced 100644 --- a/src/node/region_node.hpp +++ b/src/node/region_node.hpp @@ -6,13 +6,12 @@ #include "node/inst_node.hpp" #include "support/ilist.hpp" -#include "support/util.hpp" -#include "tinytc/tinytc.hpp" +#include "tinytc/tinytc.h" +#include "tinytc/types.h" +#include "tinytc/types.hpp" -#include +#include #include -#include -#include namespace tinytc { diff --git a/src/parser.cpp b/src/parser.cpp index b6ce6bf4..921d053b 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -3,8 +3,8 @@ #include "parser.hpp" +#include "compiler_context.hpp" #include "error.hpp" -#include "location.hpp" #include "parser/lexer.hpp" #include "parser/parse_context.hpp" #include "parser/parser_impl.hpp" @@ -17,7 +17,7 @@ #include #include #include -#include +#include #include namespace tinytc { diff --git a/src/parser/parse_context.cpp b/src/parser/parse_context.cpp index 0de4d002..a52fb739 100644 --- a/src/parser/parse_context.cpp +++ b/src/parser/parse_context.cpp @@ -2,8 +2,8 @@ // SPDX-License-Identifier: BSD-3-Clause #include "parse_context.hpp" +#include "compiler_context.hpp" #include "location.hpp" -#include "node/function_node.hpp" #include "node/value_node.hpp" #include "parser/parser_impl.hpp" diff --git a/src/pass/check_ir.cpp b/src/pass/check_ir.cpp index b4885ea5..c310ad07 100644 --- a/src/pass/check_ir.cpp +++ b/src/pass/check_ir.cpp @@ -5,15 +5,10 @@ #include "error.hpp" #include "node/inst_node.hpp" #include "node/region_node.hpp" -#include "support/casting.hpp" -#include "support/visit.hpp" #include "support/walk.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" -#include -#include - namespace tinytc { void check_ir_pass::run_on_function(function_node &fn) { diff --git a/src/pass/dump_cfg.cpp b/src/pass/dump_cfg.cpp index 4d193129..59593876 100644 --- a/src/pass/dump_cfg.cpp +++ b/src/pass/dump_cfg.cpp @@ -4,10 +4,13 @@ #include "pass/dump_cfg.hpp" #include "analysis/cfg.hpp" #include "pass/dump_ir.hpp" -#include "tinytc/tinytc.hpp" +#include "support/util.hpp" #include #include +#include +#include +#include namespace tinytc { diff --git a/src/pass/insert_barrier.cpp b/src/pass/insert_barrier.cpp index 67b1886b..7556edd2 100644 --- a/src/pass/insert_barrier.cpp +++ b/src/pass/insert_barrier.cpp @@ -4,16 +4,21 @@ #include "pass/insert_barrier.hpp" #include "analysis/alias.hpp" #include "analysis/cfg.hpp" +#include "error.hpp" #include "node/data_type_node.hpp" #include "node/inst_node.hpp" #include "node/value_node.hpp" #include "support/casting.hpp" +#include "support/ilist.hpp" +#include "support/util.hpp" #include "support/visit.hpp" #include "tinytc/tinytc.hpp" -#include - +#include +#include #include +#include +#include #include #include diff --git a/src/pass/insert_barrier.hpp b/src/pass/insert_barrier.hpp index 60af109e..dec38ac7 100644 --- a/src/pass/insert_barrier.hpp +++ b/src/pass/insert_barrier.hpp @@ -7,7 +7,10 @@ #include "analysis/aa_results.hpp" #include "node/function_node.hpp" #include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "tinytc/types.hpp" +#include #include #include diff --git a/src/pass/insert_lifetime_stop.cpp b/src/pass/insert_lifetime_stop.cpp index aef570b5..6953ca0c 100644 --- a/src/pass/insert_lifetime_stop.cpp +++ b/src/pass/insert_lifetime_stop.cpp @@ -6,10 +6,11 @@ #include "node/data_type_node.hpp" #include "node/inst_node.hpp" #include "node/value_node.hpp" -#include "support/visit.hpp" +#include "support/casting.hpp" +#include "support/ilist.hpp" +#include "support/ilist_base.hpp" +#include "tinytc/tinytc.hpp" -#include -#include #include #include diff --git a/src/pass/insert_lifetime_stop.hpp b/src/pass/insert_lifetime_stop.hpp index ec827fa3..f6c6751a 100644 --- a/src/pass/insert_lifetime_stop.hpp +++ b/src/pass/insert_lifetime_stop.hpp @@ -7,8 +7,7 @@ #include "analysis/aa_results.hpp" #include "node/function_node.hpp" #include "node/region_node.hpp" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.h" +#include "node/value_node.hpp" #include diff --git a/src/pass/slot_tracker.cpp b/src/pass/slot_tracker.cpp index 209a00bc..749ded13 100644 --- a/src/pass/slot_tracker.cpp +++ b/src/pass/slot_tracker.cpp @@ -2,12 +2,11 @@ // SPDX-License-Identifier: BSD-3-Clause #include "pass/slot_tracker.hpp" -#include "support/visit.hpp" +#include "node/inst_node.hpp" #include "support/walk.hpp" #include "tinytc/tinytc.hpp" #include -#include namespace tinytc { diff --git a/src/pass/slot_tracker.hpp b/src/pass/slot_tracker.hpp index 86abaca2..e4d11eaa 100644 --- a/src/pass/slot_tracker.hpp +++ b/src/pass/slot_tracker.hpp @@ -5,9 +5,6 @@ #define SLOT_TRACKER_20240418_HPP #include "node/function_node.hpp" -#include "node/inst_node.hpp" -#include "node/program_node.hpp" -#include "node/region_node.hpp" #include "node/value_node.hpp" #include diff --git a/src/pass/stack.cpp b/src/pass/stack.cpp index 4e8bfbb9..9b5d939f 100644 --- a/src/pass/stack.cpp +++ b/src/pass/stack.cpp @@ -7,7 +7,6 @@ #include "node/inst_node.hpp" #include "node/value_node.hpp" #include "support/casting.hpp" -#include "support/util.hpp" #include "support/visit.hpp" #include "support/walk.hpp" #include "tinytc/tinytc.hpp" @@ -15,6 +14,7 @@ #include #include +#include namespace tinytc { diff --git a/src/prog.cpp b/src/prog.cpp index 609907b9..d0ad3815 100644 --- a/src/prog.cpp +++ b/src/prog.cpp @@ -1,8 +1,10 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause +#include "compiler_context.hpp" #include "error.hpp" #include "location.hpp" +#include "node/function_node.hpp" #include "node/program_node.hpp" #include "pass/dump_ir.hpp" #include "passes.hpp" @@ -11,7 +13,6 @@ #include "tinytc/types.h" #include "tinytc/types.hpp" -#include #include #include #include @@ -19,7 +20,6 @@ #include #include #include -#include using namespace tinytc; diff --git a/src/recipe.cpp b/src/recipe.cpp index ef98462c..989745f2 100644 --- a/src/recipe.cpp +++ b/src/recipe.cpp @@ -7,6 +7,7 @@ #include "tinytc/types.hpp" #include +#include #include namespace tinytc { diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index ff15111f..be3630ef 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -2,8 +2,8 @@ // SPDX-License-Identifier: BSD-3-Clause #include "small_gemm_batched.hpp" +#include "compiler_context.hpp" #include "error.hpp" -#include "parser.hpp" #include "recipe.hpp" #include "reference_counted.hpp" #include "support/util.hpp" diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 8d016479..1bc11ad8 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -2,9 +2,9 @@ // SPDX-License-Identifier: BSD-3-Clause #include "tall_and_skinny.hpp" +#include "compiler_context.hpp" #include "device_info.hpp" #include "error.hpp" -#include "parser.hpp" #include "recipe.hpp" #include "reference_counted.hpp" #include "support/util.hpp" diff --git a/src/region.cpp b/src/region.cpp index 2912c91f..b8e22f16 100644 --- a/src/region.cpp +++ b/src/region.cpp @@ -3,15 +3,13 @@ #include "error.hpp" #include "location.hpp" +#include "node/inst_node.hpp" #include "node/region_node.hpp" +#include "support/ilist.hpp" #include "tinytc/tinytc.h" -#include "tinytc/tinytc.hpp" #include "tinytc/types.h" -#include #include -#include -#include using namespace tinytc; diff --git a/src/support/walk.hpp b/src/support/walk.hpp index c959a380..5eba2345 100644 --- a/src/support/walk.hpp +++ b/src/support/walk.hpp @@ -7,9 +7,10 @@ #include "node/function_node.hpp" #include "node/inst_node.hpp" #include "node/region_node.hpp" +#include "support/ilist_base.hpp" +#include "tinytc/tinytc.hpp" #include -#include namespace tinytc { diff --git a/tools/opt/args.cpp b/tools/opt/args.cpp index 01b4bff2..65cbb4ac 100644 --- a/tools/opt/args.cpp +++ b/tools/opt/args.cpp @@ -4,6 +4,7 @@ #include "args.hpp" #include "tinytc/types.hpp" +#include #include #include #include diff --git a/tools/opt/main.cpp b/tools/opt/main.cpp index 8b7e9b1c..f0112144 100644 --- a/tools/opt/main.cpp +++ b/tools/opt/main.cpp @@ -8,8 +8,8 @@ #include #include #include -#include -#include +#include +#include using namespace tinytc; From 0ee3cac345d534efe657c4682cfc11d419e52306 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 24 Sep 2024 16:44:04 +0200 Subject: [PATCH 026/297] Remove location information from type Signed-off-by: Carsten Uphoff --- include/tinytc/types.h | 1 + include/tinytc/types.hpp | 1 + src/data_type.cpp | 6 ++--- src/error.cpp | 2 ++ src/node/data_type_node.cpp | 18 ++++++++++---- src/node/data_type_node.hpp | 14 ++--------- src/parser/parser_impl.yy | 5 ++-- src/pass/convert_to_opencl.cpp | 44 ++++++++++++++++------------------ src/pass/convert_to_opencl.hpp | 2 +- 9 files changed, 46 insertions(+), 47 deletions(-) diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 812e3346..3278a0fd 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -64,6 +64,7 @@ typedef enum { tinytc_status_ir_spmd_called_from_collective = 0x113, ///< SPMD instruction from collective tinytc_status_ir_expected_local_address_space = 0x114, ///< Expected local address space tinytc_status_ir_expected_global_address_space = 0x115, ///< Expected global address space + tinytc_status_ir_invalid_offset = 0x116, ///< Invalid offset // Level zero errors tinytc_status_ze_result_not_ready = 0x10000, ///< ZE_RESULT_NOT_READY tinytc_status_ze_result_error_device_lost = 0x10001, ///< ZE_RESULT_ERROR_DEVICE_LOST diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 7cb16de9..8092e86c 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -74,6 +74,7 @@ enum class status { ir_spmd_called_from_collective = tinytc_status_ir_spmd_called_from_collective, ir_expected_local_address_space = tinytc_status_ir_expected_local_address_space, ir_expected_global_address_space = tinytc_status_ir_expected_global_address_space, + ir_invalid_offset = tinytc_status_ir_invalid_offset, ze_result_not_ready = tinytc_status_ze_result_not_ready, ze_result_error_device_lost = tinytc_status_ze_result_error_device_lost, ze_result_error_out_of_host_memory = tinytc_status_ze_result_error_out_of_host_memory, diff --git a/src/data_type.cpp b/src/data_type.cpp index 82b55d74..bb10c800 100644 --- a/src/data_type.cpp +++ b/src/data_type.cpp @@ -24,10 +24,8 @@ tinytc_status_t tinytc_scalar_type_create(tinytc_data_type_t *dt, tinytc_scalar_ return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { - *dt = std::make_unique(enum_cast(type), get_optional(loc)) - .release(); - }); + return exception_to_status_code( + [&] { *dt = std::make_unique(enum_cast(type)).release(); }); } tinytc_status_t tinytc_memref_type_create(tinytc_data_type_t *dt, tinytc_scalar_type_t scalar_ty, diff --git a/src/error.cpp b/src/error.cpp index fce5d9b0..4be98f2f 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -162,6 +162,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "A memref with local address space is expected"; case tinytc_status_ir_expected_global_address_space: return "A memref with global address space is expected"; + case tinytc_status_ir_invalid_offset: + return "Offset must be non-negative or dynamic"; // Level Zero case tinytc_status_ze_result_not_ready: return "ZE_RESULT_NOT_READY"; diff --git a/src/node/data_type_node.cpp b/src/node/data_type_node.cpp index 483c057b..cc5e54ef 100644 --- a/src/node/data_type_node.cpp +++ b/src/node/data_type_node.cpp @@ -3,6 +3,7 @@ #include "node/data_type_node.hpp" #include "error.hpp" +#include "support/casting.hpp" #include "tinytc/types.hpp" #include @@ -10,15 +11,24 @@ namespace tinytc { +group_data_type::group_data_type(data_type ty, std::int64_t offset, location const &lc) + : data_type_node(DTK::group), ty_(std::move(ty)), offset_(offset) { + if (!isa(*ty_)) { + throw compilation_error(lc, status::ir_expected_memref); + } + if (offset < 0 && !is_dynamic_value(offset)) { + throw compilation_error(lc, status::ir_invalid_offset); + } +} + memref_data_type::memref_data_type(scalar_type type, std::vector shape, std::vector stride, address_space addrspace, location const &lc) : data_type_node(DTK::memref), element_ty_(std::move(type)), shape_(std::move(shape)), stride_(std::move(stride)), addrspace_(addrspace) { - loc(lc); for (auto const &s : shape_) { if (s < 0 && !is_dynamic_value(s)) { - throw compilation_error(loc(), status::ir_invalid_shape); + throw compilation_error(lc, status::ir_invalid_shape); } } if (stride_.empty()) { @@ -26,12 +36,12 @@ memref_data_type::memref_data_type(scalar_type type, std::vector s } else { for (auto const &s : stride_) { if (s < 0 && !is_dynamic_value(s)) { - throw compilation_error(loc(), status::ir_invalid_shape); + throw compilation_error(lc, status::ir_invalid_shape); } } } if (stride_.size() != shape_.size()) { - throw compilation_error(loc(), status::ir_shape_stride_mismatch); + throw compilation_error(lc, status::ir_shape_stride_mismatch); } } diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp index 611e5c06..22c2950f 100644 --- a/src/node/data_type_node.hpp +++ b/src/node/data_type_node.hpp @@ -28,12 +28,8 @@ struct tinytc_data_type : tinytc::reference_counted { virtual ~tinytc_data_type() = default; inline auto type_id() const -> tinytc::DTK { return tid_; } - inline auto loc() const noexcept -> tinytc::location const & { return loc_; } - inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } - private: tinytc::DTK tid_; - tinytc::location loc_; }; namespace tinytc { @@ -43,10 +39,7 @@ using data_type_node = ::tinytc_data_type; class group_data_type : public data_type_node { public: inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::group; } - inline group_data_type(data_type ty, std::int64_t offset = 0, location const &lc = {}) - : data_type_node(DTK::group), ty_(std::move(ty)), offset_(offset) { - loc(lc); - } + group_data_type(data_type ty, std::int64_t offset = 0, location const &lc = {}); inline auto ty() const -> data_type const & { return ty_; } inline auto offset() const -> std::int64_t { return offset_; } @@ -95,10 +88,7 @@ class memref_data_type : public data_type_node { class scalar_data_type : public data_type_node { public: inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::scalar; } - inline scalar_data_type(scalar_type type, location const &lc) - : data_type_node(DTK::scalar), ty_(type) { - loc(lc); - } + inline scalar_data_type(scalar_type type) : data_type_node(DTK::scalar), ty_(type) {} inline scalar_type ty() const { return ty_; } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index b3f79cc0..15d1bb8b 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -303,7 +303,7 @@ attribute: data_type: - scalar_type { $$ = make_scalar($scalar_type); $$->loc(@scalar_type); } + scalar_type { $$ = make_scalar($scalar_type); } | memref_type | group_type ; @@ -375,8 +375,7 @@ constant_or_dynamic: group_type: GROUP LCHEV memref_type group_offset RCHEV { - $$ = make_group(std::move($memref_type), $group_offset); - $$->loc(@group_type); + $$ = make_group(std::move($memref_type), $group_offset, @group_type); } ; diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 0da096a1..febe6448 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -151,15 +151,15 @@ auto convert_to_opencl_pass::get_memref_type(value_node const &v) const return t; } -auto convert_to_opencl_pass::get_scalar_type(data_type_node const &ty) -> scalar_type { +auto convert_to_opencl_pass::get_scalar_type(value_node const &v) -> scalar_type { return visit(overloaded{[](scalar_data_type const &i) -> scalar_type { return i.ty(); }, [](memref_data_type const &i) -> scalar_type { return i.element_ty(); }, - [&](auto const &i) -> scalar_type { - throw compilation_error(i.loc(), + [&](auto const &) -> scalar_type { + throw compilation_error(v.loc(), status::ir_expected_memref_or_scalar); return scalar_type{}; }}, - ty); + *v.ty()); }; /* Data type nodes */ @@ -175,7 +175,7 @@ clir::data_type convert_to_opencl_pass::operator()(group_data_type const &g) { [](auto &) { return clir::data_type{}; }}, *ptr_ty); if (!ptr_ty) { - throw compilation_error(g.loc(), status::internal_compiler_error, + throw compilation_error(location{}, status::internal_compiler_error, "Could not determine OpenCL type of group type"); } return ptr_ty; @@ -189,11 +189,11 @@ clir::data_type convert_to_opencl_pass::operator()(scalar_data_type const &s) { /* Value nodes */ clir::expr convert_to_opencl_pass::operator()(float_imm const &v) { - auto ty = get_scalar_type(*v.ty()); + auto ty = get_scalar_type(v); return clir::expr(v.value(), static_cast(size(ty) * 8)); } clir::expr convert_to_opencl_pass::operator()(int_imm const &v) { - auto ty = get_scalar_type(*v.ty()); + auto ty = get_scalar_type(v); return clir::expr(v.value(), static_cast(size(ty) * 8)); } clir::expr convert_to_opencl_pass::operator()(val const &v) { @@ -235,8 +235,8 @@ std::vector convert_to_opencl_pass::operator()(alloca_inst const &a) std::vector convert_to_opencl_pass::operator()(axpby_inst const &inst) { auto at = get_memref_type(*inst.A()); auto bt = get_memref_type(*inst.B()); - auto alpha_ty = get_scalar_type(*inst.alpha()->ty()); - auto beta_ty = get_scalar_type(*inst.beta()->ty()); + auto alpha_ty = get_scalar_type(*inst.alpha()); + auto beta_ty = get_scalar_type(*inst.beta()); auto &adv = get_dope_vector(inst.A().get()); auto &bdv = get_dope_vector(inst.B().get()); @@ -359,7 +359,7 @@ std::vector convert_to_opencl_pass::operator()(arith_inst const &a) } return {}; }; - auto sty = get_scalar_type(*a.a()->ty()); + auto sty = get_scalar_type(*a.a()); auto v = declare(*a.result()); return {declaration_assignment( visit(*this, *a.result()->ty()), std::move(v), @@ -379,7 +379,7 @@ std::vector convert_to_opencl_pass::operator()(arith_unary_inst cons } return {}; }; - auto sty = get_scalar_type(*a.a()->ty()); + auto sty = get_scalar_type(*a.a()); auto v = declare(*a.result()); return {declaration_assignment(visit(*this, *a.result()->ty()), std::move(v), make(a.operation(), visit(*this, *a.a()), sty))}; @@ -617,9 +617,8 @@ std::vector convert_to_opencl_pass::operator()(gemm_inst const &g) { *v); }; - auto gemm_ty = - gemm_scalar_type{get_scalar_type(*g.alpha()->ty()), a->element_ty(), b->element_ty(), - get_scalar_type(*g.beta()->ty()), c->element_ty()}; + auto gemm_ty = gemm_scalar_type{get_scalar_type(*g.alpha()), a->element_ty(), b->element_ty(), + get_scalar_type(*g.beta()), c->element_ty()}; auto cfg = gemm_configuration{std::move(gemm_ty), g.tA(), g.tB(), @@ -671,9 +670,8 @@ std::vector convert_to_opencl_pass::operator()(gemv_inst const &g) { *v); }; - auto gemm_ty = - gemm_scalar_type{get_scalar_type(*g.alpha()->ty()), a->element_ty(), b->element_ty(), - get_scalar_type(*g.beta()->ty()), c->element_ty()}; + auto gemm_ty = gemm_scalar_type{get_scalar_type(*g.alpha()), a->element_ty(), b->element_ty(), + get_scalar_type(*g.beta()), c->element_ty()}; auto cfg = gemm_configuration{std::move(gemm_ty), g.tA(), transpose::N, @@ -715,8 +713,8 @@ std::vector convert_to_opencl_pass::operator()(ger_inst const &g) { auto alpha = visit(*this, *g.alpha()); auto beta = visit(*this, *g.beta()); - auto alpha_ty = get_scalar_type(*g.alpha()->ty()); - auto beta_ty = get_scalar_type(*g.beta()->ty()); + auto alpha_ty = get_scalar_type(*g.alpha()); + auto beta_ty = get_scalar_type(*g.beta()); auto A = visit(*this, *g.A()); auto B = visit(*this, *g.B()); @@ -818,8 +816,8 @@ std::vector convert_to_opencl_pass::operator()(hadamard_inst const & auto alpha = visit(*this, *g.alpha()); auto beta = visit(*this, *g.beta()); - auto alpha_ty = get_scalar_type(*g.alpha()->ty()); - auto beta_ty = get_scalar_type(*g.beta()->ty()); + auto alpha_ty = get_scalar_type(*g.alpha()); + auto beta_ty = get_scalar_type(*g.beta()); auto A = visit(*this, *g.A()); auto B = visit(*this, *g.B()); @@ -995,8 +993,8 @@ std::vector convert_to_opencl_pass::operator()(sum_inst const &inst) auto alpha = visit(*this, *inst.alpha()); auto beta = visit(*this, *inst.beta()); - auto alpha_ty = get_scalar_type(*inst.alpha()->ty()); - auto beta_ty = get_scalar_type(*inst.beta()->ty()); + auto alpha_ty = get_scalar_type(*inst.alpha()); + auto beta_ty = get_scalar_type(*inst.beta()); auto zero = clir::expr(0.0, static_cast(size(at->element_ty()) * 8)); diff --git a/src/pass/convert_to_opencl.hpp b/src/pass/convert_to_opencl.hpp index 755987f3..6de99374 100644 --- a/src/pass/convert_to_opencl.hpp +++ b/src/pass/convert_to_opencl.hpp @@ -112,7 +112,7 @@ class convert_to_opencl_pass { void set_dope_vector(value_node *v, dope_vector dv); clir::var declare(value_node const &v); auto get_memref_type(value_node const &v) const -> const memref_data_type *; - static auto get_scalar_type(data_type_node const &ty) -> scalar_type; + static auto get_scalar_type(value_node const &v) -> scalar_type; ::tinytc_core_info const *info_; clir::program_builder prog_builder_; From c0135bcd8355f2c11058d138a6cbd53b80fcebd0 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 26 Sep 2024 10:02:49 +0200 Subject: [PATCH 027/297] Store unique data types in context Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.rst | 38 ++-- docs/api/builder_capi.yaml | 8 +- docs/api/builder_cxxapi.rst | 114 +++--------- docs/api/builder_cxxapi.yaml | 22 +-- include/tinytc/tinytc.h | 71 ++++--- include/tinytc/tinytc.hpp | 299 +++++++++--------------------- include/tinytc/types.h | 1 + src/CMakeLists.txt | 1 + src/analysis/alias.cpp | 2 +- src/analysis/equal.cpp | 5 +- src/codegen_tools.cpp | 88 +++++---- src/compiler.cpp | 2 +- src/compiler_context.cpp | 6 +- src/compiler_context.hpp | 11 +- src/compiler_context_cache.cpp | 26 +++ src/compiler_context_cache.hpp | 42 +++++ src/data_type.cpp | 75 ++++---- src/inst.cpp | 50 ++--- src/node/data_type_node.cpp | 98 ++++++++-- src/node/data_type_node.hpp | 56 ++++-- src/node/inst_node.cpp | 43 +++-- src/node/inst_node.hpp | 35 ++-- src/node/value_node.hpp | 30 ++- src/parser/parse_context.hpp | 2 +- src/parser/parser_impl.yy | 197 +++++++++++--------- src/pass/convert_to_opencl.cpp | 10 +- src/pass/insert_barrier.cpp | 4 +- src/pass/stack.cpp | 2 +- src/pass/work_group_size.cpp | 2 +- src/recipe/small_gemm_batched.cpp | 44 +++-- src/recipe/tall_and_skinny.cpp | 65 ++++--- src/support/util.hpp | 25 +++ src/value.cpp | 10 +- test/codegen/load.ir | 2 +- test/opt/insert-lifetime-stop.ir | 19 +- 35 files changed, 752 insertions(+), 753 deletions(-) create mode 100644 src/compiler_context_cache.cpp create mode 100644 src/compiler_context_cache.hpp diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index ebd186cd..45f5ef17 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -241,43 +241,29 @@ Data Type * Functions - * :ref:`tinytc_group_type_create` + * :ref:`tinytc_group_type_get` - * :ref:`tinytc_memref_type_create` + * :ref:`tinytc_memref_type_get` - * :ref:`tinytc_scalar_type_create` - - * :ref:`tinytc_data_type_release` - - * :ref:`tinytc_data_type_retain` + * :ref:`tinytc_scalar_type_get` Data Type Functions ------------------- -tinytc_group_type_create -........................ - -.. doxygenfunction:: tinytc_group_type_create - -tinytc_memref_type_create -......................... - -.. doxygenfunction:: tinytc_memref_type_create - -tinytc_scalar_type_create -......................... +tinytc_group_type_get +..................... -.. doxygenfunction:: tinytc_scalar_type_create +.. doxygenfunction:: tinytc_group_type_get -tinytc_data_type_release -........................ +tinytc_memref_type_get +...................... -.. doxygenfunction:: tinytc_data_type_release +.. doxygenfunction:: tinytc_memref_type_get -tinytc_data_type_retain -....................... +tinytc_scalar_type_get +...................... -.. doxygenfunction:: tinytc_data_type_retain +.. doxygenfunction:: tinytc_scalar_type_get Function ======== diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index e97c45ae..de132721 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -38,11 +38,9 @@ Builder C-API: - const_tinytc_region_t Data Type: function: - - tinytc_group_type_create - - tinytc_memref_type_create - - tinytc_scalar_type_create - - tinytc_data_type_release - - tinytc_data_type_retain + - tinytc_group_type_get + - tinytc_memref_type_get + - tinytc_scalar_type_get Function: function: - tinytc_function_create diff --git a/docs/api/builder_cxxapi.rst b/docs/api/builder_cxxapi.rst index d59b5992..3888ccd7 100644 --- a/docs/api/builder_cxxapi.rst +++ b/docs/api/builder_cxxapi.rst @@ -164,15 +164,11 @@ Data Type * Functions - * :ref:`make_memref` + * :ref:`get_memref` - * :ref:`make_group` + * :ref:`get_group` - * :ref:`make_scalar` - -* Classes - - * :ref:`data_type` + * :ref:`get_scalar` * Structures @@ -185,28 +181,20 @@ Data Type Data Type Functions ------------------- -make_memref -........... - -.. doxygenfunction:: tinytc::make_memref - -make_group +get_memref .......... -.. doxygenfunction:: tinytc::make_group - -make_scalar -........... +.. doxygenfunction:: tinytc::get_memref -.. doxygenfunction:: tinytc::make_scalar +get_group +......... -Data Type Classes ------------------ +.. doxygenfunction:: tinytc::get_group -data_type -......... +get_scalar +.......... -.. doxygenclass:: tinytc::data_type +.. doxygenfunction:: tinytc::get_scalar Data Type Structures -------------------- @@ -561,27 +549,11 @@ Value * Functions - * :ref:`make_dynamic(location const&)` - - * :ref:`make_imm(float,location const&)` - - * :ref:`make_imm(double,scalar_type,location const&)` - - * :ref:`make_imm(std::int8_t,location const&)` - - * :ref:`make_imm(std::int16_t,location const&)` - - * :ref:`make_imm(std::int32_t,location const&)` - - * :ref:`make_imm(std::int64_t,scalar_type,location const&)` - - * :ref:`make_index(std::int32_t,location const&)` - - * :ref:`make_index(std::int64_t,location const&)` + * :ref:`make_fimm` - * :ref:`make_value(data_type const&,location const&)` + * :ref:`make_imm` - * :ref:`make_value(scalar_type,location const&)` + * :ref:`make_value` * Classes @@ -590,60 +562,20 @@ Value Value Functions --------------- -make_dynamic(location const&) -............................. - -.. doxygenfunction:: tinytc::make_dynamic(location const&) - -make_imm(float,location const&) -............................... - -.. doxygenfunction:: tinytc::make_imm(float,location const&) - -make_imm(double,scalar_type,location const&) -............................................ - -.. doxygenfunction:: tinytc::make_imm(double,scalar_type,location const&) - -make_imm(std::int8_t,location const&) -..................................... - -.. doxygenfunction:: tinytc::make_imm(std::int8_t,location const&) - -make_imm(std::int16_t,location const&) -...................................... - -.. doxygenfunction:: tinytc::make_imm(std::int16_t,location const&) - -make_imm(std::int32_t,location const&) -...................................... - -.. doxygenfunction:: tinytc::make_imm(std::int32_t,location const&) - -make_imm(std::int64_t,scalar_type,location const&) -.................................................. - -.. doxygenfunction:: tinytc::make_imm(std::int64_t,scalar_type,location const&) - -make_index(std::int32_t,location const&) -........................................ - -.. doxygenfunction:: tinytc::make_index(std::int32_t,location const&) - -make_index(std::int64_t,location const&) -........................................ +make_fimm +......... -.. doxygenfunction:: tinytc::make_index(std::int64_t,location const&) +.. doxygenfunction:: tinytc::make_fimm -make_value(data_type const&,location const&) -............................................ +make_imm +........ -.. doxygenfunction:: tinytc::make_value(data_type const&,location const&) +.. doxygenfunction:: tinytc::make_imm -make_value(scalar_type,location const&) -....................................... +make_value +.......... -.. doxygenfunction:: tinytc::make_value(scalar_type,location const&) +.. doxygenfunction:: tinytc::make_value Value Classes ------------- diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index 1092232c..6994f28b 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -27,11 +27,9 @@ Builder C++-API: - tinytc::dynamic Data Type: function: - - tinytc::make_memref - - tinytc::make_group - - tinytc::make_scalar - class: - - tinytc::data_type + - tinytc::get_memref + - tinytc::get_group + - tinytc::get_scalar struct: - tinytc::to_scalar_type variable: @@ -90,16 +88,8 @@ Builder C++-API: - tinytc::region_builder Value: function: - - tinytc::make_dynamic(location const&) - - tinytc::make_imm(float,location const&) - - tinytc::make_imm(double,scalar_type,location const&) - - tinytc::make_imm(std::int8_t,location const&) - - tinytc::make_imm(std::int16_t,location const&) - - tinytc::make_imm(std::int32_t,location const&) - - tinytc::make_imm(std::int64_t,scalar_type,location const&) - - tinytc::make_index(std::int32_t,location const&) - - tinytc::make_index(std::int64_t,location const&) - - tinytc::make_value(data_type const&,location const&) - - tinytc::make_value(scalar_type,location const&) + - tinytc::make_fimm + - tinytc::make_imm + - tinytc::make_value class: - tinytc::value diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 63762240..edb1cdc2 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -50,22 +50,23 @@ TINYTC_EXPORT size_t tinytc_scalar_type_size(tinytc_scalar_type_t ty); //////////////////////////// /** - * @brief Create scalar data type + * @brief Get scalar data type * * @param dt [out] pointer to the data type object created + * @param ctx [inout] compiler context * @param type [in] scalar type - * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_scalar_type_create(tinytc_data_type_t *dt, - tinytc_scalar_type_t type, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_scalar_type_get(tinytc_data_type_t *dt, + tinytc_compiler_context_t ctx, + tinytc_scalar_type_t type); /** - * @brief Create memref data type + * @brief Get memref data type * * @param dt [out] pointer to the data type object created + * @param ctx [inout] compiler context * @param scalar_ty [in] element type * @param shape_size [in] tensor order; number of elements in shape array, must be 0 if shape == * nullptr @@ -78,46 +79,26 @@ TINYTC_EXPORT tinytc_status_t tinytc_scalar_type_create(tinytc_data_type_t *dt, * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_memref_type_create(tinytc_data_type_t *dt, - tinytc_scalar_type_t scalar_ty, - uint32_t shape_size, const int64_t *shape, - uint32_t stride_size, const int64_t *stride, - const tinytc_address_space_t addrspace, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_memref_type_get( + tinytc_data_type_t *dt, tinytc_compiler_context_t ctx, tinytc_scalar_type_t scalar_ty, + uint32_t shape_size, const int64_t *shape, uint32_t stride_size, const int64_t *stride, + const tinytc_address_space_t addrspace, const tinytc_location_t *loc); /** - * @brief Create group data type + * @brief Get group data type * * @param dt [out] pointer to the data type object created + * @param ctx [inout] compiler context * @param memref_ty [in] memref data type object * @param offset [in][optional] offset parameter; pass 0 for default * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_group_type_create(tinytc_data_type_t *dt, - tinytc_data_type_t memref_ty, int64_t offset, - const tinytc_location_t *loc); - -/** - * @brief Release data type object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param dt [inout] data type object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_data_type_release(tinytc_data_type_t dt); - -/** - * @brief Increase reference count of data type object by 1 - * - * @param dt [inout] data type object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_data_type_retain(tinytc_data_type_t dt); +TINYTC_EXPORT tinytc_status_t tinytc_group_type_get(tinytc_data_type_t *dt, + tinytc_compiler_context_t ctx, + tinytc_data_type_t memref_ty, int64_t offset, + const tinytc_location_t *loc); //////////////////////////// /////////// Value ////////// @@ -146,7 +127,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_value_create(tinytc_value_t *vl, tinytc_dat * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_float_imm_create(tinytc_value_t *vl, double imm, - tinytc_scalar_type_t type, + tinytc_data_type_t type, const tinytc_location_t *loc); /** * @brief Create integer immediate value @@ -159,7 +140,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_float_imm_create(tinytc_value_t *vl, double * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_int_imm_create(tinytc_value_t *vl, int64_t imm, - tinytc_scalar_type_t type, + tinytc_data_type_t type, const tinytc_location_t *loc); /** @@ -386,11 +367,13 @@ TINYTC_EXPORT tinytc_status_t tinytc_load_inst_create(tinytc_inst_t *instr, tiny * @code %value = group_id @endcode * * @param instr [out] pointer to the inst object created + * @param ctx [in] compiler context * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_group_id_inst_create(tinytc_inst_t *instr, + tinytc_compiler_context_t ctx, const tinytc_location_t *loc); /** @@ -399,11 +382,13 @@ TINYTC_EXPORT tinytc_status_t tinytc_group_id_inst_create(tinytc_inst_t *instr, * @code %value = group_size @endcode * * @param instr [out] pointer to the inst object created + * @param ctx [in] compiler context * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_group_size_inst_create(tinytc_inst_t *instr, + tinytc_compiler_context_t ctx, const tinytc_location_t *loc); /** @@ -514,11 +499,13 @@ TINYTC_EXPORT tinytc_status_t tinytc_hadamard_inst_create( * @code %value = num_subgroups @endcode * * @param instr [out] pointer to the inst object created + * @param ctx [in] compiler context * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_num_subgroups_inst_create(tinytc_inst_t *instr, + tinytc_compiler_context_t ctx, const tinytc_location_t *loc); /** @@ -561,11 +548,13 @@ TINYTC_EXPORT tinytc_status_t tinytc_size_inst_create(tinytc_inst_t *instr, tiny * @code %value = subgroup_id @endcode * * @param instr [out] pointer to the inst object created + * @param ctx [in] compiler context * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_subgroup_id_inst_create(tinytc_inst_t *instr, + tinytc_compiler_context_t ctx, const tinytc_location_t *loc); /** @@ -574,11 +563,13 @@ TINYTC_EXPORT tinytc_status_t tinytc_subgroup_id_inst_create(tinytc_inst_t *inst * @code %value = subgroup_local_id @endcode * * @param instr [out] pointer to the inst object created + * @param ctx [in] compiler context * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_subgroup_local_id_inst_create(tinytc_inst_t *instr, + tinytc_compiler_context_t ctx, const tinytc_location_t *loc); /** @@ -587,11 +578,13 @@ TINYTC_EXPORT tinytc_status_t tinytc_subgroup_local_id_inst_create(tinytc_inst_t * @code %value = subgroup_size @endcode * * @param instr [out] pointer to the inst object created + * @param ctx [in] compiler context * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_subgroup_size_inst_create(tinytc_inst_t *instr, + tinytc_compiler_context_t ctx, const tinytc_location_t *loc); /** @@ -737,7 +730,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, TINYTC_EXPORT tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condition, tinytc_region_t then, tinytc_region_t otherwise, uint32_t return_type_list_size, - tinytc_scalar_type_t *return_type_list, + tinytc_data_type_t *return_type_list, const tinytc_location_t *loc); /** diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 74178cb1..45e5a8f9 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -357,45 +357,29 @@ inline auto make_compiler_context() -> compiler_context { //! Check if mode i is dynamic ('?') inline bool is_dynamic_value(std::int64_t i) { return i == dynamic; } -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_data_type_t handle) -> tinytc_status_t { - return tinytc_data_type_retain(handle); - } - static auto release(tinytc_data_type_t handle) -> tinytc_status_t { - return tinytc_data_type_release(handle); - } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_data_type_t -class data_type : public shared_handle { - public: - using shared_handle::shared_handle; -}; - /** - * @brief Make a scalar data type + * @brief Get a scalar data type * - * Cf. \ref tinytc_scalar_type_create + * Cf. \ref tinytc_scalar_type_get * + * @param ctx Compiler context * @param scalar_ty Scalar type - * @param loc Source code location * * @return Data type */ -inline data_type make_scalar(scalar_type scalar_ty, location const &loc = {}) { +inline tinytc_data_type_t get_scalar(compiler_context const &ctx, scalar_type scalar_ty) { tinytc_data_type_t st; - CHECK_STATUS_LOC( - tinytc_scalar_type_create(&st, static_cast(scalar_ty), &loc), loc); - return data_type{st}; + CHECK_STATUS( + tinytc_scalar_type_get(&st, ctx.get(), static_cast(scalar_ty))); + return st; } /** - * @brief Make a memref data type + * @brief Get a memref data type * - * Cf. \ref tinytc_memref_type_create + * Cf. \ref tinytc_memref_type_get * + * @param ctx Compiler context * @param scalar_ty Element type * @param shape Tensor shape * @param stride Tensor stride @@ -404,33 +388,35 @@ inline data_type make_scalar(scalar_type scalar_ty, location const &loc = {}) { * * @return Data type */ -inline data_type make_memref(scalar_type scalar_ty, std::vector const &shape, - std::vector const &stride = {}, - const address_space addrspace = address_space::global, - location const &loc = {}) { +inline tinytc_data_type_t get_memref(compiler_context const &ctx, scalar_type scalar_ty, + std::vector const &shape, + std::vector const &stride = {}, + const address_space addrspace = address_space::global, + location const &loc = {}) { tinytc_data_type_t mt; CHECK_STATUS_LOC( - tinytc_memref_type_create(&mt, static_cast(scalar_ty), shape.size(), - shape.data(), stride.size(), stride.data(), - static_cast(addrspace), &loc), + tinytc_memref_type_get(&mt, ctx.get(), static_cast(scalar_ty), + shape.size(), shape.data(), stride.size(), stride.data(), + static_cast(addrspace), &loc), loc); - return data_type{mt}; + return mt; } /** - * @brief Make a group data type + * @brief Get a group data type * + * @param ctx Compiler context * @param memref_ty Memref data type * @param offset Offset parameter * @param loc Source code location * * @return Data type */ -inline data_type make_group(data_type const &memref_ty, std::int64_t offset = 0, - location const &loc = {}) { +inline tinytc_data_type_t get_group(compiler_context const &ctx, tinytc_data_type_t memref_ty, + std::int64_t offset = 0, location const &loc = {}) { tinytc_data_type_t gt; - CHECK_STATUS_LOC(tinytc_group_type_create(>, memref_ty.get(), offset, &loc), loc); - return data_type{gt}; + CHECK_STATUS_LOC(tinytc_group_type_get(>, ctx.get(), memref_ty, offset, &loc), loc); + return gt; } //////////////////////////// @@ -487,40 +473,9 @@ constexpr bool value_reinterpret_allowed = * * @return Value */ -inline auto make_value(data_type const &ty, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_value_create(&val, ty.get(), &loc), loc); - return value{val}; -} - -/** - * @brief Make value - * - * @param scalar_ty Scalar type - * @param loc Source code location - * - * @return Value - */ -inline auto make_value(scalar_type scalar_ty, location const &loc = {}) -> value { - tinytc_value_t val; - auto ty = make_scalar(scalar_ty, loc); - CHECK_STATUS_LOC(tinytc_value_create(&val, ty.get(), &loc), loc); - return value{val}; -} - -/** - * @brief Make immediate value - * - * Type is f32. - * - * @param imm Float value - * @param loc Source code location - * - * @return Value - */ -inline auto make_imm(float imm, location const &loc = {}) -> value { +inline auto make_value(tinytc_data_type_t ty, location const &loc = {}) -> value { tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_float_imm_create(&val, imm, tinytc_scalar_type_f32, &loc), loc); + CHECK_STATUS_LOC(tinytc_value_create(&val, ty, &loc), loc); return value{val}; } @@ -533,59 +488,9 @@ inline auto make_imm(float imm, location const &loc = {}) -> value { * * @return Value */ -inline auto make_imm(double imm, scalar_type type = scalar_type::f64, - location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC( - tinytc_float_imm_create(&val, imm, static_cast(type), &loc), loc); - return value{val}; -} - -/** - * @brief Make immediate value - * - * Type is i8. - * - * @param imm Int value - * @param loc Source code location - * - * @return Value - */ -inline auto make_imm(std::int8_t imm, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, tinytc_scalar_type_i8, &loc), loc); - return value{val}; -} - -/** - * @brief Make immediate value - * - * Type is i16. - * - * @param imm Int value - * @param loc Source code location - * - * @return Value - */ -inline auto make_imm(std::int16_t imm, location const &loc = {}) -> value { +inline auto make_fimm(double imm, tinytc_data_type_t type, location const &loc = {}) -> value { tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, tinytc_scalar_type_i16, &loc), loc); - return value{val}; -} - -/** - * @brief Make immediate value - * - * Type is i32. - * - * @param imm Int value - * @param loc Source code location - * - * @return Value - */ -inline auto make_imm(std::int32_t imm, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, tinytc_scalar_type_i32, &loc), loc); + CHECK_STATUS_LOC(tinytc_float_imm_create(&val, imm, type, &loc), loc); return value{val}; } @@ -598,52 +503,9 @@ inline auto make_imm(std::int32_t imm, location const &loc = {}) -> value { * * @return Value */ -inline auto make_imm(std::int64_t imm, scalar_type type = scalar_type::i64, - location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC( - tinytc_int_imm_create(&val, imm, static_cast(type), &loc), loc); - return value{val}; -} - -/** - * @brief Make immediate index value - * - * @param imm index value - * @param loc Source code location - * - * @return Value - */ -inline auto make_index(std::int32_t imm, location const &loc = {}) -> value { +inline auto make_imm(std::int64_t imm, tinytc_data_type_t type, location const &loc = {}) -> value { tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, tinytc_scalar_type_index, &loc), loc); - return value{val}; -} - -/** - * @brief Make immediate index value - * - * @param imm index value - * @param loc Source code location - * - * @return Value - */ -inline auto make_index(std::int64_t imm, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, tinytc_scalar_type_index, &loc), loc); - return value{val}; -} - -/** - * @brief Make dynamic ('?') - * - * @param loc Source code location - * - * @return Value - */ -inline auto make_dynamic(location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, dynamic, tinytc_scalar_type_i64, &loc), loc); + CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, type, &loc), loc); return value{val}; } @@ -864,9 +726,9 @@ inline inst make_cmp(cmp_condition cond, value const &a, value const &b, locatio * * @return Instruction */ -inline inst make_alloca(data_type const &ty, location const &loc = {}) { +inline inst make_alloca(tinytc_data_type_t ty, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_alloca_inst_create(&instr, ty.get(), &loc), loc); + CHECK_STATUS_LOC(tinytc_alloca_inst_create(&instr, ty, &loc), loc); return inst(instr); } @@ -959,26 +821,28 @@ inline inst make_load(value const &a, std::vector const &index_list, /** * @brief Make group id instruction * + * @param ctx compiler context * @param loc Source code location * * @return Instruction */ -inline inst make_group_id(location const &loc = {}) { +inline inst make_group_id(compiler_context const &ctx, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_group_id_inst_create(&instr, &loc), loc); + CHECK_STATUS_LOC(tinytc_group_id_inst_create(&instr, ctx.get(), &loc), loc); return inst(instr); } /** * @brief Make group size instruction * + * @param ctx compiler context * @param loc Source code location * * @return Instruction */ -inline inst make_group_size(location const &loc = {}) { +inline inst make_group_size(compiler_context const &ctx, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_group_size_inst_create(&instr, &loc), loc); + CHECK_STATUS_LOC(tinytc_group_size_inst_create(&instr, ctx.get(), &loc), loc); return inst(instr); } @@ -1079,13 +943,14 @@ inline inst make_hadamard(bool atomic, value const &alpha, value const &A, value /** * @brief Make num_subgroups instruction * + * @param ctx compiler context * @param loc Source code location * * @return Instruction */ -inline inst make_num_subgroups(location const &loc = {}) { +inline inst make_num_subgroups(compiler_context const &ctx, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_num_subgroups_inst_create(&instr, &loc), loc); + CHECK_STATUS_LOC(tinytc_num_subgroups_inst_create(&instr, ctx.get(), &loc), loc); return inst(instr); } @@ -1121,39 +986,42 @@ inline inst make_size(value const &a, std::int64_t mode, location const &loc = { /** * @brief Make subgroup_id instruction * + * @param ctx compiler context * @param loc Source code location * * @return Instruction */ -inline inst make_subgroup_id(location const &loc = {}) { +inline inst make_subgroup_id(compiler_context const &ctx, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_subgroup_id_inst_create(&instr, &loc), loc); + CHECK_STATUS_LOC(tinytc_subgroup_id_inst_create(&instr, ctx.get(), &loc), loc); return inst(instr); } /** * @brief Make subgroup_local_id instruction * + * @param ctx compiler context * @param loc Source code location * * @return Instruction */ -inline inst make_subgroup_local_id(location const &loc = {}) { +inline inst make_subgroup_local_id(compiler_context const &ctx, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_subgroup_local_id_inst_create(&instr, &loc), loc); + CHECK_STATUS_LOC(tinytc_subgroup_local_id_inst_create(&instr, ctx.get(), &loc), loc); return inst(instr); } /** * @brief Make subgroup_size instruction * + * @param ctx compiler context * @param loc Source code location * * @return Instruction */ -inline inst make_subgroup_size(location const &loc = {}) { +inline inst make_subgroup_size(compiler_context const &ctx, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_subgroup_size_inst_create(&instr, &loc), loc); + CHECK_STATUS_LOC(tinytc_subgroup_size_inst_create(&instr, ctx.get(), &loc), loc); return inst(instr); } @@ -1285,21 +1153,17 @@ inline inst make_foreach(value const &loop_var, value const &from, value const & * @return Instruction */ inline inst make_if(value const &condition, region then, region otherwise = region{}, - std::vector const &return_type_list = {}, + std::vector const &return_type_list = {}, location const &loc = {}) { tinytc_inst_t instr; auto len = return_type_list.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("return type list too long"); } - auto rl_vec = std::vector(); - rl_vec.resize(len); - for (auto const &rt : return_type_list) { - rl_vec.emplace_back(static_cast(rt)); - } - CHECK_STATUS_LOC(tinytc_if_inst_create(&instr, condition.get(), then.release(), - otherwise.release(), len, rl_vec.data(), &loc), - loc); + CHECK_STATUS_LOC( + tinytc_if_inst_create(&instr, condition.get(), then.release(), otherwise.release(), len, + const_cast(return_type_list.data()), &loc), + loc); return inst(instr); } @@ -1428,7 +1292,7 @@ class prog : public shared_handle { auto get_compiler_context() const -> compiler_context { tinytc_compiler_context_t ctx; CHECK_STATUS(tinytc_prog_get_compiler_context(obj_, &ctx)); - return compiler_context{ctx}; + return compiler_context{ctx, true}; } /** * @brief Dump program to file @@ -1474,9 +1338,11 @@ class region_builder { /** * @brief ctor * + * @param ctx compiler context * @param loc Source code location */ - region_builder(location const &loc = {}) : reg_{make_region(loc)} {} + region_builder(compiler_context const &ctx, location const &loc = {}) + : ctx_(ctx), reg_{make_region(loc)} {} /** * @brief Returns built product @@ -1559,11 +1425,11 @@ class region_builder { template void for_loop(scalar_type loop_var_ty, value const &from, value const &to, value const &step, F &&f, std::string const &name = "", location const &loc = {}) { - auto loop_var = make_value(loop_var_ty); + auto loop_var = make_value(get_scalar(ctx_, loop_var_ty)); if (name.size() > 0) { loop_var.name(name); } - auto bb = region_builder{}; + auto bb = region_builder{ctx_}; f(bb, loop_var); add(::tinytc::make_for(std::move(loop_var), from, to, step, std::move(bb).get_product(), loc)); @@ -1580,13 +1446,13 @@ class region_builder { * @param loc Source code location */ template - void foreach (data_type const &loop_var_ty, value const &from, value const &to, F && f, + void foreach (scalar_type loop_var_ty, value const &from, value const &to, F && f, std::string const &name = "", location const &loc = {}) { - auto loop_var = make_value(loop_var_ty); + auto loop_var = make_value(get_scalar(ctx_, loop_var_ty)); if (name.size() > 0) { loop_var.name(name); } - auto bb = region_builder{}; + auto bb = region_builder{ctx_}; f(bb); add(::tinytc::make_foreach(std::move(loop_var), from, to, std::move(bb).get_product(), loc)); @@ -1605,9 +1471,9 @@ class region_builder { */ template auto if_condition(value const &condition, F &&then, - std::vector const &return_type_list = {}, + std::vector const &return_type_list = {}, location const &loc = {}) -> std::vector { - auto bb = region_builder{}; + auto bb = region_builder{ctx_}; then(bb); return add_multivalued(::tinytc::make_if(std::move(condition), std::move(bb).get_product(), region{}, return_type_list, loc)); @@ -1628,18 +1494,21 @@ class region_builder { */ template auto ifelse(value const &condition, F &&then, G &&otherwise, - std::vector const &return_type_list = {}, + std::vector const &return_type_list = {}, location const &loc = {}) -> std::vector { - auto bb1 = region_builder{}; + auto bb1 = region_builder{ctx_}; then(bb1); - auto bb2 = region_builder{}; + auto bb2 = region_builder{ctx_}; otherwise(bb2); return add_multivalued(::tinytc::make_if(std::move(condition), std::move(bb1).get_product(), std::move(bb2).get_product(), return_type_list, loc)); } + inline auto context() -> compiler_context const & { return ctx_; } + private: + compiler_context ctx_; region reg_; }; @@ -1649,12 +1518,13 @@ class function_builder { /** * @brief creates function \@name * + * @param ctx compiler context * @param name Function name * @param loc Source code location * */ - inline function_builder(std::string name, location const &loc = {}) - : name_(std::move(name)), body_{nullptr}, loc_(loc) {} + inline function_builder(compiler_context const &ctx, std::string name, location const &loc = {}) + : ctx_(ctx), name_(std::move(name)), body_{nullptr}, loc_(loc) {} /** * @brief Returns built product @@ -1681,7 +1551,7 @@ class function_builder { * * @return Value */ - inline value argument(data_type const &ty, std::string const &name = "", + inline value argument(tinytc_data_type_t ty, std::string const &name = "", location const &loc = {}) { auto v = make_value(ty, loc); if (name.size() > 0) { @@ -1716,12 +1586,13 @@ class function_builder { * @param loc Source code location */ template void body(F &&f, location const &loc = {}) { - auto bb = region_builder{loc}; + auto bb = region_builder{ctx_, loc}; f(bb); body_ = std::move(bb).get_product(); } private: + compiler_context ctx_; std::string name_; region body_; location loc_; @@ -1751,7 +1622,7 @@ class program_builder { * @param loc Source code location */ template void create(std::string name, F &&f, location const &loc = {}) { - auto fb = function_builder(std::move(name), loc); + auto fb = function_builder(prg_.get_compiler_context(), std::move(name), loc); f(fb); add(std::move(fb).get_product()); } @@ -1911,7 +1782,7 @@ inline auto parse_file(char const *filename, compiler_context ctx = {}) -> prog * * @return Program */ -inline auto parse_stdin(compiler_context ctx = {}) -> prog { +inline auto parse_stdin(compiler_context const &ctx = {}) -> prog { tinytc_prog_t prg; CHECK_STATUS(tinytc_parse_stdin(&prg, ctx.get())); return prog(prg); @@ -1924,7 +1795,7 @@ inline auto parse_stdin(compiler_context ctx = {}) -> prog { * * @return Porgram */ -inline auto parse_string(std::string const &src, compiler_context ctx = {}) -> prog { +inline auto parse_string(std::string const &src, compiler_context const &ctx = {}) -> prog { tinytc_prog_t prg; CHECK_STATUS(tinytc_parse_string(&prg, src.size(), src.c_str(), ctx.get())); return prog(prg); @@ -2261,7 +2132,7 @@ inline auto make_small_gemm_batched(core_info const &info, scalar_type ty, trans transpose tB, std::int64_t M, std::int64_t N, std::int64_t K, std::int64_t ldA, std::int64_t strideA, std::int64_t ldB, std::int64_t strideB, std::int64_t ldC, std::int64_t strideC, - compiler_context ctx = {}) -> small_gemm_batched { + compiler_context const &ctx = {}) -> small_gemm_batched { tinytc_recipe_t rec; CHECK_STATUS(tinytc_recipe_small_gemm_batched_create( &rec, info.get(), static_cast(ty), @@ -2316,7 +2187,7 @@ class tall_and_skinny : public recipe { */ inline auto make_tall_and_skinny(core_info const &info, scalar_type ty, std::int64_t N, std::int64_t K, std::int32_t M_block_size = 0, - compiler_context ctx = {}) -> tall_and_skinny { + compiler_context const &ctx = {}) -> tall_and_skinny { tinytc_recipe_t rec; CHECK_STATUS(tinytc_recipe_tall_and_skinny_create( &rec, info.get(), static_cast(ty), N, K, M_block_size, ctx.get())); @@ -2345,7 +2216,7 @@ inline auto make_tall_and_skinny_specialized(core_info const &info, scalar_type std::int64_t N, std::int64_t K, std::int64_t ldA, std::int64_t ldB, std::int64_t ldC, std::int32_t M_block_size = 0, - compiler_context ctx = {}) -> tall_and_skinny { + compiler_context const &ctx = {}) -> tall_and_skinny { tinytc_recipe_t rec; CHECK_STATUS(tinytc_recipe_tall_and_skinny_create_specialized( &rec, info.get(), static_cast(ty), M, N, K, ldA, ldB, ldC, diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 3278a0fd..422f30d7 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -229,6 +229,7 @@ typedef enum { tinytc_scalar_type_c32 = 8, ///< Single precision complex (2x32 bit) tinytc_scalar_type_c64 = 9 ///< Double precision complex (2x64 bit) } tinytc_scalar_type_t; +#define TINYTC_NUMBER_OF_SCALAR_TYPES 10 // @todo Keep up to date with tinytc_scalar_type_t //! Arithmetic operations typedef enum { diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9b7737e5..ec1f1dc6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -25,6 +25,7 @@ set(SOURCES codegen_tools.cpp compiler.cpp compiler_context.cpp + compiler_context_cache.cpp data_type.cpp device_info.cpp error.cpp diff --git a/src/analysis/alias.cpp b/src/analysis/alias.cpp index 28a388c7..39097e63 100644 --- a/src/analysis/alias.cpp +++ b/src/analysis/alias.cpp @@ -35,7 +35,7 @@ class alias_analysis_visitor { void alias_analysis_visitor::operator()(inst_node const &) {} void alias_analysis_visitor::operator()(alloca_inst const &a) { if (a.stack_ptr() >= 0) { - auto t = dyn_cast(a.result()->ty().get()); + auto t = dyn_cast(a.result()->ty()); if (t == nullptr) { throw compilation_error(a.loc(), status::ir_expected_memref); } diff --git a/src/analysis/equal.cpp b/src/analysis/equal.cpp index 02ae23c5..a2377bf7 100644 --- a/src/analysis/equal.cpp +++ b/src/analysis/equal.cpp @@ -12,10 +12,11 @@ namespace tinytc { bool equal::operator()(data_type_node const &, data_type_node const &) { return false; } bool equal::operator()(void_data_type const &, void_data_type const &) { return true; } bool equal::operator()(group_data_type const &a, group_data_type const &b) { - return visit(*this, *a.ty(), *b.ty()); + return visit(*this, *a.ty(), *b.ty()) && a.offset() == b.offset(); } bool equal::operator()(memref_data_type const &a, memref_data_type const &b) { - return a.element_ty() == b.element_ty() && a.shape() == b.shape() && a.stride() == b.stride(); + return a.element_ty() == b.element_ty() && a.shape() == b.shape() && a.stride() == b.stride() && + a.addrspace() == b.addrspace(); } bool equal::operator()(scalar_data_type const &a, scalar_data_type const &b) { return a.ty() == b.ty(); diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 3c89c7d0..762be241 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -448,27 +448,29 @@ void tile_loop_by_sgs_new(region_builder &bb, value const &loop_trip_count, int void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_count, int sgs, int num_tiles, value const &sg_id, sgs_loop_body_builder_new const &body) { + auto index_ty = get_scalar(bb.context(), scalar_type::index); std::int64_t blocks = loop_trip_count / sgs; std::int64_t rem = loop_trip_count % sgs; auto sg_id_index = bb.add(make_cast(sg_id, scalar_type::index)); if (blocks > 0) { - auto block_start = bb.add(make_arith(arithmetic::mul, make_index(sgs), sg_id_index)); - auto block_end = make_index(sgs * blocks); - auto step = make_index(sgs * num_tiles); + auto block_start = + bb.add(make_arith(arithmetic::mul, make_imm(sgs, index_ty), sg_id_index)); + auto block_end = make_imm(sgs * blocks, index_ty); + auto step = make_imm(sgs * num_tiles, index_ty); bb.for_loop( scalar_type::index, std::move(block_start), std::move(block_end), std::move(step), [&](region_builder &bb, value const &block) { - body(bb, block, false, make_index(sgs)); + body(bb, block, false, make_imm(sgs, index_ty)); }, "block"); } if (rem > 0) { auto condition = - bb.add(make_cmp(cmp_condition::eq, sg_id_index, make_index(num_tiles - 1))); + bb.add(make_cmp(cmp_condition::eq, sg_id_index, make_imm(num_tiles - 1, index_ty))); bb.if_condition(condition, [&](region_builder &bb) { - body(bb, make_index(blocks * sgs), true, make_index(rem)); + body(bb, make_imm(blocks * sgs, index_ty), true, make_imm(rem, index_ty)); }); } } @@ -476,24 +478,27 @@ void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_co void tile_loop_by_sgs_new_dynamic(region_builder &bb, value const &loop_trip_count, int sgs, int num_tiles, value const &sg_id, sgs_loop_body_builder_new const &body) { - auto blocks = bb.add(make_arith(arithmetic::div, loop_trip_count, make_index(sgs))); - auto rem = bb.add(make_arith(arithmetic::rem, loop_trip_count, make_index(sgs))); + auto index_ty = get_scalar(bb.context(), scalar_type::index); + auto blocks = bb.add(make_arith(arithmetic::div, loop_trip_count, make_imm(sgs, index_ty))); + auto rem = bb.add(make_arith(arithmetic::rem, loop_trip_count, make_imm(sgs, index_ty))); auto sg_id_index = bb.add(make_cast(sg_id, scalar_type::index)); - auto block_start = bb.add(make_arith(arithmetic::mul, make_index(sgs), sg_id_index)); - auto block_end = bb.add(make_arith(arithmetic::mul, make_index(sgs), blocks)); - auto step = make_index(sgs * num_tiles); + auto block_start = bb.add(make_arith(arithmetic::mul, make_imm(sgs, index_ty), sg_id_index)); + auto block_end = bb.add(make_arith(arithmetic::mul, make_imm(sgs, index_ty), blocks)); + auto step = make_imm(sgs * num_tiles, index_ty); bb.for_loop( scalar_type::index, std::move(block_start), std::move(block_end), std::move(step), - [&](region_builder &bb, value const &block) { body(bb, block, false, make_index(sgs)); }, + [&](region_builder &bb, value const &block) { + body(bb, block, false, make_imm(sgs, index_ty)); + }, "block"); - auto condition0 = bb.add(make_cmp(cmp_condition::gt, rem, make_index(0))); + auto condition0 = bb.add(make_cmp(cmp_condition::gt, rem, make_imm(0, index_ty))); bb.if_condition(condition0, [&](region_builder &bb) { auto condition1 = - bb.add(make_cmp(cmp_condition::eq, sg_id_index, make_index(num_tiles - 1))); + bb.add(make_cmp(cmp_condition::eq, sg_id_index, make_imm(num_tiles - 1, index_ty))); bb.if_condition(condition1, [&](region_builder &bb) { - auto block = bb.add(make_arith(arithmetic::mul, blocks, make_index(sgs))); + auto block = bb.add(make_arith(arithmetic::mul, blocks, make_imm(sgs, index_ty))); body(bb, block, true, rem); }); }); @@ -519,6 +524,7 @@ void tile_loop_uniformly_new(region_builder &bb, value const &loop_trip_count, i void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip_count, int block_size, int num_tiles, value const &sg_id, uniform_loop_body_builder_new const &body) { + auto index_ty = get_scalar(bb.context(), scalar_type::index); // Find minimum number of blocks such that the block sizes are smaller or equal block_size std::int64_t blocks = 1 + (loop_trip_count - 1) / block_size; // Increase the number of blocks if such that the number of blocks is a multiple @@ -530,40 +536,46 @@ void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip auto sg_id_index = bb.add(make_cast(sg_id, scalar_type::index)); if (rem > 0) { - auto block_start = bb.add(make_arith(arithmetic::mul, make_index(bs_1), sg_id_index)); - auto block_end = make_index(bs_1 * rem); - auto step = make_index(bs_1 * num_tiles); + auto block_start = + bb.add(make_arith(arithmetic::mul, make_imm(bs_1, index_ty), sg_id_index)); + auto block_end = make_imm(bs_1 * rem, index_ty); + auto step = make_imm(bs_1 * num_tiles, index_ty); bb.for_loop( scalar_type::index, std::move(block_start), std::move(block_end), std::move(step), - [&](region_builder &bb, value const &block) { body(bb, block, make_index(bs_1)); }, + [&](region_builder &bb, value const &block) { + body(bb, block, make_imm(bs_1, index_ty)); + }, "block"); } - auto tmp = bb.add(make_arith(arithmetic::add, sg_id_index, make_index(rem % num_tiles))); - auto sg_id_1 = bb.add(make_arith(arithmetic::rem, tmp, make_index(num_tiles))); - auto tmp2 = bb.add(make_arith(arithmetic::mul, make_index(bs), sg_id_1)); - auto block_start = bb.add(make_arith(arithmetic::add, make_index(bs_1 * rem), tmp2)); - auto block_end = make_index(loop_trip_count); - auto step = make_index(bs * num_tiles); + auto tmp = + bb.add(make_arith(arithmetic::add, sg_id_index, make_imm(rem % num_tiles, index_ty))); + auto sg_id_1 = bb.add(make_arith(arithmetic::rem, tmp, make_imm(num_tiles, index_ty))); + auto tmp2 = bb.add(make_arith(arithmetic::mul, make_imm(bs, index_ty), sg_id_1)); + auto block_start = bb.add(make_arith(arithmetic::add, make_imm(bs_1 * rem, index_ty), tmp2)); + auto block_end = make_imm(loop_trip_count, index_ty); + auto step = make_imm(bs * num_tiles, index_ty); bb.for_loop( scalar_type::index, std::move(block_start), std::move(block_end), std::move(step), - [&](region_builder &bb, value const &block) { body(bb, block, make_index(bs)); }, "block"); + [&](region_builder &bb, value const &block) { body(bb, block, make_imm(bs, index_ty)); }, + "block"); } void tile_loop_uniformly_new_dynamic(region_builder &bb, value const &loop_trip_count, int block_size, int num_tiles, value const &sg_id, uniform_loop_body_builder_new const &body) { - auto blocks0 = bb.add(make_arith(arithmetic::sub, loop_trip_count, make_index(1))); - auto blocks1 = bb.add(make_arith(arithmetic::div, blocks0, make_index(block_size))); - auto blocks2 = bb.add(make_arith(arithmetic::add, make_index(1), blocks1)); - auto blocks3 = bb.add(make_arith(arithmetic::sub, blocks2, make_index(1))); - auto blocks4 = bb.add(make_arith(arithmetic::div, blocks3, make_index(num_tiles))); - auto blocks5 = bb.add(make_arith(arithmetic::add, make_index(1), blocks4)); - auto blocks = bb.add(make_arith(arithmetic::mul, blocks5, make_index(num_tiles))); + auto index_ty = get_scalar(bb.context(), scalar_type::index); + auto blocks0 = bb.add(make_arith(arithmetic::sub, loop_trip_count, make_imm(1, index_ty))); + auto blocks1 = bb.add(make_arith(arithmetic::div, blocks0, make_imm(block_size, index_ty))); + auto blocks2 = bb.add(make_arith(arithmetic::add, make_imm(1, index_ty), blocks1)); + auto blocks3 = bb.add(make_arith(arithmetic::sub, blocks2, make_imm(1, index_ty))); + auto blocks4 = bb.add(make_arith(arithmetic::div, blocks3, make_imm(num_tiles, index_ty))); + auto blocks5 = bb.add(make_arith(arithmetic::add, make_imm(1, index_ty), blocks4)); + auto blocks = bb.add(make_arith(arithmetic::mul, blocks5, make_imm(num_tiles, index_ty))); blocks->name("blocks"); auto bs = bb.add(make_arith(arithmetic::div, loop_trip_count, blocks)); bs->name("bs"); - auto bs_1 = bb.add(make_arith(arithmetic::add, bs, make_index(1))); + auto bs_1 = bb.add(make_arith(arithmetic::add, bs, make_imm(1, index_ty))); bs_1->name("bs_1"); auto rem = bb.add(make_arith(arithmetic::rem, loop_trip_count, blocks)); rem->name("rem"); @@ -571,18 +583,18 @@ void tile_loop_uniformly_new_dynamic(region_builder &bb, value const &loop_trip_ auto sg_id_index = bb.add(make_cast(sg_id, scalar_type::index)); auto block_start_1 = bb.add(make_arith(arithmetic::mul, bs_1, sg_id_index)); auto block_end_1 = bb.add(make_arith(arithmetic::mul, bs_1, rem)); - auto step_1 = bb.add(make_arith(arithmetic::mul, bs_1, make_index(num_tiles))); + auto step_1 = bb.add(make_arith(arithmetic::mul, bs_1, make_imm(num_tiles, index_ty))); bb.for_loop( scalar_type::index, std::move(block_start_1), std::move(block_end_1), std::move(step_1), [&](region_builder &bb, value const &block) { body(bb, block, bs_1); }, "block"); - auto tmp0 = bb.add(make_arith(arithmetic::rem, rem, make_index(num_tiles))); + auto tmp0 = bb.add(make_arith(arithmetic::rem, rem, make_imm(num_tiles, index_ty))); auto tmp1 = bb.add(make_arith(arithmetic::add, sg_id_index, tmp0)); - auto sg_id_1 = bb.add(make_arith(arithmetic::rem, tmp1, make_index(num_tiles))); + auto sg_id_1 = bb.add(make_arith(arithmetic::rem, tmp1, make_imm(num_tiles, index_ty))); auto tmp2 = bb.add(make_arith(arithmetic::mul, bs, sg_id_1)); auto tmp3 = bb.add(make_arith(arithmetic::mul, bs_1, rem)); auto block_start = bb.add(make_arith(arithmetic::add, tmp3, tmp2)); - auto step = bb.add(make_arith(arithmetic::mul, bs, make_index(num_tiles))); + auto step = bb.add(make_arith(arithmetic::mul, bs, make_imm(num_tiles, index_ty))); bb.for_loop( scalar_type::index, std::move(block_start), loop_trip_count, std::move(step), [&](region_builder &bb, value const &block) { body(bb, block, bs); }, "block"); diff --git a/src/compiler.cpp b/src/compiler.cpp index 59961081..64d9a082 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -88,7 +88,7 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ // insert_barriers(*prg); run_function_pass(work_group_size_pass{info}, *prg); // lower_linalg(*prg, *info); - run_function_pass(dump_ir_pass{std::cout}, *prg); + //run_function_pass(dump_ir_pass{std::cout}, *prg); // propagate_constants(*prg); // dump_ir(std::cout, *prg); // opencl diff --git a/src/compiler_context.cpp b/src/compiler_context.cpp index 4d241dd1..be8f54c9 100644 --- a/src/compiler_context.cpp +++ b/src/compiler_context.cpp @@ -2,25 +2,27 @@ // SPDX-License-Identifier: BSD-3-Clause #include "compiler_context.hpp" +#include "compiler_context_cache.hpp" #include "error.hpp" #include "tinytc/tinytc.h" #include "tinytc/types.h" #include "tinytc/types.hpp" #include -#include namespace tinytc { void default_error_reporter(char const *what, const tinytc_location_t *, void *) { std::cerr << what << std::endl; } - } // namespace tinytc using namespace tinytc; extern "C" { +tinytc_compiler_context::tinytc_compiler_context() + : cache_{std::make_unique(this)} {} + auto tinytc_compiler_context::source_name(std::int32_t source_id) -> std::pair { if (has_source_id(source_id)) { diff --git a/src/compiler_context.hpp b/src/compiler_context.hpp index 21cd4e36..5a19a258 100644 --- a/src/compiler_context.hpp +++ b/src/compiler_context.hpp @@ -10,18 +10,26 @@ #include #include +#include #include #include #include namespace tinytc { void default_error_reporter(char const *what, const tinytc_location_t *location, void *user_data); -} + +class compiler_context_cache; + +} // namespace tinytc struct tinytc_compiler_context : tinytc::reference_counted { public: constexpr static const char unavailable_source_name[] = "Source name unavailable"; + tinytc_compiler_context(); + + inline auto cache() -> tinytc::compiler_context_cache * { return cache_.get(); } + inline void set_error_reporter(tinytc::error_reporter_t reporter, void *user_data) { reporter_ = reporter; user_data_ = user_data; @@ -49,6 +57,7 @@ struct tinytc_compiler_context : tinytc::reference_counted { return source_id >= 1 && static_cast(source_id) <= sources_.size(); } + std::unique_ptr cache_; tinytc::error_reporter_t reporter_ = &tinytc::default_error_reporter; void *user_data_ = nullptr; std::vector sources_; diff --git a/src/compiler_context_cache.cpp b/src/compiler_context_cache.cpp new file mode 100644 index 00000000..e415be77 --- /dev/null +++ b/src/compiler_context_cache.cpp @@ -0,0 +1,26 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "compiler_context_cache.hpp" +#include "support/util.hpp" + +namespace tinytc { + +compiler_context_cache::compiler_context_cache(tinytc_compiler_context_t ctx) { + for (int i = 0; i < TINYTC_NUMBER_OF_SCALAR_TYPES; ++i) { + scalar_tys[i] = + std::unique_ptr(new scalar_data_type(ctx, enum_cast(i))); + } +} + +compiler_context_cache::~compiler_context_cache() { + for (auto &mt : memref_tys) { + delete mt.second; + } + for (auto > : group_tys) { + delete gt.second; + } +} + +} // namespace tinytc + diff --git a/src/compiler_context_cache.hpp b/src/compiler_context_cache.hpp new file mode 100644 index 00000000..78318d38 --- /dev/null +++ b/src/compiler_context_cache.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef COMPILER_CONTEXT_CACHE_20240925_HPP +#define COMPILER_CONTEXT_CACHE_20240925_HPP + +#include "node/data_type_node.hpp" +#include "support/util.hpp" +#include "tinytc/types.hpp" + +#include +#include +#include +#include + +namespace std { +template <> class hash> { + public: + auto operator()(std::pair const &key) const -> std::size_t { + return tinytc::fnv1a(key.first, key.second); + } +}; +} // namespace std + +namespace tinytc { + +class compiler_context_cache { + public: + compiler_context_cache(tinytc_compiler_context_t ctx); + ~compiler_context_cache(); + + compiler_context_cache(compiler_context_cache const &) = delete; + compiler_context_cache &operator=(compiler_context_cache const &) = delete; + + std::array, TINYTC_NUMBER_OF_SCALAR_TYPES> scalar_tys; + std::unordered_multimap memref_tys; + std::unordered_map, tinytc_data_type_t> group_tys; +}; + +} // namespace tinytc + +#endif // COMPILER_CONTEXT_CACHE_20240925_HPP diff --git a/src/data_type.cpp b/src/data_type.cpp index bb10c800..ac033ab1 100644 --- a/src/data_type.cpp +++ b/src/data_type.cpp @@ -1,6 +1,8 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause +#include "compiler_context.hpp" +#include "compiler_context_cache.hpp" #include "error.hpp" #include "location.hpp" #include "node/data_type_node.hpp" @@ -18,67 +20,52 @@ using namespace tinytc; extern "C" { -tinytc_status_t tinytc_scalar_type_create(tinytc_data_type_t *dt, tinytc_scalar_type_t type, - const tinytc_location_t *loc) { - if (dt == nullptr) { +tinytc_status_t tinytc_scalar_type_get(tinytc_data_type_t *dt, tinytc_compiler_context_t ctx, + tinytc_scalar_type_t type) { + if (dt == nullptr || ctx == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code( - [&] { *dt = std::make_unique(enum_cast(type)).release(); }); + [&] { *dt = scalar_data_type::get(ctx, enum_cast(type)); }); } -tinytc_status_t tinytc_memref_type_create(tinytc_data_type_t *dt, tinytc_scalar_type_t scalar_ty, - uint32_t shape_size, const int64_t *shape, - uint32_t stride_size, const int64_t *stride, - const tinytc_address_space_t addrspace, - const tinytc_location_t *loc) { - if (dt == nullptr) { +tinytc_status_t tinytc_memref_type_get(tinytc_data_type_t *dt, tinytc_compiler_context_t ctx, + tinytc_scalar_type_t scalar_ty, uint32_t shape_size, + const int64_t *shape, uint32_t stride_size, + const int64_t *stride, + const tinytc_address_space_t addrspace, + const tinytc_location_t *loc) { + if (dt == nullptr || ctx == nullptr || (shape_size != 0 && shape == nullptr) || + (stride_size != 0 && stride == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto shape_vec = std::vector(shape, shape + shape_size); - auto stride_vec = std::vector(); - if (stride_size > 0) { - stride_vec.insert(stride_vec.end(), stride, stride + stride_size); + auto shape_span = std::span{}; + if (shape != nullptr) { + shape_span = std::span(shape, static_cast(shape_size)); + } + auto stride_span = std::span{}; + if (stride != nullptr) { + stride_span = + std::span(stride, static_cast(stride_size)); } - *dt = std::make_unique( - enum_cast(scalar_ty), std::move(shape_vec), std::move(stride_vec), - enum_cast(addrspace), get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_group_type_create(tinytc_data_type_t *dt, tinytc_data_type_t memref_ty, - int64_t offset, const tinytc_location_t *loc) { - if (dt == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *dt = - std::make_unique(data_type(memref_ty, true), offset, get_optional(loc)) - .release(); + *dt = memref_data_type::get(ctx, enum_cast(scalar_ty), std::move(shape_span), + std::move(stride_span), enum_cast(addrspace), + get_optional(loc)); }); } -tinytc_status_t tinytc_data_type_release(tinytc_data_type_t obj) { - if (obj == nullptr) { +tinytc_status_t tinytc_group_type_get(tinytc_data_type_t *dt, tinytc_compiler_context_t ctx, + tinytc_data_type_t memref_ty, int64_t offset, + const tinytc_location_t *loc) { + if (dt == nullptr || ctx == nullptr) { return tinytc_status_invalid_arguments; } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; -} -tinytc_status_t tinytc_data_type_retain(tinytc_data_type_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - obj->inc_ref(); - return tinytc_status_success; + return exception_to_status_code( + [&] { *dt = group_data_type::get(ctx, memref_ty, offset, get_optional(loc)); }); } } diff --git a/src/inst.cpp b/src/inst.cpp index 334e8616..ac458d51 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -151,9 +151,8 @@ tinytc_status_t tinytc_alloca_inst_create(tinytc_inst_t *instr, tinytc_data_type if (instr == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { - *instr = std::make_unique(data_type(ty, true), get_optional(loc)).release(); - }); + return exception_to_status_code( + [&] { *instr = std::make_unique(ty, get_optional(loc)).release(); }); } tinytc_status_t tinytc_axpby_inst_create(tinytc_inst_t *instr, tinytc_transpose_t tA, @@ -216,20 +215,22 @@ tinytc_status_t tinytc_load_inst_create(tinytc_inst_t *instr, tinytc_value_t a, }); } -tinytc_status_t tinytc_group_id_inst_create(tinytc_inst_t *instr, const tinytc_location_t *loc) { - if (instr == nullptr) { +tinytc_status_t tinytc_group_id_inst_create(tinytc_inst_t *instr, tinytc_compiler_context_t ctx, + const tinytc_location_t *loc) { + if (instr == nullptr || ctx == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code( - [&] { *instr = std::make_unique(get_optional(loc)).release(); }); + [&] { *instr = std::make_unique(ctx, get_optional(loc)).release(); }); } -tinytc_status_t tinytc_group_size_inst_create(tinytc_inst_t *instr, const tinytc_location_t *loc) { - if (instr == nullptr) { +tinytc_status_t tinytc_group_size_inst_create(tinytc_inst_t *instr, tinytc_compiler_context_t ctx, + const tinytc_location_t *loc) { + if (instr == nullptr || ctx == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code( - [&] { *instr = std::make_unique(get_optional(loc)).release(); }); + [&] { *instr = std::make_unique(ctx, get_optional(loc)).release(); }); } tinytc_status_t tinytc_gemm_inst_create(tinytc_inst_t *instr, tinytc_transpose_t tA, @@ -295,12 +296,13 @@ tinytc_status_t tinytc_hadamard_inst_create(tinytc_inst_t *instr, tinytc_bool_t } tinytc_status_t tinytc_num_subgroups_inst_create(tinytc_inst_t *instr, + tinytc_compiler_context_t ctx, const tinytc_location_t *loc) { - if (instr == nullptr) { + if (instr == nullptr || ctx == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code( - [&] { *instr = std::make_unique(get_optional(loc)).release(); }); + [&] { *instr = std::make_unique(ctx, get_optional(loc)).release(); }); } tinytc_status_t tinytc_parallel_inst_create(tinytc_inst_t *instr, tinytc_region_t body, @@ -323,30 +325,34 @@ tinytc_status_t tinytc_size_inst_create(tinytc_inst_t *instr, tinytc_value_t a, }); } -tinytc_status_t tinytc_subgroup_id_inst_create(tinytc_inst_t *instr, const tinytc_location_t *loc) { - if (instr == nullptr) { +tinytc_status_t tinytc_subgroup_id_inst_create(tinytc_inst_t *instr, tinytc_compiler_context_t ctx, + const tinytc_location_t *loc) { + if (instr == nullptr || ctx == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code( - [&] { *instr = std::make_unique(get_optional(loc)).release(); }); + [&] { *instr = std::make_unique(ctx, get_optional(loc)).release(); }); } tinytc_status_t tinytc_subgroup_local_id_inst_create(tinytc_inst_t *instr, + tinytc_compiler_context_t ctx, const tinytc_location_t *loc) { - if (instr == nullptr) { + if (instr == nullptr || ctx == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code( - [&] { *instr = std::make_unique(get_optional(loc)).release(); }); + return exception_to_status_code([&] { + *instr = std::make_unique(ctx, get_optional(loc)).release(); + }); } tinytc_status_t tinytc_subgroup_size_inst_create(tinytc_inst_t *instr, + tinytc_compiler_context_t ctx, const tinytc_location_t *loc) { - if (instr == nullptr) { + if (instr == nullptr || ctx == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code( - [&] { *instr = std::make_unique(get_optional(loc)).release(); }); + [&] { *instr = std::make_unique(ctx, get_optional(loc)).release(); }); } tinytc_status_t tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t a, @@ -437,17 +443,17 @@ tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, tinytc_value_t tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condition, tinytc_region_t then, tinytc_region_t otherwise, uint32_t return_type_list_size, - tinytc_scalar_type_t *return_type_list, + tinytc_data_type_t *return_type_list, const tinytc_location_t *loc) { if (instr == nullptr || condition == nullptr || then == nullptr || (return_type_list_size > 0 && return_type_list == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto rt = std::vector(); + auto rt = std::vector(); rt.reserve(return_type_list_size); for (uint32_t i = 0; i < return_type_list_size; ++i) { - rt.emplace_back(enum_cast(return_type_list[i])); + rt.emplace_back(return_type_list[i]); } *instr = std::make_unique(value(condition, true), region{then}, region{otherwise}, std::move(rt), get_optional(loc)) diff --git a/src/node/data_type_node.cpp b/src/node/data_type_node.cpp index cc5e54ef..caca6ffe 100644 --- a/src/node/data_type_node.cpp +++ b/src/node/data_type_node.cpp @@ -2,17 +2,33 @@ // SPDX-License-Identifier: BSD-3-Clause #include "node/data_type_node.hpp" +#include "compiler_context_cache.hpp" #include "error.hpp" #include "support/casting.hpp" +#include "support/util.hpp" #include "tinytc/types.hpp" +#include #include #include namespace tinytc { -group_data_type::group_data_type(data_type ty, std::int64_t offset, location const &lc) - : data_type_node(DTK::group), ty_(std::move(ty)), offset_(offset) { +auto group_data_type::get(tinytc_compiler_context_t ctx, tinytc_data_type_t ty, std::int64_t offset, + location const &lc) -> tinytc_data_type_t { + auto &value = ctx->cache()->group_tys[std::make_pair(ty, offset)]; + + if (value == nullptr) { + value = + std::unique_ptr(new group_data_type(ctx, ty, offset, lc)).release(); + } + + return value; +} + +group_data_type::group_data_type(tinytc_compiler_context_t ctx, tinytc_data_type_t ty, + std::int64_t offset, location const &lc) + : data_type_node(DTK::group, ctx), ty_(std::move(ty)), offset_(offset) { if (!isa(*ty_)) { throw compilation_error(lc, status::ir_expected_memref); } @@ -21,40 +37,86 @@ group_data_type::group_data_type(data_type ty, std::int64_t offset, location con } } -memref_data_type::memref_data_type(scalar_type type, std::vector shape, +memref_data_type::memref_data_type(tinytc_compiler_context_t ctx, scalar_type type, + std::vector shape, std::vector stride, address_space addrspace, location const &lc) - : data_type_node(DTK::memref), element_ty_(std::move(type)), shape_(std::move(shape)), + : data_type_node(DTK::memref, ctx), element_ty_(std::move(type)), shape_(std::move(shape)), stride_(std::move(stride)), addrspace_(addrspace) { + if (stride_.size() != shape_.size()) { + throw compilation_error(lc, status::ir_shape_stride_mismatch); + } for (auto const &s : shape_) { if (s < 0 && !is_dynamic_value(s)) { throw compilation_error(lc, status::ir_invalid_shape); } } - if (stride_.empty()) { - stride_ = canonical_stride(); - } else { - for (auto const &s : stride_) { - if (s < 0 && !is_dynamic_value(s)) { - throw compilation_error(lc, status::ir_invalid_shape); - } + for (auto const &s : stride_) { + if (s < 0 && !is_dynamic_value(s)) { + throw compilation_error(lc, status::ir_invalid_shape); } } - if (stride_.size() != shape_.size()) { - throw compilation_error(lc, status::ir_shape_stride_mismatch); +} + +auto memref_data_type::get(tinytc_compiler_context_t ctx, scalar_type element_ty, + std::span shape, + std::span stride, address_space addrspace, + location const &lc) -> tinytc_data_type_t { + auto stride_buffer = std::vector{}; + if (stride.empty()) { + stride_buffer = canonical_stride(shape); + stride = std::span{stride_buffer}; } + + auto key = memref_data_type_key(element_ty, shape, stride, addrspace); + std::uint64_t map_key = key.hash(); + + auto &tys = ctx->cache()->memref_tys; + auto range = tys.equal_range(map_key); + for (auto it = range.first; it != range.second; ++it) { + if (key == *dyn_cast(it->second)) { + return it->second; + } + } + auto new_mt = std::unique_ptr(new memref_data_type( + ctx, key.element_ty, std::vector(shape.begin(), shape.end()), + std::vector(stride.begin(), stride.end()), key.addrspace, lc)); + return tys.emplace(map_key, new_mt.release())->second; } -auto memref_data_type::canonical_stride() const -> std::vector { - if (shape_.empty()) { +auto memref_data_type::canonical_stride(std::span shape) + -> std::vector { + if (shape.empty()) { return {}; } - auto stride = std::vector(shape_.size(), dynamic); + auto stride = std::vector(shape.size(), dynamic); stride[0] = 1; - for (std::size_t i = 0; i < shape_.size() - 1 && !is_dynamic_value(shape_[i]); ++i) { - stride[i + 1] = stride[i] * shape_[i]; + for (std::size_t i = 0; i < shape.size() - 1 && !is_dynamic_value(shape[i]); ++i) { + stride[i + 1] = stride[i] * shape[i]; } return stride; } +auto memref_data_type_key::hash() -> std::uint64_t { + std::uint64_t hash = fnv1a0(); + hash = fnv1a_step(hash, element_ty); + for (auto &s : shape) { + hash = fnv1a_step(hash, s); + } + for (auto &s : stride) { + hash = fnv1a_step(hash, s); + } + return fnv1a_step(hash, addrspace); +} + +auto memref_data_type_key::operator==(memref_data_type const &mt) -> bool { + return element_ty == mt.element_ty() && addrspace == mt.addrspace() && + std::equal(shape.begin(), shape.end(), mt.shape().begin(), mt.shape().end()) && + std::equal(stride.begin(), stride.end(), mt.stride().begin(), mt.stride().end()); +} + +auto scalar_data_type::get(tinytc_compiler_context_t ctx, scalar_type ty) -> tinytc_data_type_t { + return ctx->cache()->scalar_tys[static_cast(ty)].get(); +} + } // namespace tinytc diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp index 22c2950f..fc030aea 100644 --- a/src/node/data_type_node.hpp +++ b/src/node/data_type_node.hpp @@ -4,13 +4,13 @@ #ifndef DATA_TYPE_NODE_20230309_HPP #define DATA_TYPE_NODE_20230309_HPP -#include "reference_counted.hpp" #include "support/type_list.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" #include #include +#include #include #include @@ -20,16 +20,19 @@ using data_type_nodes = type_list; } // namespace tinytc -struct tinytc_data_type : tinytc::reference_counted { +struct tinytc_data_type { public: using leaves = tinytc::data_type_nodes; - inline tinytc_data_type(tinytc::DTK tid) : tid_(tid) {} + inline tinytc_data_type(tinytc::DTK tid, tinytc_compiler_context_t ctx) + : tid_(tid), ctx_(ctx) {} virtual ~tinytc_data_type() = default; inline auto type_id() const -> tinytc::DTK { return tid_; } + inline auto context() const -> tinytc_compiler_context_t { return ctx_; } private: tinytc::DTK tid_; + tinytc_compiler_context_t ctx_; }; namespace tinytc { @@ -39,22 +42,29 @@ using data_type_node = ::tinytc_data_type; class group_data_type : public data_type_node { public: inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::group; } - group_data_type(data_type ty, std::int64_t offset = 0, location const &lc = {}); + static auto get(tinytc_compiler_context_t ctx, tinytc_data_type_t ty, std::int64_t offset, + location const &lc = {}) -> tinytc_data_type_t; - inline auto ty() const -> data_type const & { return ty_; } + inline auto ty() const -> tinytc_data_type_t { return ty_; } inline auto offset() const -> std::int64_t { return offset_; } + protected: + group_data_type(tinytc_compiler_context_t ctx, tinytc_data_type_t ty, std::int64_t offset = 0, + location const &lc = {}); + private: - data_type ty_; + tinytc_data_type_t ty_; std::int64_t offset_; }; class memref_data_type : public data_type_node { public: inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::memref; } - memref_data_type(scalar_type type, std::vector shape, - std::vector stride = {}, - address_space addrspace = address_space::global, location const &lc = {}); + static auto canonical_stride(std::span shape) -> std::vector; + static auto get(tinytc_compiler_context_t ctx, scalar_type element_ty, + std::span shape, std::span stride, + address_space addrspace = address_space::global, + location const &lc = {}) -> tinytc_data_type_t; inline scalar_type element_ty() const { return element_ty_; } inline std::int64_t dim() const { return shape_.size(); } @@ -75,23 +85,39 @@ class memref_data_type : public data_type_node { return std::any_of(stride_.begin(), stride_.end(), is_dynamic_value); } inline bool is_dynamic() const { return is_dynamic_shape() || is_dynamic_stride(); } - inline bool is_canonical_stride() const { return stride_ == canonical_stride(); } + inline bool is_canonical_stride() const { return stride_ == canonical_stride(shape_); } - private: - auto canonical_stride() const -> std::vector; + protected: + memref_data_type(tinytc_compiler_context_t ctx, scalar_type type, + std::vector shape, std::vector stride, + address_space addrspace = address_space::global, location const &lc = {}); scalar_type element_ty_; std::vector shape_, stride_; address_space addrspace_ = address_space::global; }; +struct memref_data_type_key { + scalar_type element_ty; + std::span shape, stride; + address_space addrspace; + + auto hash() -> std::uint64_t; + auto operator==(memref_data_type const &mt) -> bool; +}; + class scalar_data_type : public data_type_node { public: inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::scalar; } - inline scalar_data_type(scalar_type type) : data_type_node(DTK::scalar), ty_(type) {} + static auto get(tinytc_compiler_context_t ctx, scalar_type ty) -> tinytc_data_type_t; inline scalar_type ty() const { return ty_; } + protected: + inline scalar_data_type(tinytc_compiler_context_t ctx, scalar_type type) + : data_type_node(DTK::scalar, ctx), ty_(type) {} + friend class compiler_context_cache; + private: scalar_type ty_; }; @@ -99,7 +125,9 @@ class scalar_data_type : public data_type_node { class void_data_type : public data_type_node { public: inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::void_; } - inline void_data_type() : data_type_node(DTK::void_) {} + + protected: + inline void_data_type(tinytc_compiler_context_t ctx) : data_type_node(DTK::void_, ctx) {} }; } // namespace tinytc diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index b90dc4c0..cd99c1e5 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -21,7 +21,7 @@ namespace tinytc { scalar_data_type *get_scalar_type(location const &loc, value const &v) { - auto m = dyn_cast(v->ty().get()); + auto m = dyn_cast(v->ty()); if (m == nullptr) { throw compilation_error(loc, status::ir_expected_scalar); } @@ -29,7 +29,7 @@ scalar_data_type *get_scalar_type(location const &loc, value const &v) { } memref_data_type *get_memref_type(location const &loc, value const &v) { - auto m = dyn_cast(v->ty().get()); + auto m = dyn_cast(v->ty()); if (m == nullptr) { throw compilation_error(loc, status::ir_expected_memref); } @@ -77,12 +77,12 @@ loop_inst::loop_inst(IK tid, value loop_var0, value from0, value to0, value step } } -alloca_inst::alloca_inst(data_type ty, location const &lc) +alloca_inst::alloca_inst(tinytc_data_type_t ty, location const &lc) : standard_inst{IK::alloca}, stack_ptr_{-1} { loc(lc); result(0) = make_value(std::move(ty)); - auto memref = dyn_cast(result(0)->ty().get()); + auto memref = dyn_cast(result(0)->ty()); if (memref == nullptr) { throw compilation_error(loc(), status::ir_expected_memref); } @@ -148,7 +148,7 @@ arith_inst::arith_inst(arithmetic operation, value a0, value b0, location const if (!inst_supports_fp && is_floating_type(at->ty())) { throw compilation_error(loc(), status::ir_fp_unsupported); } - result(0) = make_value(at->ty()); + result(0) = make_value(at); } arith_unary_inst::arith_unary_inst(arithmetic_unary operation, value a0, location const &lc) @@ -169,14 +169,14 @@ arith_unary_inst::arith_unary_inst(arithmetic_unary operation, value a0, locatio if (!inst_supports_fp && is_floating_type(at->ty())) { throw compilation_error(loc(), status::ir_fp_unsupported); } - result(0) = make_value(at->ty()); + result(0) = make_value(at); } cast_inst::cast_inst(value a, scalar_type to_ty, location const &lc) : standard_inst{IK::cast} { op(op_a) = std::move(a); loc(lc); - result(0) = make_value(std::move(to_ty)); + result(0) = make_value(scalar_data_type::get(op(op_a)->context(), std::move(to_ty))); } compare_inst::compare_inst(cmp_condition cond, value a0, value b0, location const &lc) @@ -192,7 +192,7 @@ compare_inst::compare_inst(cmp_condition cond, value a0, value b0, location cons throw compilation_error(loc(), status::ir_scalar_mismatch); } - result(0) = make_value(scalar_type::i1); + result(0) = make_value(scalar_data_type::get(at->context(), scalar_type::i1)); } expand_inst::expand_inst(value op0, std::int64_t mode, std::vector const &expand_shape0, @@ -255,7 +255,8 @@ expand_inst::expand_inst(value op0, std::int64_t mode, std::vector const if (dyn_mode >= 0) { std::int64_t const s = size / prod; known_expand_shape[dyn_mode] = s; - expand_shape()[dyn_mode] = make_imm(s); + expand_shape()[dyn_mode] = + make_imm(s, scalar_data_type::get(m->context(), scalar_type::i64)); prod *= s; } if (prod != size) { @@ -284,10 +285,9 @@ expand_inst::expand_inst(value op0, std::int64_t mode, std::vector const shape.push_back(m->shape(i)); stride.push_back(m->stride(i)); } - auto r = std::make_unique(m->element_ty(), shape, stride); - r->addrspace(m->addrspace()); - result(0) = make_value(data_type(r.release())); + result(0) = make_value( + memref_data_type::get(m->context(), m->element_ty(), shape, stride, m->addrspace())); } fuse_inst::fuse_inst(value op0, std::int64_t from, std::int64_t to, location const &lc) @@ -322,10 +322,9 @@ fuse_inst::fuse_inst(value op0, std::int64_t from, std::int64_t to, location con shape.push_back(m->shape(i)); stride.push_back(m->stride(i)); } - auto r = std::make_unique(m->element_ty(), shape, stride); - r->addrspace(m->addrspace()); - result(0) = make_value(data_type(r.release())); + result(0) = make_value( + memref_data_type::get(m->context(), m->element_ty(), shape, stride, m->addrspace())); } load_inst::load_inst(value op0, std::vector const &index_list0, location const &lc) @@ -347,7 +346,7 @@ load_inst::load_inst(value op0, std::vector const &index_list0, location if (m.dim() != static_cast(index_list().size())) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } - result(0) = make_value(m.element_ty()); + result(0) = make_value(scalar_data_type::get(m.context(), m.element_ty())); }, [&](auto &) { throw compilation_error(loc(), status::ir_expected_memref_or_group); }}, *operand()->ty()); @@ -474,7 +473,7 @@ hadamard_inst::hadamard_inst(value alpha0, value A0, value B0, value beta0, valu } if_inst::if_inst(value condition, region then0, region otherwise0, - std::vector const &return_types, location const &lc) + std::vector const &return_types, location const &lc) : standard_inst{IK::if_, 1, static_cast(return_types.size()), otherwise0 ? 2 : 1} { op(0) = std::move(condition); child_region(child_region_then) = std::move(then0); @@ -502,7 +501,7 @@ size_inst::size_inst(value op0, std::int64_t mode, location const &lc) throw compilation_error(loc(), status::ir_out_of_bounds); } - result(0) = make_value(scalar_type::index); + result(0) = make_value(scalar_data_type::get(op(0)->context(), scalar_type::index)); } subview_inst::subview_inst(value op0, std::vector const &offset_list0, @@ -515,8 +514,9 @@ subview_inst::subview_inst(value op0, std::vector const &offset_list0, for (auto const &val : offset_list0) { op(i++) = val; } + auto index_ty = scalar_data_type::get(op(0)->context(), scalar_type::index); for (auto const &val : size_list0) { - op(i++) = val ? val : make_index(0); + op(i++) = val ? val : make_imm(0, index_ty); } } loc(lc); @@ -569,10 +569,9 @@ subview_inst::subview_inst(value op0, std::vector const &offset_list0, stride.push_back(m->stride(i)); } } - auto r = std::make_unique(m->element_ty(), shape, stride); - r->addrspace(m->addrspace()); - result(0) = make_value(data_type(r.release())); + result(0) = make_value( + memref_data_type::get(m->context(), m->element_ty(), shape, stride, m->addrspace())); } store_inst::store_inst(value val0, value op0, std::vector const &index_list0, diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 56fa5fe5..bef8fe0e 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -5,6 +5,7 @@ #define INST_NODE_20230327_HPP #include "error.hpp" +#include "node/data_type_node.hpp" #include "support/ilist.hpp" #include "support/type_list.hpp" #include "support/util.hpp" @@ -339,7 +340,7 @@ class loop_inst : public standard_inst<4, 0, 1> { class alloca_inst : public standard_inst<0, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::alloca; } - alloca_inst(data_type ty, location const &loc = {}); + alloca_inst(tinytc_data_type_t ty, location const &loc = {}); inline std::int64_t stack_ptr() const { return stack_ptr_; } inline void stack_ptr(std::int64_t ptr) { stack_ptr_ = ptr; } @@ -468,18 +469,20 @@ class load_inst : public standard_inst { class group_id_inst : public standard_inst<0, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::group_id; } - inline group_id_inst(location const &lc = {}) : standard_inst{IK::group_id} { + inline group_id_inst(tinytc_compiler_context_t ctx, location const &lc = {}) + : standard_inst{IK::group_id} { loc(lc); - result(0) = make_value(scalar_type::index); + result(0) = make_value(scalar_data_type::get(ctx, scalar_type::index)); } }; class group_size_inst : public standard_inst<0, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::group_size; } - inline group_size_inst(location const &lc = {}) : standard_inst{IK::group_size} { + inline group_size_inst(tinytc_compiler_context_t ctx, location const &lc = {}) + : standard_inst{IK::group_size} { loc(lc); - result(0) = make_value(scalar_type::index); + result(0) = make_value(scalar_data_type::get(ctx, scalar_type::index)); } }; @@ -558,7 +561,7 @@ class if_inst : public standard_inst<1, dynamic, 2> { inline static bool classof(inst_node const &i) { return i.type_id() == IK::if_; } enum child_region_number { child_region_then = 0, child_region_otherwise = 1 }; if_inst(value condition, region then, region otherwise = {}, - std::vector const &return_types = {}, location const &lc = {}); + std::vector const &return_types = {}, location const &lc = {}); inline auto condition() const -> value const & { return op(0); } inline auto then() -> tinytc_region & { return *child_region(child_region_then); } inline auto then() const -> tinytc_region const & { return *child_region(child_region_then); } @@ -572,9 +575,10 @@ class if_inst : public standard_inst<1, dynamic, 2> { class num_subgroups_inst : public standard_inst<0, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::num_subgroups; } - inline num_subgroups_inst(location const &lc = {}) : standard_inst{IK::num_subgroups} { + inline num_subgroups_inst(tinytc_compiler_context_t ctx, location const &lc = {}) + : standard_inst{IK::num_subgroups} { loc(lc); - result(0) = make_value(scalar_type::i32); + result(0) = make_value(scalar_data_type::get(ctx, scalar_type::i32)); } }; @@ -602,27 +606,30 @@ class size_inst : public standard_inst<1, 1> { class subgroup_id_inst : public standard_inst<0, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::subgroup_id; } - inline subgroup_id_inst(location const &lc = {}) : standard_inst{IK::subgroup_id} { + inline subgroup_id_inst(tinytc_compiler_context_t ctx, location const &lc = {}) + : standard_inst{IK::subgroup_id} { loc(lc); - result(0) = make_value(scalar_type::i32); + result(0) = make_value(scalar_data_type::get(ctx, scalar_type::i32)); } }; class subgroup_local_id_inst : public standard_inst<0, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::subgroup_local_id; } - inline subgroup_local_id_inst(location const &lc = {}) : standard_inst{IK::subgroup_local_id} { + inline subgroup_local_id_inst(tinytc_compiler_context_t ctx, location const &lc = {}) + : standard_inst{IK::subgroup_local_id} { loc(lc); - result(0) = make_value(scalar_type::i32); + result(0) = make_value(scalar_data_type::get(ctx, scalar_type::i32)); } }; class subgroup_size_inst : public standard_inst<0, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::subgroup_size; } - inline subgroup_size_inst(location const &lc = {}) : standard_inst{IK::subgroup_size} { + inline subgroup_size_inst(tinytc_compiler_context_t ctx, location const &lc = {}) + : standard_inst{IK::subgroup_size} { loc(lc); - result(0) = make_value(scalar_type::i32); + result(0) = make_value(scalar_data_type::get(ctx, scalar_type::i32)); } }; diff --git a/src/node/value_node.hpp b/src/node/value_node.hpp index 2dc33ea0..4d75b1dc 100644 --- a/src/node/value_node.hpp +++ b/src/node/value_node.hpp @@ -4,6 +4,7 @@ #ifndef VALUE_NODE_20230309_HPP #define VALUE_NODE_20230309_HPP +#include "node/data_type_node.hpp" #include "reference_counted.hpp" #include "support/type_list.hpp" #include "tinytc/tinytc.hpp" @@ -21,21 +22,25 @@ struct tinytc_value : tinytc::reference_counted { public: using leaves = tinytc::value_nodes; - inline tinytc_value(tinytc::VK tid) : tid_(tid) {} + inline tinytc_value(tinytc::VK tid, tinytc_data_type_t ty) : tid_(tid), ty_(std::move(ty)) {} virtual ~tinytc_value() = default; inline auto type_id() const -> tinytc::VK { return tid_; } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } - virtual tinytc::data_type ty() const = 0; - virtual void ty(tinytc::data_type ty) = 0; + inline tinytc_data_type_t ty() const { return ty_; } + inline void ty(tinytc_data_type_t ty) { ty_ = std::move(ty); } + + inline auto context() const -> tinytc_compiler_context_t { return ty_->context(); } + virtual auto name() const -> char const * = 0; virtual void name(std::string name) = 0; virtual auto has_name() const -> bool = 0; private: tinytc::VK tid_; + tinytc_data_type_t ty_; tinytc::location loc_; }; @@ -46,13 +51,11 @@ using value_node = ::tinytc_value; class float_imm : public value_node { public: inline static bool classof(value_node const &v) { return v.type_id() == VK::float_; } - inline float_imm(double v, scalar_type ty = scalar_type::f64, location const &lc = {}) - : value_node(VK::float_), ty_{make_scalar(ty)}, value_(v) { + inline float_imm(double v, tinytc_data_type_t ty, location const &lc = {}) + : value_node(VK::float_, ty), value_(v) { loc(lc); } - inline data_type ty() const override { return ty_; } - inline void ty(data_type ty) override { ty_ = std::move(ty); } inline auto name() const -> char const * override { return ""; } inline void name(std::string) override {} auto has_name() const -> bool override { return false; } @@ -60,20 +63,17 @@ class float_imm : public value_node { inline double value() const { return value_; } private: - data_type ty_; double value_; }; class int_imm : public value_node { public: inline static bool classof(value_node const &v) { return v.type_id() == VK::int_; } - inline int_imm(std::int64_t v, scalar_type ty = scalar_type::i64, location const &lc = {}) - : value_node(VK::int_), ty_{make_scalar(ty)}, value_(v) { + inline int_imm(std::int64_t v, tinytc_data_type_t ty, location const &lc = {}) + : value_node(VK::int_, ty), value_(v) { loc(lc); } - inline data_type ty() const override { return ty_; } - inline void ty(data_type ty) override { ty_ = std::move(ty); } inline auto name() const -> char const * override { return ""; } inline void name(std::string) override {} auto has_name() const -> bool override { return false; } @@ -81,25 +81,21 @@ class int_imm : public value_node { inline std::int64_t value() const { return value_; } private: - data_type ty_; std::int64_t value_; }; class val : public value_node { public: inline static bool classof(value_node const &v) { return v.type_id() == VK::val; } - inline val(data_type ty, location const &lc = {}) : value_node(VK::val), ty_(std::move(ty)) { + inline val(tinytc_data_type_t ty, location const &lc = {}) : value_node(VK::val, ty) { loc(lc); } - inline data_type ty() const override { return ty_; } - inline void ty(data_type ty) override { ty_ = std::move(ty); } inline auto name() const -> char const * override { return name_.c_str(); } inline void name(std::string name) override { name_ = std::move(name); } virtual auto has_name() const -> bool override { return !name_.empty(); } private: - data_type ty_; std::string name_; }; diff --git a/src/parser/parse_context.hpp b/src/parser/parse_context.hpp index 251286cb..1a930db3 100644 --- a/src/parser/parse_context.hpp +++ b/src/parser/parse_context.hpp @@ -29,7 +29,7 @@ class parse_context { void report_error(location const &loc, std::string const &what); - auto get_compiler_context() -> compiler_context const & { return compiler_ctx_; } + auto cctx() -> compiler_context const & { return compiler_ctx_; } private: compiler_context compiler_ctx_; diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 15d1bb8b..646d770b 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -21,7 +21,6 @@ } %code { - #include "analysis/equal.hpp" #include "error.hpp" #include "node/data_type_node.hpp" #include "node/inst_node.hpp" @@ -37,34 +36,32 @@ #include #include #include - #include #include + #include namespace tinytc { - void check_scalar_type(value & val, scalar_type const& sty, location & loc1, - location & loc2) { - visit( - overloaded{[&](int_imm &i) { i.ty(make_scalar(sty)); }, - [&](float_imm &i) { i.ty(make_scalar(sty)); }, - [&](auto &) { - if (!val->ty() || !is_equal(*val->ty(), *make_scalar(sty))) { - auto loc = loc1; - loc.end = loc2.end; - throw parser::syntax_error( - loc, "Type of SSA value does not match operand type"); - } - }}, - *val); + void check_scalar_type(compiler_context const &ctx, value &val, scalar_type const &sty, + location &loc1, location &loc2) { + visit(overloaded{[&](int_imm &i) { i.ty(get_scalar(ctx, sty)); }, + [&](float_imm &i) { i.ty(get_scalar(ctx, sty)); }, + [&](auto &) { + if (val->ty() != get_scalar(ctx, sty)) { + auto loc = loc1; + loc.end = loc2.end; + throw parser::syntax_error( + loc, "Type of SSA value does not match operand type"); + } + }}, + *val); } - void check_type(value & val, data_type & ty, location & loc1, - location & loc2) { - if (!val->ty() || !is_equal(*val->ty(), *ty)) { + void check_type(value &val, tinytc_data_type_t ty, location &loc1, location &loc2) { + if (val->ty() != ty) { auto loc = loc1; loc.end = loc2.end; throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); } }; - } + } // namespace tinytc } %header @@ -157,17 +154,17 @@ %nterm <::tinytc::value> argument %nterm >> attributes %nterm > attribute -%nterm data_type +%nterm data_type %nterm scalar_type -%nterm memref_type +%nterm memref_type %nterm optional_address_space %nterm > mode_list %nterm > optional_stride_list %nterm > stride_list %nterm constant_or_dynamic -%nterm group_type +%nterm group_type %nterm group_offset -%nterm memref_or_group_type +%nterm memref_or_group_type %nterm region %nterm <::tinytc::value> var %nterm > instructions @@ -189,9 +186,9 @@ %nterm foreach_inst %nterm hadamard_inst %nterm if_inst -%nterm > optional_returned_values -%nterm > optional_scalar_type_list -%nterm > scalar_type_list +%nterm > optional_returned_values +%nterm > optional_scalar_type_list +%nterm > scalar_type_list %nterm else_region %nterm sum_inst %nterm yield_inst @@ -230,7 +227,7 @@ %% prog: func_list { - auto p = prog { std::make_unique(ctx.get_compiler_context(), @prog).release() }; + auto p = prog { std::make_unique(ctx.cctx(), @prog).release() }; ctx.program(p); $$ = std::move(p); for (auto& f : $func_list) { @@ -303,7 +300,7 @@ attribute: data_type: - scalar_type { $$ = make_scalar($scalar_type); } + scalar_type { $$ = get_scalar(ctx.cctx(), $scalar_type); } | memref_type | group_type ; @@ -316,12 +313,8 @@ scalar_type: memref_type: MEMREF LCHEV scalar_type mode_list optional_address_space RCHEV { try { - $$ = data_type { - std::make_unique($scalar_type, std::move($mode_list), - std::vector{}, $optional_address_space, - @memref_type) - .release() - }; + $$ = + get_memref(ctx.cctx(), $scalar_type, $mode_list, {}, $optional_address_space, @memref_type); } catch (compilation_error const &e) { error(e.loc(), e.what()); YYERROR; @@ -334,12 +327,8 @@ memref_type: throw syntax_error(loc, "Shape and stride list must have the same length"); } try { - $$ = data_type { - std::make_unique($scalar_type, std::move($mode_list), - std::move($optional_stride_list), - $optional_address_space, @memref_type) - .release() - }; + $$ = get_memref(ctx.cctx(), $scalar_type, $mode_list, $optional_stride_list, + $optional_address_space, @memref_type); } catch (compilation_error const &e) { error(e.loc(), e.what()); YYERROR; @@ -375,7 +364,7 @@ constant_or_dynamic: group_type: GROUP LCHEV memref_type group_offset RCHEV { - $$ = make_group(std::move($memref_type), $group_offset, @group_type); + $$ = get_group(ctx.cctx(), std::move($memref_type), $group_offset, @group_type); } ; @@ -434,9 +423,9 @@ axpby_inst: AXPBY transpose[ta] atomic identifier_or_constant[alpha] COMMA var[a] COMMA identifier_or_constant[beta] COMMA var[b] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA scalar_type[fbeta] COMMA memref_type[mb] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); + check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); check_type($a, $ma, @a, @ma); - check_scalar_type($beta, $fbeta, @beta, @fbeta); + check_scalar_type(ctx.cctx(), $beta, $fbeta, @beta, @fbeta); check_type($b, $mb, @b, @mb); try { $$ = inst { @@ -458,8 +447,16 @@ atomic: identifier_or_constant: var { $$ = $var; } - | INTEGER_CONSTANT { $$ = make_imm($INTEGER_CONSTANT); $$->loc(@INTEGER_CONSTANT); } - | FLOATING_CONSTANT { $$ = make_imm($FLOATING_CONSTANT); $$->loc(@FLOATING_CONSTANT); } + | INTEGER_CONSTANT { + auto i64_ty = get_scalar(ctx.cctx(), scalar_type::i64); + $$ = make_imm($INTEGER_CONSTANT, i64_ty); + $$->loc(@INTEGER_CONSTANT); + } + | FLOATING_CONSTANT { + auto f64_ty = get_scalar(ctx.cctx(), scalar_type::f64); + $$ = make_fimm($FLOATING_CONSTANT, f64_ty); + $$->loc(@FLOATING_CONSTANT); + } ; optional_identifier_or_constant_list: @@ -504,10 +501,10 @@ gemm_inst: identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] COMMA memref_type[mc] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); + check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); check_type($a, $ma, @a, @ma); check_type($b, $mb, @b, @mb); - check_scalar_type($beta, $fbeta, @beta, @fbeta); + check_scalar_type(ctx.cctx(), $beta, $fbeta, @beta, @fbeta); check_type($c, $mc, @c, @mc); try { $$ = inst { @@ -528,10 +525,10 @@ gemv_inst: identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] COMMA memref_type[mc] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); + check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); check_type($a, $ma, @a, @ma); check_type($b, $mb, @b, @mb); - check_scalar_type($beta, $fbeta, @beta, @fbeta); + check_scalar_type(ctx.cctx(), $beta, $fbeta, @beta, @fbeta); check_type($c, $mc, @c, @mc); try { $$ = inst { @@ -556,10 +553,10 @@ ger_inst: identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] COMMA memref_type[mc] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); + check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); check_type($a, $ma, @a, @ma); check_type($b, $mb, @b, @mb); - check_scalar_type($beta, $fbeta, @beta, @fbeta); + check_scalar_type(ctx.cctx(), $beta, $fbeta, @beta, @fbeta); check_type($c, $mc, @c, @mc); try { $$ = inst { @@ -578,12 +575,12 @@ for_inst: FOR LOCAL_IDENTIFIER[loop_var] EQUALS identifier_or_constant[from] COMMA identifier_or_constant[to] optional_step for_loop_var_type { - check_scalar_type($from, $for_loop_var_type, @from, @for_loop_var_type); - check_scalar_type($to, $for_loop_var_type, @to, @for_loop_var_type); + check_scalar_type(ctx.cctx(), $from, $for_loop_var_type, @from, @for_loop_var_type); + check_scalar_type(ctx.cctx(), $to, $for_loop_var_type, @to, @for_loop_var_type); if ($optional_step) { - check_scalar_type($optional_step, $for_loop_var_type, @optional_step, @for_loop_var_type); + check_scalar_type(ctx.cctx(), $optional_step, $for_loop_var_type, @optional_step, @for_loop_var_type); } - auto v = make_value($for_loop_var_type); + auto v = make_value(get_scalar(ctx.cctx(), $for_loop_var_type)); v.name($loop_var); ctx.val($loop_var, std::move(v), @loop_var); } region { @@ -607,9 +604,9 @@ optional_step: foreach_inst: FOREACH LOCAL_IDENTIFIER[loop_var] EQUALS identifier_or_constant[from] COMMA identifier_or_constant[to] for_loop_var_type { - check_scalar_type($from, $for_loop_var_type, @from, @for_loop_var_type); - check_scalar_type($to, $for_loop_var_type, @to, @for_loop_var_type); - auto v = make_value($for_loop_var_type); + check_scalar_type(ctx.cctx(), $from, $for_loop_var_type, @from, @for_loop_var_type); + check_scalar_type(ctx.cctx(), $to, $for_loop_var_type, @to, @for_loop_var_type); + auto v = make_value(get_scalar(ctx.cctx(), $for_loop_var_type)); v.name($loop_var); ctx.val($loop_var, std::move(v), @loop_var); } region { @@ -658,10 +655,10 @@ hadamard_inst: identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] COMMA memref_type[mc] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); + check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); check_type($a, $ma, @a, @ma); check_type($b, $mb, @b, @mb); - check_scalar_type($beta, $fbeta, @beta, @fbeta); + check_scalar_type(ctx.cctx(), $beta, $fbeta, @beta, @fbeta); check_type($c, $mc, @c, @mc); try { $$ = inst { @@ -681,9 +678,9 @@ sum_inst: SUM transpose[ta] atomic identifier_or_constant[alpha] COMMA var[a] COMMA identifier_or_constant[beta] COMMA var[b] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA scalar_type[fbeta] COMMA memref_type[mb] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); + check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); check_type($a, $ma, @a, @ma); - check_scalar_type($beta, $fbeta, @beta, @fbeta); + check_scalar_type(ctx.cctx(), $beta, $fbeta, @beta, @fbeta); check_type($b, $mb, @b, @mb); try { $$ = inst { @@ -706,7 +703,11 @@ yield_inst: throw syntax_error(loc, "Identifier and scalar type list must have the same length"); } for (std::size_t i = 0; i < $vals.size(); ++i) { - check_scalar_type($vals[i], $tys[i], @vals, @tys); + if (auto ty = dyn_cast($tys[i]); ty) { + check_scalar_type(ctx.cctx(), $vals[i], ty->ty(), @vals, @tys); + } else { + throw syntax_error(@tys, "Yield only accepts scalar types"); + } } $$ = inst{std::make_unique(std::move($vals)).release()}; } @@ -747,8 +748,8 @@ alloca_inst: arith_inst: ARITH ARITHMETIC identifier_or_constant[a] COMMA identifier_or_constant[b] COLON scalar_type[ty] { - check_scalar_type($a, $ty, @a, @ty); - check_scalar_type($b, $ty, @b, @ty); + check_scalar_type(ctx.cctx(), $a, $ty, @a, @ty); + check_scalar_type(ctx.cctx(), $b, $ty, @b, @ty); try { $$ = inst { std::make_unique($ARITHMETIC, std::move($a), std::move($b), @arith_inst) @@ -763,7 +764,7 @@ arith_inst: arith_unary_inst: ARITH ARITHMETIC_UNARY identifier_or_constant[a] COLON scalar_type[ty] { - check_scalar_type($a, $ty, @a, @ty); + check_scalar_type(ctx.cctx(), $a, $ty, @a, @ty); try { $$ = inst { std::make_unique($ARITHMETIC_UNARY, std::move($a), @@ -780,7 +781,7 @@ arith_unary_inst: cast_inst: CAST identifier_or_constant[a] COLON scalar_type[from] RETURNS scalar_type[to] { - check_scalar_type($a, $from, @a, @from); + check_scalar_type(ctx.cctx(), $a, $from, @a, @from); try { $$ = inst { std::make_unique(std::move($a), $to, @cast_inst).release() }; } catch (compilation_error const &e) { @@ -792,8 +793,8 @@ cast_inst: compare_inst: CMP CMP_CONDITION identifier_or_constant[a] COMMA identifier_or_constant[b] COLON scalar_type[ty] { - check_scalar_type($a, $ty, @a, @ty); - check_scalar_type($b, $ty, @b, @ty); + check_scalar_type(ctx.cctx(), $a, $ty, @a, @ty); + check_scalar_type(ctx.cctx(), $b, $ty, @b, @ty); try { $$ = inst { std::make_unique($CMP_CONDITION, std::move($a), std::move($b), @@ -809,7 +810,7 @@ compare_inst: expand_inst: EXPAND var LSQBR INTEGER_CONSTANT[mode] RETURNS expand_shape RSQBR COLON memref_type { - if (!$var->ty() || !is_equal(*$var->ty(), *$memref_type)) { + if ($var->ty() != $memref_type) { auto loc = @var; loc.end = @memref_type.end; throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); @@ -836,16 +837,20 @@ expand_shape: constant_or_dynamic_or_identifier: var { - check_scalar_type($var, scalar_type::index, @var, @var); + check_scalar_type(ctx.cctx(), $var, scalar_type::index, @var, @var); $$ = $var; } - | INTEGER_CONSTANT { $$ = make_index($INTEGER_CONSTANT); $$->loc(@INTEGER_CONSTANT); } - | DYNAMIC { $$ = make_dynamic(); $$->loc(@DYNAMIC); } + | INTEGER_CONSTANT { + auto index_ty = get_scalar(ctx.cctx(), scalar_type::index); + $$ = make_imm($INTEGER_CONSTANT, index_ty); + $$->loc(@INTEGER_CONSTANT); + } + | DYNAMIC { $$ = make_imm(dynamic, get_scalar(ctx.cctx(), scalar_type::i64)); $$->loc(@DYNAMIC); } ; fuse_inst: FUSE var LSQBR INTEGER_CONSTANT[from] COMMA INTEGER_CONSTANT[to] RSQBR COLON memref_type { - if (!$var->ty() || !is_equal(*$var->ty(), *$memref_type)) { + if ($var->ty() != $memref_type) { auto loc = @var; loc.end = @memref_type.end; throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); @@ -863,7 +868,7 @@ fuse_inst: load_inst: LOAD var LSQBR optional_index_list RSQBR COLON memref_or_group_type { - if (!$var->ty() || !is_equal(*$var->ty(), *$memref_or_group_type)) { + if ($var->ty() != $memref_or_group_type) { auto loc = @var; loc.end = @memref_or_group_type.end; throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); @@ -893,18 +898,18 @@ index_list: index_identifier_or_const: var { - check_scalar_type($var, scalar_type::index, @var, @var); + check_scalar_type(ctx.cctx(), $var, scalar_type::index, @var, @var); $$ = $var; } | INTEGER_CONSTANT { - $$ = make_index($INTEGER_CONSTANT); + $$ = make_imm($INTEGER_CONSTANT, get_scalar(ctx.cctx(), scalar_type::index)); $$->loc(@INTEGER_CONSTANT); } ; store_inst: STORE var[a] COMMA var[b] LSQBR optional_index_list RSQBR COLON memref_type { - if (!$b->ty() || !is_equal(*$b->ty(), *$memref_type)) { + if ($b->ty() != $memref_type) { auto loc = @b; loc.end = @memref_type.end; throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); @@ -923,16 +928,16 @@ store_inst: ; group_id_inst: - GROUP_ID { $$ = inst{std::make_unique(@GROUP_ID).release()}; } + GROUP_ID { $$ = inst{std::make_unique(ctx.cctx().get(), @GROUP_ID).release()}; } ; group_size_inst: - GROUP_SIZE { $$ = inst{std::make_unique(@GROUP_SIZE).release()}; } + GROUP_SIZE { $$ = inst{std::make_unique(ctx.cctx().get(), @GROUP_SIZE).release()}; } ; if_inst: IF identifier_or_constant[condition] optional_returned_values region else_region { - check_scalar_type($condition, scalar_type::i1, @condition, @condition); + check_scalar_type(ctx.cctx(), $condition, scalar_type::i1, @condition, @condition); $$ = inst{std::make_unique(std::move($condition), std::move($region), std::move($else_region), std::move($optional_returned_values)) @@ -957,12 +962,14 @@ optional_scalar_type_list: ; scalar_type_list: - scalar_type { $$.push_back($scalar_type); } - | scalar_type_list COMMA scalar_type { $$ = std::move($1); $$.push_back($scalar_type); } + scalar_type { $$.push_back(get_scalar(ctx.cctx(), $scalar_type)); } + | scalar_type_list COMMA scalar_type { + $$ = std::move($1); $$.push_back(get_scalar(ctx.cctx(), $scalar_type)); + } ; num_subgroups_inst: - NUM_SUBGROUPS { $$ = inst{std::make_unique(@NUM_SUBGROUPS).release()}; } + NUM_SUBGROUPS { $$ = inst{std::make_unique(ctx.cctx().get(), @NUM_SUBGROUPS).release()}; } ; parallel_inst: @@ -973,7 +980,7 @@ parallel_inst: size_inst: SIZE var LSQBR INTEGER_CONSTANT[mode] RSQBR COLON memref_type { - if (!$var->ty() || !is_equal(*$var->ty(), *$memref_type)) { + if ($var->ty() != $memref_type) { auto loc = @var; loc.end = @memref_type.end; throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); @@ -988,20 +995,22 @@ size_inst: ; subgroup_id_inst: - SUBGROUP_ID { $$ = inst{std::make_unique(@SUBGROUP_ID).release()}; } + SUBGROUP_ID { $$ = inst{std::make_unique(ctx.cctx().get(), @SUBGROUP_ID).release()}; } ; subgroup_local_id_inst: - SUBGROUP_LOCAL_ID { $$ = inst{std::make_unique(@SUBGROUP_LOCAL_ID).release()}; } + SUBGROUP_LOCAL_ID { + $$ = inst{std::make_unique(ctx.cctx().get(), @SUBGROUP_LOCAL_ID).release()}; + } ; subgroup_size_inst: - SUBGROUP_SIZE { $$ = inst{std::make_unique(@SUBGROUP_SIZE).release()}; } + SUBGROUP_SIZE { $$ = inst{std::make_unique(ctx.cctx().get(), @SUBGROUP_SIZE).release()}; } ; subview_inst: SUBVIEW var LSQBR optional_slice_list RSQBR COLON memref_type { - if (!$var->ty() || !is_equal(*$var->ty(), *$memref_type)) { + if ($var->ty() != $memref_type) { auto loc = @var; loc.end = @memref_type.end; throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); @@ -1037,14 +1046,18 @@ slice_list: ; slice: - COLON { $$ = std::make_pair(make_index(0), make_dynamic()); } + COLON { + auto index_ty = get_scalar(ctx.cctx(), scalar_type::index); + auto i64_ty = get_scalar(ctx.cctx(), scalar_type::i64); + $$ = std::make_pair(make_imm(0, index_ty), make_imm(dynamic, i64_ty)); + } | index_identifier_or_const slice_size { $$ = std::make_pair(std::move($1), std::move($2)); } ; slice_size: %empty { $$ = {}; } | COLON index_identifier_or_const { $$ = $2; } - | COLON DYNAMIC { $$ = make_dynamic(); } + | COLON DYNAMIC { $$ = make_imm(dynamic, get_scalar(ctx.cctx(), scalar_type::i64)); } ; %% diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index febe6448..ab6f9e98 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -31,6 +31,8 @@ #include #include +#include + namespace tinytc { std::string var_name(std::string name) { @@ -49,7 +51,7 @@ dope_vector dope_vector::from_value(value_node const &v, decl_fun_t declare) { dt = to_clir_ty(scalar_type::index); }, [&](group_data_type const &g) { - m = dyn_cast(g.ty().get()); + m = dyn_cast(g.ty()); dt = clir::pointer_to( to_clir_ty(scalar_type::index, clir::address_space::global_t)); }, @@ -144,7 +146,7 @@ clir::var convert_to_opencl_pass::declare(value_node const &v) { auto convert_to_opencl_pass::get_memref_type(value_node const &v) const -> const memref_data_type * { - auto t = dyn_cast(v.ty().get()); + auto t = dyn_cast(v.ty()); if (t == nullptr) { throw compilation_error(v.loc(), status::ir_expected_memref); } @@ -215,7 +217,7 @@ std::vector convert_to_opencl_pass::operator()(alloca_inst const &a) "Invalid stack_ptr in alloca. Did you run set_stack_ptrs?"); } auto result_var = declare(*a.result()); - auto t = dyn_cast(a.result()->ty().get()); + auto t = dyn_cast(a.result()->ty()); if (t == nullptr) { throw compilation_error(a.loc(), status::ir_expected_memref); } @@ -567,7 +569,7 @@ std::vector convert_to_opencl_pass::operator()(load_inst const &e) { *e.operand()->ty()); auto lhs = declare(*e.result()); - auto result_type = e.result()->ty().get(); + auto result_type = e.result()->ty(); if (result_type == nullptr) { throw compilation_error(e.loc(), status::internal_compiler_error, "Expected type"); } diff --git a/src/pass/insert_barrier.cpp b/src/pass/insert_barrier.cpp index 7556edd2..58639448 100644 --- a/src/pass/insert_barrier.cpp +++ b/src/pass/insert_barrier.cpp @@ -129,12 +129,12 @@ void insert_barrier_pass::run_on_region(region_node ®, aa_results const &aa) auto const get_rw = [](inst_node &in) -> reads_writes { auto rw = reads_writes{}; auto const emplace_read = [&rw](value const &v) { - if (auto *m = dyn_cast(v->ty().get()); m) { + if (auto *m = dyn_cast(v->ty()); m) { rw.emplace_read(m->addrspace(), v.get()); } }; auto const emplace_write = [&rw](value const &v) { - if (auto *m = dyn_cast(v->ty().get()); m) { + if (auto *m = dyn_cast(v->ty()); m) { rw.emplace_write(m->addrspace(), v.get()); } }; diff --git a/src/pass/stack.cpp b/src/pass/stack.cpp index 9b5d939f..0bafb883 100644 --- a/src/pass/stack.cpp +++ b/src/pass/stack.cpp @@ -28,7 +28,7 @@ void set_stack_ptr_pass::run_on_function(function_node &fn) { walk(fn, [&allocs](inst_node &i) { visit(overloaded{ [&allocs](alloca_inst &a) { - auto t = dyn_cast(a.result()->ty().get()); + auto t = dyn_cast(a.result()->ty()); if (t == nullptr) { throw compilation_error(a.loc(), status::ir_expected_memref); } diff --git a/src/pass/work_group_size.cpp b/src/pass/work_group_size.cpp index c042ed9d..da0a651a 100644 --- a/src/pass/work_group_size.cpp +++ b/src/pass/work_group_size.cpp @@ -23,7 +23,7 @@ namespace tinytc { auto get_memref_type(value_node &v) { - auto t = dyn_cast(v.ty().get()); + auto t = dyn_cast(v.ty()); if (t == nullptr) { throw compilation_error(v.loc(), status::ir_expected_memref); } diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index be3630ef..311b2fc9 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -81,33 +81,36 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( }; return exception_to_status_code( [&] { - auto const ty_ = enum_cast(ty); + auto const ty_ = get_scalar(ctx_, enum_cast(ty)); auto const tA_ = enum_cast(tA); auto const tB_ = enum_cast(tB); + auto const index_ty = get_scalar(ctx_, scalar_type::index); + auto const i64_ty = get_scalar(ctx_, scalar_type::i64); auto const kernel = [&](function_builder &fb, bool is_beta_nonzero) { - auto alpha = fb.argument(make_scalar(ty_, my_loc()), "alpha", my_loc()); - auto A = - fb.argument(make_memref(ty_, {selA(M, K), selA(K, M), dynamic}, - {1, ldA, strideA}, address_space::global, my_loc()), - "A", my_loc()); - auto B = - fb.argument(make_memref(ty_, {selB(K, N), selB(N, K), dynamic}, - {1, ldB, strideB}, address_space::global, my_loc()), - "B", my_loc()); - auto beta_arg = fb.argument(make_scalar(ty_, my_loc()), "beta", my_loc()); - auto C = fb.argument(make_memref(ty_, {M, N, dynamic}, {1, ldC, strideC}, - address_space::global, my_loc()), + auto alpha = fb.argument(ty_, "alpha"); + auto A = fb.argument(get_memref(ctx_, enum_cast(ty), + {selA(M, K), selA(K, M), dynamic}, + {1, ldA, strideA}, address_space::global, my_loc()), + "A", my_loc()); + auto B = fb.argument(get_memref(ctx_, enum_cast(ty), + {selB(K, N), selB(N, K), dynamic}, + {1, ldB, strideB}, address_space::global, my_loc()), + "B", my_loc()); + auto beta_arg = fb.argument(ty_, "beta"); + auto C = fb.argument(get_memref(ctx_, enum_cast(ty), {M, N, dynamic}, + {1, ldC, strideC}, address_space::global, my_loc()), "C", my_loc()); - auto beta = is_beta_nonzero ? std::move(beta_arg) : make_imm(0.0, ty_, my_loc()); + auto beta = is_beta_nonzero ? std::move(beta_arg) : make_fimm(0.0, ty_, my_loc()); fb.body( [&](region_builder &bb) { - auto gid = bb.add(make_group_id(my_loc())); - auto offsets = std::vector{make_index(0, my_loc()), - make_index(0, my_loc()), gid}; - auto size = std::vector{make_dynamic(my_loc()), - make_dynamic(my_loc()), value{}}; + auto gid = bb.add(make_group_id(ctx_, my_loc())); + auto offsets = std::vector{make_imm(0, index_ty, my_loc()), + make_imm(0, index_ty, my_loc()), gid}; + auto size = + std::vector{make_imm(dynamic, i64_ty, my_loc()), + make_imm(dynamic, i64_ty, my_loc()), value{}}; auto a = bb.add(make_subview(A, offsets, size, my_loc())); auto b = bb.add(make_subview(B, offsets, size, my_loc())); auto c = bb.add(make_subview(C, offsets, size, my_loc())); @@ -129,7 +132,8 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( }(); tinytc_source_t src; CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, p.get(), info)); - *recipe = std::make_unique(std::move(p), source(src), ty_) + *recipe = std::make_unique(std::move(p), source(src), + enum_cast(ty)) .release(); }, ctx_.get()); diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 1bc11ad8..58a88cc2 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -91,9 +91,12 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( return exception_to_status_code( [&] { - auto const ty_ = enum_cast(ty); + auto const ty_ = get_scalar(ctx_, enum_cast(ty)); + auto const index_ty = get_scalar(ctx_, scalar_type::index); + auto const i64_ty = get_scalar(ctx_, scalar_type::i64); - auto const shapes = std::vector{blas_shape{ty_, {M_block_size, N}}}; + auto const shapes = + std::vector{blas_shape{enum_cast(ty), {M_block_size, N}}}; auto [sgs, tiling] = suggest_subgroup_size_and_tiling(shapes, *info); // We want to avoid working on too many columns in parallel as there is a high @@ -106,29 +109,30 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( value &C) { auto const gemm = [&](region_builder &bb, std::vector const &offsets, value const &block_size) { - auto a = bb.add( - make_subview(A, offsets, {block_size, make_index(K, my_loc())}, my_loc())); - auto c = bb.add( - make_subview(C, offsets, {block_size, make_index(N, my_loc())}, my_loc())); + auto a = bb.add(make_subview( + A, offsets, {block_size, make_imm(K, index_ty, my_loc())}, my_loc())); + auto c = bb.add(make_subview( + C, offsets, {block_size, make_imm(N, index_ty, my_loc())}, my_loc())); bb.add(make_gemm(transpose::N, transpose::N, false, alpha, a, B, beta, c, my_loc())); }; - auto const block_size_imm = make_index(M_block_size, my_loc()); - auto gid = bb.add(make_group_id(my_loc())); - auto m = bb.add( - make_arith(arithmetic::mul, gid, make_index(M_block_size, my_loc()), my_loc())); - auto const offsets = std::vector{m, make_index(0, my_loc())}; + auto const block_size_imm = make_imm(M_block_size, index_ty, my_loc()); + auto gid = bb.add(make_group_id(ctx_, my_loc())); + auto m = bb.add(make_arith(arithmetic::mul, gid, + make_imm(M_block_size, index_ty, my_loc()), my_loc())); + auto const offsets = std::vector{m, make_imm(0, index_ty, my_loc())}; if (!is_dynamic_value(M) && M % M_block_size == 0) { gemm(bb, offsets, block_size_imm); } else { - auto M_val = - is_dynamic_value(M) ? bb.add(make_size(C, 0, my_loc())) : make_index(M); + auto M_val = is_dynamic_value(M) ? bb.add(make_size(C, 0, my_loc())) + : make_imm(M, index_ty); auto M_val_sub_m = bb.add(make_arith(arithmetic::sub, M_val, m, my_loc())); - auto cond = bb.add(make_cmp(cmp_condition::lt, M_val_sub_m, - make_index(M_block_size, my_loc()), my_loc())); - auto const dynamic_imm = make_dynamic(my_loc()); + auto cond = + bb.add(make_cmp(cmp_condition::lt, M_val_sub_m, + make_imm(M_block_size, index_ty, my_loc()), my_loc())); + auto const dynamic_imm = make_imm(dynamic, i64_ty, my_loc()); bb.ifelse( cond, [&](region_builder &bb) { gemm(bb, offsets, dynamic_imm); }, [&](region_builder &bb) { gemm(bb, offsets, block_size_imm); }, {}, @@ -137,22 +141,22 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( }; auto const kernel = [&](function_builder &fb, bool is_beta_nonzero) { - auto alpha = fb.argument(make_scalar(ty_, my_loc()), "alpha", my_loc()); - auto A = - fb.argument(make_memref(ty_, {M, K}, {1, ldA}, address_space::global, my_loc()), - "A", my_loc()); - auto B = - fb.argument(make_memref(ty_, {K, N}, {1, ldB}, address_space::global, my_loc()), - "B", my_loc()); - auto beta_arg = fb.argument(make_scalar(ty_, my_loc()), "beta", my_loc()); - auto C = - fb.argument(make_memref(ty_, {M, N}, {1, ldC}, address_space::global, my_loc()), - "C", my_loc()); + auto alpha = fb.argument(ty_, "alpha", my_loc()); + auto A = fb.argument(get_memref(ctx_, enum_cast(ty), {M, K}, {1, ldA}, + address_space::global, my_loc()), + "A", my_loc()); + auto B = fb.argument(get_memref(ctx_, enum_cast(ty), {K, N}, {1, ldB}, + address_space::global, my_loc()), + "B", my_loc()); + auto beta_arg = fb.argument(ty_, "beta", my_loc()); + auto C = fb.argument(get_memref(ctx_, enum_cast(ty), {M, N}, {1, ldC}, + address_space::global, my_loc()), + "C", my_loc()); fb.subgroup_size(sgs); auto const wgs = tiling.work_group_size(sgs); fb.work_group_size(wgs[0], wgs[1]); - auto beta = is_beta_nonzero ? beta_arg : make_imm(0.0, ty_, my_loc()); + auto beta = is_beta_nonzero ? beta_arg : make_fimm(0.0, ty_, my_loc()); fb.body([&](region_builder &bb) { body(bb, alpha, A, B, beta, C); }, my_loc()); }; @@ -168,8 +172,9 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( }(); tinytc_source_t src; CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, p.get(), info)); - *recipe = std::make_unique(std::move(p), source(src), ty_, M, - ldA, ldB, ldC, M_block_size) + *recipe = std::make_unique(std::move(p), source(src), + enum_cast(ty), M, ldA, + ldB, ldC, M_block_size) .release(); }, ctx_.get()); diff --git a/src/support/util.hpp b/src/support/util.hpp index 2f26cd13..f6c91bba 100644 --- a/src/support/util.hpp +++ b/src/support/util.hpp @@ -4,6 +4,8 @@ #ifndef UTIL_20240201_HPP #define UTIL_20240201_HPP +#include +#include #include #include @@ -23,6 +25,29 @@ template class iterator_range_wrapper { ItT begin_, end_; }; +constexpr auto fnv1a0() -> std::uint64_t { return 0xcbf29ce484222325; } +constexpr auto fnv1a_step(std::uint64_t hash, char ch) -> std::uint64_t { + return (hash ^ ch) * 0x00000100000001b3; +} +template constexpr auto fnv1a_step(std::uint64_t hash, T &&t) -> std::uint64_t { + char buf[sizeof(T)]; + std::memcpy(buf, &t, sizeof(T)); + for (std::size_t i = 0; i < sizeof(T); ++i) { + hash = fnv1a_step(hash, buf[i]); + } + return hash; +} + +template +constexpr auto fnv1a_step(std::uint64_t hash, Head &&head, Tail &&...tail) -> std::uint64_t { + return fnv1a_step(fnv1a_step(hash, std::forward(tail)...), std::forward(head)); +} + +template +constexpr auto fnv1a(Head &&head, Tail &&...tail) -> std::uint64_t { + return fnv1a_step(fnv1a_step(fnv1a0(), std::forward(tail)...), std::forward(head)); +} + } // namespace tinytc #endif // UTIL_20240201_HPP diff --git a/src/value.cpp b/src/value.cpp index c87f73dd..3ed0ad4e 100644 --- a/src/value.cpp +++ b/src/value.cpp @@ -18,13 +18,13 @@ using namespace tinytc; namespace { template -tinytc_status_t create_imm(tinytc_value_t *vl, T imm, tinytc_scalar_type_t type, +tinytc_status_t create_imm(tinytc_value_t *vl, T imm, tinytc_data_type_t type, const tinytc_location_t *lc) { if (vl == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *vl = std::make_unique(imm, enum_cast(type)).release(); + *vl = std::make_unique(imm, type).release(); if (lc) { (*vl)->loc(*lc); } @@ -39,14 +39,14 @@ tinytc_status_t tinytc_value_create(tinytc_value_t *vl, tinytc_data_type_t type, return tinytc_status_invalid_arguments; } return exception_to_status_code( - [&] { *vl = std::make_unique(data_type(type, true), get_optional(lc)).release(); }); + [&] { *vl = std::make_unique(type, get_optional(lc)).release(); }); } -tinytc_status_t tinytc_float_imm_create(tinytc_value_t *vl, double imm, tinytc_scalar_type_t type, +tinytc_status_t tinytc_float_imm_create(tinytc_value_t *vl, double imm, tinytc_data_type_t type, const tinytc_location_t *loc) { return create_imm(vl, imm, type, loc); } -tinytc_status_t tinytc_int_imm_create(tinytc_value_t *vl, int64_t imm, tinytc_scalar_type_t type, +tinytc_status_t tinytc_int_imm_create(tinytc_value_t *vl, int64_t imm, tinytc_data_type_t type, const tinytc_location_t *loc) { return create_imm(vl, imm, type, loc); } diff --git a/test/codegen/load.ir b/test/codegen/load.ir index 40536124..df94f0f6 100644 --- a/test/codegen/load.ir +++ b/test/codegen/load.ir @@ -15,6 +15,6 @@ func @kernel1(%a: memref, %b: memref, %c: group>) func @kernel2(%c: group, offset: 21>) { %0 = group_id - %1 = load %c[%0] : group> + %1 = load %c[%0] : group, offset: 21> ; CHECK: global float* x1 = *(c + x0) + 21; } diff --git a/test/opt/insert-lifetime-stop.ir b/test/opt/insert-lifetime-stop.ir index b2595e05..dbceb9ed 100644 --- a/test/opt/insert-lifetime-stop.ir +++ b/test/opt/insert-lifetime-stop.ir @@ -11,7 +11,7 @@ func @basic() { func @use1(%A: memref, %C: memref) { ; CHECK-LABEL: func @use1{{.*}} %B = alloca -> memref - gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref + gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref ; CHECK: gemm.n.n{{.*}} ; CHECK-NEXT: lifetime_stop %B } @@ -19,10 +19,10 @@ func @use1(%A: memref, %C: memref) { func @use2(%A: memref, %C: memref) { ; CHECK-LABEL: func @use2{{.*}} %B = alloca -> memref - gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref + gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref %B2 = alloca -> memref - gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref - gemm.n.n 1.0, %A, %B2, 0.0, %C : f32, memref, memref, f32, memref + gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref + gemm.n.n 1.0, %A, %B2, 0.0, %C : f32, memref, memref, f32, memref ; CHECK: %B2 = {{.*}} ; CHECK-NEXT: gemm.n.n{{.*}} ; CHECK-NEXT: lifetime_stop %B @@ -33,9 +33,9 @@ func @use2(%A: memref, %C: memref) { func @use_alias(%A: memref, %C: memref) { ; CHECK-LABEL: func @use_alias{{.*}} %B = alloca -> memref - %0 = fuse %B[1,3] : memref - %1 = subview %0[0:8,:] : memref - gemm.n.n 1.0, %A, %1, 0.0, %C : f32, memref, memref>, f32, memref + %0 = fuse %B[1,3] : memref + %1 = subview %0[0:8,:] : memref + gemm.n.n 1.0, %A, %1, 0.0, %C : f32, memref, memref,local>, f32, memref ; CHECK: gemm.n.n{{.*}} ; CHECK-NEXT: lifetime_stop %B } @@ -47,8 +47,9 @@ func @region1() { %1 = alloca -> memref for %k=0,4 : index { %2 = alloca -> memref - gemm.n.n 1.0, %0, %1, 0.0, %2 : f32, memref, memref, f32, memref - axpby.n 1.0, %0, 0.0, %1 : f32, memref, f32, memref + gemm.n.n 1.0, %0, %1, 0.0, %2 + : f32, memref, memref, f32, memref + axpby.n 1.0, %0, 0.0, %1 : f32, memref, f32, memref } } ; CHECK: gemm.n.n{{.*}} From ad5b3d1bf5e58112ef99fc4752377d75abdb1907 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 26 Sep 2024 19:01:24 +0200 Subject: [PATCH 028/297] Started removing immediate operands Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 133 +++++++++-------- include/tinytc/tinytc.h | 42 +++--- include/tinytc/tinytc.hpp | 59 ++++++-- include/tinytc/types.h | 2 +- include/tinytc/types.hpp | 2 +- src/error.cpp | 9 +- src/inst.cpp | 58 +++++--- src/node/inst_node.cpp | 166 ++++++++------------- src/node/inst_node.hpp | 65 ++++++--- src/parser/lexer.re | 1 + src/parser/parser_impl.yy | 215 +++++++++++++++------------- src/pass/convert_to_opencl.cpp | 111 +++++++------- src/pass/convert_to_opencl.hpp | 1 + src/pass/dump_ir.cpp | 65 +++++++-- src/pass/dump_ir.hpp | 1 + src/recipe/small_gemm_batched.cpp | 20 +-- src/recipe/tall_and_skinny.cpp | 45 +++--- test/codegen/atomic.ir | 12 +- test/codegen/axpby0.ir | 5 +- test/codegen/axpby1.ir | 27 ++-- test/codegen/dope_vector_group0.ir | 18 ++- test/codegen/expand.ir | 204 ++++++++++++++------------ test/codegen/for.ir | 12 +- test/codegen/fuse.ir | 20 +-- test/codegen/if.ir | 48 ++++--- test/codegen/load.ir | 5 +- test/codegen/store.ir | 5 +- test/codegen/subview_return_type.ir | 87 +++++------ test/codegen/type_mismatch1.ir | 7 +- test/opt/check-ir/nesting0.ir | 10 +- test/opt/check-ir/nesting1.ir | 8 +- test/opt/check-ir/nesting3.ir | 6 +- test/opt/insert-barrier.ir | 51 ++++--- test/opt/insert-lifetime-stop.ir | 27 ++-- test/opt/work-group-size.ir | 16 ++- 35 files changed, 878 insertions(+), 685 deletions(-) diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index b35a29c8..7e0a08eb 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -61,8 +61,9 @@ Constants .. code:: abnf - sign = "-" / "+" + constant = floating-constant / integer-constant integer-constant = "true" / "false" / [sign] 1*DIGIT + sign = "-" / "+" floating-constant = [sign] *DIGIT "." 1*DIGIT ["e" [sign] 1*DIGIT] mantissa-dec = *DIGIT "." 1*DIGIT / 1*DIGIT "." mantissa-hex = *HEXDIG "." 1*HEXDIG / 1*HEXDIG "." @@ -289,9 +290,8 @@ Axpby .. code:: abnf transpose = ".t" / ".n" - const-or-val = floating-constant / local-identifier instruction =/ "axpby" transpose [".atomic"] - const-or-val "," local-identifier "," const-or-val "," local-identifier + local-identifier "," local-identifier "," local-identifier "," local-identifier ":" scalar-type "," memref-type "," scalar-type "," memref-type Overview @@ -333,7 +333,7 @@ Foreach .. code:: abnf - instruction =/ "foreach" local-identifier "=" identifier-or-int-constant "," identifier-or-int-constant + instruction =/ "foreach" local-identifier "=" local-identifier "," local-identifier [":" integer-type] region Overview @@ -342,9 +342,8 @@ Overview A foreach loop that executes the loop's range [from; to) without any sequence guarantee. The region of a foreach is a *spmd region*. -The loop's range [from; to) is given by the first integer constant and second integer constant, +The loop's range [from; to) is given by the first integer value and second integer value, and the trip count is stored in the local identifier. -The integer type of the loop variable is given after the colon. The integer type of the loop variable and the loop bounds is given after the colon. The default integer type is ``index``. @@ -354,7 +353,7 @@ GEMM .. code:: abnf instruction =/ "gemm" transpose transpose [".atomic"] - "," const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier + "," local-identifier "," local-identifier "," local-identifier "," local-identifier "," local-identifier ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type Overview @@ -397,7 +396,7 @@ GEMV .. code:: abnf instruction =/ "gemv" transpose [".atomic"] - "," const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier + "," local-identifier "," local-identifier "," local-identifier "," local-identifier "," local-identifier ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type Overview @@ -428,7 +427,7 @@ GER .. code:: abnf instruction =/ "ger" [".atomic"] - const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier + local-identifier "," local-identifier "," local-identifier "," local-identifier "," local-identifier ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type Overview @@ -458,7 +457,7 @@ Hadamard product .. code:: abnf instruction =/ "hadamard_product" [".atomic"] - const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier + local-identifier "," local-identifier "," local-identifier "," local-identifier "," local-identifier ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type Overview @@ -500,7 +499,7 @@ Sum .. code:: abnf instruction =/ "sum" transpose [".atomic"] - "," const-or-val "," local-identifier "," const-or-val "," local-identifier + "," local-identifier "," local-identifier "," local-identifier "," local-identifier ":" scalar-type "," memref-type "," scalar-type "," memref-type Overview @@ -543,7 +542,6 @@ Arithmetic (binary) .. code:: abnf - identifier-or-constant = local-identifier / integer-constant / floating-constant arith-binary-type = ".add" / ".sub" / ".mul" / @@ -554,8 +552,7 @@ Arithmetic (binary) ".and" / ".or" / ".xor" - value-instruction =/ "arith" arith-binary-type - identifier-or-constant "," identifier-or-constant ":" scalar-type + value-instruction =/ "arith" arith-binary-type local-identifier "," local-identifier ":" scalar-type Overview ~~~~~~~~ @@ -584,7 +581,7 @@ Arithmetic (unary) .. code:: abnf arith-unary-type = ".neg" / ".not" - value-instruction =/ "arith" arith-unary-type identifier-or-constant ":" scalar-type + value-instruction =/ "arith" arith-unary-type local-identifier ":" scalar-type Overview ~~~~~~~~ @@ -632,7 +629,7 @@ Cast .. code:: abnf - value-instruction =/ "cast" identifier-or-constant ":" scalar-type "->" scalar-type + value-instruction =/ "cast" local-identifier ":" scalar-type "->" scalar-type Overview ~~~~~~~~ @@ -645,7 +642,7 @@ Comparison .. code:: abnf value-instruction =/ "cmp" (".eq" / ".ne" / ".gt" / ".ge" / ".lt" / ".le") - identifier-or-constant "," identifier-or-constant ":" scalar-type + local-identifier "," local-identifier ":" scalar-type Overview ~~~~~~~~ @@ -664,14 +661,28 @@ Cond Description .le Less than or equal ==== ===================== +Constant +........ + +.. code:: abnf + + value-instruction =/ "constant" constant "->" scalar-type + +Overview +~~~~~~~~ + +Sets the result value to a constant value. +The type of the constant must match the scalar type +(e.g. an integer type requires an integer-constant and a floating type requires a floating-constant). + Expand ...... .. code:: abnf value-instruction =/ "expand" local-identifier "[" integer-constant "->" expand-shape "]" ":" memref-type - expand-shape = constant-or-dynamic-or-identifier 1*("x" constant-or-dynamic-or-identifier) - constant-or-dynamic-or-identifier = integer-constant / "?" / local-identifier + expand-shape = integer-constant-or-identifier 1*("x" integer-constant-or-identifier) + integer-constant-or-identifier = integer-constant / local-identifier Overview ~~~~~~~~ @@ -682,49 +693,51 @@ Arguments ~~~~~~~~~ The first argument must point to a value of memref type. -The integer constant in square brackets gives the mode that shall be expanded. -The expand shape gives the new shape of the mode. -Values in the expand shape must have index type. +The first integer constant before "->" gives the mode that shall be expanded. +The expand shape coming after "->" gives the new shape of the mode. +Dynamic values in the expand shape must have index type. The output type is a memref type according to the following rules: -#. **Shape:** The mode size is replaced with the expand shape. If one entry in expand shape is dynamic, - then either its size is inferred automatically if the mode size is known, or it determined automatically - at run-time if the mode size is dynamic. +#. **Shape:** The mode size is replaced with the expand shape. + The product of the expand shape must equal the size of the expanded mode. .. code:: - expand %0[1 -> 2x8] : memref ; -> memref - expand %0[1 -> 2x?] : memref ; -> memref - expand %0[1 -> ?x8] : memref ; -> memref - expand %0[1 -> 2x?] : memref ; -> memref - expand %0[1 -> ?x8] : memref ; -> memref + expand %0[1 -> 2x8] : memref ; -> memref + expand %0[1 -> 2x2x2x2] : memref ; -> memref #. **Identifiers:** Local identifiers in the expand shape are dynamic in the resulting memref type. + The product of the dynamic expand shape must equal the size of the expanded mode. .. code:: - expand %0[1 -> %1 x ?] : memref ; -> memref - expand %0[1 -> %1 x ?] : memref ; -> memref - expand %0[1 -> %1 x %2] : memref ; -> memref - expand %0[1 -> 4 x %1] : memref ; -> memref + expand %0[1 -> %1 x 2] : memref ; -> memref + expand %0[1 -> 2 x %1] : memref ; -> memref + expand %0[1 -> %1 x 2] : memref ; -> memref + expand %0[1 -> %1 x 2] : memref ; -> memref + expand %0[1 -> %1 x %2 x 2] : memref ; -> memref + expand %0[1 -> %2 x 2 x %1] : memref ; -> memref + expand %0[1 -> %1 x %2] : memref ; -> memref + expand %0[1 -> %1 x %2] : memref ; -> memref + + *Note:* In the third example above, %1 must be equal to 8. + The output mode corresponding to %1 is still dynamic. #. **Stride:** A new stride entry is entered that follows the canonical stride computation. .. code:: - expand %0[0->4x8] : memref> ; -> memref> - expand %0[0->4x?] : memref> ; -> memref> - expand %0[0->?x4] : memref> ; -> memref> - expand %0[0->4x?] : memref> ; -> memref> + expand %0[0->4 x 8] : memref> ; -> memref> + expand %0[0->%1 x 4] : memref> ; -> memref> + expand %0[0->4 x %1] : memref> ; -> memref> Restrictions ~~~~~~~~~~~~ -At most one mode in expand-shape must be dynamic. - The product of the expand shape must be the same as the mode size. -If one entry in the expand shape is dynamic then the other must evenly divide the mode size. +If the product of the expand shape is only known at runtime, then it is undefined behaviour +if the dynamic product does not match the mode size. Fuse .... @@ -815,7 +828,7 @@ If .. code:: abnf - multi-value-instruction = "if" identifier-or-int-constant ["->" "(" scalar-type-list ")"] + multi-value-instruction = "if" local-identifier ["->" "(" scalar-type-list ")"] region ["else" region] type-list = scalar-type *("," scalar-type) @@ -852,9 +865,7 @@ Load .. code:: abnf - value-instruction =/ "load" local-identifier "[" [index-list] "]" ":" memref-or-group-type - index-list = identifier-or-int-constant *("," identifier-or-int-constant) - identifier-or-int-constant = integer-constant / local-identifier + value-instruction =/ "load" local-identifier "[" [local-identifier-list] "]" ":" memref-or-group-type memref-or-group-type = memref-type / group-type Overview @@ -898,8 +909,8 @@ For .. code:: abnf - instruction =/ "for" local-identifier "=" identifier-or-int-constant "," identifier-or-int-constant - ["," identifier-or-int-constant] [":" integer-type] region + instruction =/ "for" local-identifier "=" local-identifier "," local-identifier + ["," local-identifier] [":" integer-type] region Overview ~~~~~~~~ @@ -961,7 +972,7 @@ Subview value-instruction =/ "subview" local-identifier "[" [index-or-slice-list] "]" ":" memref-type index-or-slice-list = index-or-slice *("," index-or-slice) - index-or-slice = identifier-or-int-constant [":" (identifier-or-int-constant / "?")] / ":" + index-or-slice = integer-constant-or-identifier [":" integer-constant-or-identifier] Overview ~~~~~~~~ @@ -985,10 +996,6 @@ to the index interval [%0, %0 + %1). to determine whether the mode size is known at compile-time or not. Therefore, we prefer the offset plus size notation. -A dynamic size ("?") means that the size is the mode size inferred from the memref type -minus the offset. -A plain colon is syntactic sugar for "0:?". - Zero sizes are used to encode that a rank-reduction is required, that is, the rank of size 0 is removed from the output memref type. A single index is syntactic sugar for offset plus size 0, e.g. %0 is syntactic sugar for %0:0. @@ -1016,8 +1023,8 @@ The output type is a memref type according to the following rules: .. code:: - subview %0[2:4, %1] : memref ; Returns memref> - subview %0[2:4, %1:0] : memref ; Returns memref> + subview %0[2:4, %1] : memref ; Returns memref + subview %0[2:4, %1:0] : memref ; Returns memref subview %0[2:4, %1:1] : memref ; Returns memref> #. **Output-mode size:** The size of the output mode is determined by the size field of a slice @@ -1030,21 +1037,12 @@ The output type is a memref type according to the following rules: subview %0[2:4, %2:%2, 6:7] : memref ; Returns memref subview %0[2:4, %2:%2, 6:7] : memref> ; Returns memref -#. **Dynamic size:** - - .. code:: - - subview %0[:] : memref ; Returns memref - subview %0[:] : memref ; Returns memref - subview %0[5:?] : memref ; Returns memref - subview %0[%2:?] : memref ; Returns memref - Store ..... .. code:: abnf - instruction =/ "store" local-identifier "," local-identifier "[" [index-list] "]" ":" memref-type + instruction =/ "store" local-identifier "," local-identifier "[" [local-identifier-list] "]" ":" memref-type Overview ~~~~~~~~ @@ -1067,7 +1065,6 @@ Yield .. code:: abnf instruction =/ "yield" [local-identifier-list] ":" [scalar-type-list] - identifier-or-constant-list = identifier-or-constant *("," identifier-or-constant) Overview ~~~~~~~~ @@ -1139,8 +1136,10 @@ where B and C are constant matrices and A and D are matrix batches. %1 = load %A[%0] : group> ; Returns memref %2 = subview %D[:,:,%0] : memref ; Returns memref %tmp0 = alloca -> memref - gemm.n.t 1.0, %1, %B, 0.0, %tmp0 + %zero = constant 0.0 : f32 + %one = constant 1.0 : f32 + gemm.n.t %one, %1, %B, %zero, %tmp0 : f32, memref, memref, f32, memref - gemm.n.n %alpha, %tmp0, %C, 1.0, %2 + gemm.n.n %alpha, %tmp0, %C, %one, %2 : f32, memref, memref, f32, memref } diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index edb1cdc2..8bc0525d 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -314,17 +314,20 @@ TINYTC_EXPORT tinytc_status_t tinytc_axpby_inst_create(tinytc_inst_t *instr, tin * * @param instr [out] pointer to the inst object created * @param a [in] operand - * @param mode [in] expanded mode - * @param expand_shape_size [in] dimension of expand shape; must be at least 2 - * @param expand_shape [in][range(2, expand_shape_size)] expand shape array + * @param expanded_mode [in] expanded mode + * @param static_expand_shape_size [in] dimension of static expand shape; must be at least 2 + * @param static_expand_shape [in][range(2, static expand_shape_size)] static expand shape array + * @param expand_shape_size [in][optional] dimension of expand shape; must match number of entries + * equal to TINYTC_DYNAMIC in static_expand_shape array; can be 0 + * @param expand_shape [in][optional][range(0, expand_shape_size)] expand shape array * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_expand_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - int64_t mode, uint32_t expand_shape_size, - tinytc_value_t *expand_shape, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_expand_inst_create( + tinytc_inst_t *instr, tinytc_value_t a, int64_t expanded_mode, + uint32_t static_expand_shape_size, int64_t *static_expand_shape, uint32_t expand_shape_size, + tinytc_value_t *expand_shape, const tinytc_location_t *loc); /** * @brief Create fuse instruction @@ -594,21 +597,24 @@ TINYTC_EXPORT tinytc_status_t tinytc_subgroup_size_inst_create(tinytc_inst_t *in * * @param instr [out] pointer to the inst object created * @param a [in] operand - * @param slice_list_size [in] number of slices - * @param offset_list [in][range(0, slice_list_size)] offset array; may be nullptr if - * slice_list_size is 0 - * @param size_list [in][range(0, slice_list_size)] size array; may be nullptr if slice_list_size - * is 0; size_list[i] may be nullptr if a single offset shall be passed instead of a range for the - * i-th mode + * @param static_list_size [in] number of slices + * @param static_offset_list [in][range(0, static_list_size)] offsets (need to add value to + * offset_list if static_offset_list[i] == TINYTC_DYNAMIC); may be nullptr if static_offset_list = 0 + * @param static_size_list [in][range(0, static_list_size)] sizes (need to add value to size_list + * if static_size_list[i] == TINYTC_DYNAMIC); may be nullptr if static_offset_list = 0 + * @param offset_list_size [in] number of dynamic offsets + * @param offset_list [in][range(0, offset_list_size)] offset array; may be nullptr if + * offset_list_size is 0 + * @param size_list_size [in] number of dynamic sizes + * @param size_list [in][range(0, size_list_size)] size array; may be nullptr if size_list_size is 0 * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - uint32_t slice_list_size, - tinytc_value_t *offset_list, - tinytc_value_t *size_list, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_subview_inst_create( + tinytc_inst_t *instr, tinytc_value_t a, uint32_t static_list_size, int64_t *static_offset_list, + int64_t *static_size_list, uint32_t offset_list_size, tinytc_value_t *offset_list, + uint32_t size_list_size, tinytc_value_t *size_list, const tinytc_location_t *loc); /** * @brief Create store instruction diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 45e5a8f9..c2680ab6 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -758,23 +758,32 @@ inline inst make_axpby(transpose tA, bool atomic, value const &alpha, value cons * @brief Make expand instruction * * @param a Operand - * @param mode Expanded mode - * @param expand_shape New shape of mode + * @param expanded_mode Expanded mode + * @param static_expand_shape Static expand shape + * @param expand_shape Dynamic expand shape * @param loc Source code location * * @return Instruction */ -inline inst make_expand(value const &a, std::int64_t mode, std::vector const &expand_shape, - location const &loc = {}) { +inline inst make_expand(value const &a, std::int64_t expanded_mode, + std::vector const &static_expand_shape, + std::vector const &expand_shape, location const &loc = {}) { static_assert(internal::value_reinterpret_allowed); tinytc_inst_t instr; + auto static_len = static_expand_shape.size(); + if (static_len > std::numeric_limits::max()) { + throw std::out_of_range("static expand shape too large"); + } auto len = expand_shape.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("expand shape too large"); } tinytc_value_t *eshape = const_cast(reinterpret_cast(expand_shape.data())); - CHECK_STATUS_LOC(tinytc_expand_inst_create(&instr, a.get(), mode, len, eshape, &loc), loc); + CHECK_STATUS_LOC(tinytc_expand_inst_create( + &instr, a.get(), expanded_mode, static_len, + const_cast(static_expand_shape.data()), len, eshape, &loc), + loc); return inst(instr); } @@ -1029,28 +1038,48 @@ inline inst make_subgroup_size(compiler_context const &ctx, location const &loc * @brief Make subview instruction * * @param a Operand - * @param offset_list Vector of offsets - * @param size_list Vector of sizes; initialize with empty value if only offset is required + * @param static_offset_list Static offsets + * @param static_size_list Static sizes + * @param offset_list Vector of offsets; need to add dynamic offsets here if static_offset_list + * contains "dynamic" + * @param size_list Vector of sizes; need to add dynamic sizes here if static_size_list contains + * "dynamic" * @param loc Source code location * * @return Instruction */ -inline inst make_subview(value const &a, std::vector const &offset_list, - std::vector const &size_list, location const &loc = {}) { +inline inst make_subview(value const &a, std::vector const &static_offset_list, + std::vector const &static_size_list, + std::vector const &offset_list, std::vector const &size_list, + location const &loc = {}) { static_assert(internal::value_reinterpret_allowed); tinytc_inst_t instr; - if (offset_list.size() != size_list.size()) { - throw std::invalid_argument("offset list must have the same length as the size list"); + if (static_offset_list.size() != static_size_list.size()) { + throw std::invalid_argument( + "static offset list must have the same length as the static size list"); } - auto len = offset_list.size(); - if (len > std::numeric_limits::max()) { - throw std::out_of_range("slice list too long"); + auto static_len = static_offset_list.size(); + if (static_len > std::numeric_limits::max()) { + throw std::out_of_range("static slice list too long"); + } + auto offset_len = offset_list.size(); + if (offset_len > std::numeric_limits::max()) { + throw std::out_of_range("dynamic offset list too long"); + } + auto size_len = offset_list.size(); + if (size_len > std::numeric_limits::max()) { + throw std::out_of_range("dynamic size list too long"); } tinytc_value_t *ol = const_cast(reinterpret_cast(offset_list.data())); tinytc_value_t *sl = const_cast(reinterpret_cast(size_list.data())); - CHECK_STATUS_LOC(tinytc_subview_inst_create(&instr, a.get(), len, ol, sl, &loc), loc); + CHECK_STATUS_LOC( + tinytc_subview_inst_create(&instr, a.get(), static_len, + const_cast(static_offset_list.data()), + const_cast(static_size_list.data()), offset_len, + ol, size_len, sl, &loc), + loc); return inst(instr); } diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 422f30d7..154d5312 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -55,7 +55,7 @@ typedef enum { tinytc_status_ir_expected_vector_or_matrix = 0x10a, ///< Expected a vector or marix tinytc_status_ir_unexpected_yield = 0x10b, ///< Unexpected yield instruction tinytc_status_ir_yield_mismatch = 0x10c, ///< Wrong number of yielded values - tinytc_status_ir_multiple_dynamic_modes = 0x10d, ///< At most one mode must be dynamic + tinytc_status_ir_subview_mismatch = 0x10d, ///< Mismatch in subview tinytc_status_ir_invalid_slice = 0x10e, ///< Invalid slice tinytc_status_ir_expand_shape_order_too_small = 0x10f, ///< Expand shape too small tinytc_status_ir_expand_shape_mismatch = 0x110, ///< Invalid expand shape diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 8092e86c..21f600d0 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -65,7 +65,7 @@ enum class status { ir_expected_vector_or_matrix = tinytc_status_ir_expected_vector_or_matrix, ir_unexpected_yield = tinytc_status_ir_unexpected_yield, ir_yield_mismatch = tinytc_status_ir_yield_mismatch, - ir_multiple_dynamic_modes = tinytc_status_ir_multiple_dynamic_modes, + ir_subview_mismatch = tinytc_status_ir_subview_mismatch, ir_invalid_slice = tinytc_status_ir_invalid_slice, ir_expand_shape_order_too_small = tinytc_status_ir_expand_shape_order_too_small, ir_expand_shape_mismatch = tinytc_status_ir_expand_shape_mismatch, diff --git a/src/error.cpp b/src/error.cpp index 4be98f2f..a8b91ee6 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -144,14 +144,15 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Yield encountered in non-yielding region"; case tinytc_status_ir_yield_mismatch: return "Number of yielded values does not match number of values yielded by region"; - case tinytc_status_ir_multiple_dynamic_modes: - return "At most one mode must be dynamic ('?')"; + case tinytc_status_ir_subview_mismatch: + return "Number of dynamic offsets and sizes must match number of dynamic operands"; case tinytc_status_ir_invalid_slice: - return "Offset must be non-negative and must not be '?'; size must be non-negative or '?'"; + return "Static offset and size must be non-negative or dynamic ('?')"; case tinytc_status_ir_expand_shape_order_too_small: return "Expand shape must have at least 2 entries"; case tinytc_status_ir_expand_shape_mismatch: - return "Product of expand shape must equal mode size"; + return "Number of dynamic expand shape operands must equal number of dynamic modes in " + "static expand shape"; case tinytc_status_ir_collective_called_from_spmd: return "Collective instruction must not be called from SPMD region"; case tinytc_status_ir_fp_unsupported: diff --git a/src/inst.cpp b/src/inst.cpp index ac458d51..bafa5e2b 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -170,21 +170,30 @@ tinytc_status_t tinytc_axpby_inst_create(tinytc_inst_t *instr, tinytc_transpose_ }); } -tinytc_status_t tinytc_expand_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t mode, - uint32_t expand_shape_size, tinytc_value_t *expand_shape, +tinytc_status_t tinytc_expand_inst_create(tinytc_inst_t *instr, tinytc_value_t a, + int64_t expanded_mode, uint32_t static_expand_shape_size, + int64_t *static_expand_shape, uint32_t expand_shape_size, + tinytc_value_t *expand_shape, const tinytc_location_t *loc) { - if (instr == nullptr || expand_shape == nullptr) { + if (instr == nullptr || static_expand_shape == nullptr || + (expand_shape_size > 0 && expand_shape == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto eshape_vec = std::vector(); - eshape_vec.reserve(expand_shape_size); + auto static_shape = std::vector{}; + static_shape.reserve(static_expand_shape_size); + for (uint32_t i = 0; i < static_expand_shape_size; ++i) { + static_shape.emplace_back(static_expand_shape[i]); + } + auto dynamic_shape = std::vector{}; + dynamic_shape.reserve(expand_shape_size); for (uint32_t i = 0; i < expand_shape_size; ++i) { - eshape_vec.emplace_back(value(expand_shape[i], true)); + dynamic_shape.emplace_back(value(expand_shape[i], true)); } - *instr = std::make_unique(value(a, true), mode, std::move(eshape_vec), - get_optional(loc)) - .release(); + *instr = + std::make_unique(value(a, true), expanded_mode, std::move(static_shape), + std::move(dynamic_shape), get_optional(loc)) + .release(); }); } @@ -356,25 +365,40 @@ tinytc_status_t tinytc_subgroup_size_inst_create(tinytc_inst_t *instr, } tinytc_status_t tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - uint32_t slice_list_size, tinytc_value_t *offset_list, + uint32_t static_list_size, int64_t *static_offset_list, + int64_t *static_size_list, uint32_t offset_list_size, + tinytc_value_t *offset_list, uint32_t size_list_size, tinytc_value_t *size_list, const tinytc_location_t *loc) { if (instr == nullptr || - (slice_list_size > 0 && (offset_list == nullptr || size_list == nullptr))) { + (static_list_size > 0 && (static_offset_list == nullptr || static_size_list == nullptr)) || + (offset_list_size > 0 && offset_list == nullptr) || + (size_list_size > 0 && size_list == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { + auto static_offset_vec = + static_list_size > 0 + ? std::vector(static_offset_list, static_offset_list + static_list_size) + : std::vector{}; + auto static_size_vec = + static_list_size > 0 + ? std::vector(static_size_list, static_size_list + static_list_size) + : std::vector{}; auto offset_vec = std::vector(); auto size_vec = std::vector(); - offset_vec.reserve(slice_list_size); - size_vec.reserve(slice_list_size); - for (uint32_t i = 0; i < slice_list_size; ++i) { + offset_vec.reserve(offset_list_size); + size_vec.reserve(size_list_size); + for (uint32_t i = 0; i < offset_list_size; ++i) { offset_vec.emplace_back(value(offset_list[i], true)); + } + for (uint32_t i = 0; i < size_list_size; ++i) { size_vec.emplace_back(value(size_list[i], true)); } - *instr = - std::make_unique(value(a, true), offset_vec, size_vec, get_optional(loc)) - .release(); + *instr = std::make_unique(value(a, true), std::move(static_offset_vec), + std::move(static_size_vec), std::move(offset_vec), + std::move(size_vec), get_optional(loc)) + .release(); }); } diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index cd99c1e5..b19190b2 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -14,6 +14,7 @@ #include +#include #include #include #include @@ -195,9 +196,28 @@ compare_inst::compare_inst(cmp_condition cond, value a0, value b0, location cons result(0) = make_value(scalar_data_type::get(at->context(), scalar_type::i1)); } -expand_inst::expand_inst(value op0, std::int64_t mode, std::vector const &expand_shape0, - location const &lc) - : standard_inst{IK::expand, static_cast(1 + expand_shape0.size())}, mode_(mode) { +constant_inst::constant_inst(std::variant const &value, tinytc_data_type_t ty, + location const &lc) + : standard_inst{IK::constant}, value_(value) { + loc(lc); + + if (auto st = dyn_cast(ty); st) { + if ((is_floating_type(st->ty()) && std::holds_alternative(value_)) || + (!is_floating_type(st->ty()) && std::holds_alternative(value_))) { + throw compilation_error(loc(), status::ir_scalar_mismatch); + } + } else { + throw compilation_error(loc(), status::ir_expected_scalar); + } + + result(0) = make_value(ty); +} + +expand_inst::expand_inst(value op0, std::int64_t expanded_mode, + std::vector static_expand_shape0, + std::vector const &expand_shape0, location const &lc) + : standard_inst{IK::expand, static_cast(1 + expand_shape0.size())}, + expanded_mode_(expanded_mode), static_expand_shape_(std::move(static_expand_shape0)) { op(0) = std::move(op0); for (std::size_t i = 0; i < expand_shape0.size(); ++i) { op(1 + i) = expand_shape0[i]; @@ -205,83 +225,37 @@ expand_inst::expand_inst(value op0, std::int64_t mode, std::vector const loc(lc); auto m = get_memref_type(loc(), operand()); - bool const range_ok = 0 <= mode_ && mode_ < m->dim(); + bool const range_ok = 0 <= expanded_mode_ && expanded_mode_ < m->dim(); if (!range_ok) { throw compilation_error(loc(), status::ir_out_of_bounds); } - if (expand_shape().size() < 2) { + if (static_expand_shape_.size() < 2) { throw compilation_error(loc(), status::ir_expand_shape_order_too_small); } - - auto known_expand_shape = std::vector(); - known_expand_shape.reserve(expand_shape().size()); - std::size_t dyn_count = 0, non_imm_count = 0; - for (auto &s : expand_shape()) { - visit(overloaded{[&](int_imm &i) { - if (is_dynamic_value(i.value())) { - known_expand_shape.push_back(dynamic); - ++dyn_count; - return; - } - if (i.value() < 0) { - throw compilation_error(loc(), status::ir_invalid_shape); - } - known_expand_shape.push_back(i.value()); - }, - [&](auto &) { - known_expand_shape.push_back(dynamic); - ++non_imm_count; - }}, - *s); - } - - if (dyn_count > 1) { - throw compilation_error(loc(), status::ir_multiple_dynamic_modes); - } - - auto size = m->shape(mode_); - if (!is_dynamic_value(size) && non_imm_count == 0) { - std::int64_t prod = 1; - std::int64_t dyn_mode = -1; - for (std::size_t i = 0; i < known_expand_shape.size(); ++i) { - auto const s = known_expand_shape[i]; - if (is_dynamic_value(s)) { - dyn_mode = i; - } else { - prod *= s; - } - } - if (dyn_mode >= 0) { - std::int64_t const s = size / prod; - known_expand_shape[dyn_mode] = s; - expand_shape()[dyn_mode] = - make_imm(s, scalar_data_type::get(m->context(), scalar_type::i64)); - prod *= s; - } - if (prod != size) { - throw compilation_error(loc(), status::ir_expand_shape_mismatch); - } + if (std::count(static_expand_shape_.begin(), static_expand_shape_.end(), dynamic) != + num_operands() - 1) { + throw compilation_error(loc(), status::ir_expand_shape_mismatch); } auto shape = std::vector{}; auto stride = std::vector{}; - shape.reserve(m->dim() + known_expand_shape.size() - 1); - stride.reserve(m->dim() + known_expand_shape.size() - 1); - for (std::int64_t i = 0; i < mode_; ++i) { + shape.reserve(m->dim() + static_expand_shape_.size() - 1); + stride.reserve(m->dim() + static_expand_shape_.size() - 1); + for (std::int64_t i = 0; i < expanded_mode_; ++i) { shape.push_back(m->shape(i)); stride.push_back(m->stride(i)); } - stride.push_back(m->stride(mode_)); - shape.push_back(known_expand_shape[0]); - for (std::size_t j = 1; j < known_expand_shape.size(); ++j) { + stride.push_back(m->stride(expanded_mode_)); + shape.push_back(static_expand_shape_[0]); + for (std::size_t j = 1; j < static_expand_shape_.size(); ++j) { stride.push_back(is_dynamic_value(stride.back()) || is_dynamic_value(shape.back()) ? dynamic : stride.back() * shape.back()); - shape.push_back(known_expand_shape[j]); + shape.push_back(static_expand_shape_[j]); } - for (std::int64_t i = mode_ + 1; i < m->dim(); ++i) { + for (std::int64_t i = expanded_mode_ + 1; i < m->dim(); ++i) { shape.push_back(m->shape(i)); stride.push_back(m->stride(i)); } @@ -504,68 +478,48 @@ size_inst::size_inst(value op0, std::int64_t mode, location const &lc) result(0) = make_value(scalar_data_type::get(op(0)->context(), scalar_type::index)); } -subview_inst::subview_inst(value op0, std::vector const &offset_list0, - std::vector const &size_list0, location const &lc) - : standard_inst{IK::subview, - static_cast(1 + offset_list0.size() + size_list0.size())} { +subview_inst::subview_inst(value op0, std::vector static_offsets0, + std::vector static_sizes0, + std::vector const &offsets0, std::vector const &sizes0, + location const &lc) + : standard_inst{IK::subview, static_cast(1 + offsets0.size() + sizes0.size())}, + static_offsets_(std::move(static_offsets0)), static_sizes_(std::move(static_sizes0)) { op(0) = std::move(op0); { std::size_t i = 1; - for (auto const &val : offset_list0) { + for (auto const &val : offsets0) { op(i++) = val; } - auto index_ty = scalar_data_type::get(op(0)->context(), scalar_type::index); - for (auto const &val : size_list0) { - op(i++) = val ? val : make_imm(0, index_ty); + num_dyn_offsets_ = i - 1; + for (auto const &val : sizes0) { + op(i++) = val; } } loc(lc); auto m = get_memref_type(loc(), operand()); - if (m->dim() != static_cast(offset_list0.size()) || - m->dim() != static_cast(size_list0.size())) { + if (m->dim() != static_cast(static_offsets_.size()) || + m->dim() != static_cast(static_sizes_.size())) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } + if (std::count(static_offsets_.begin(), static_offsets_.end(), dynamic) != num_dyn_offsets_ || + std::count(static_sizes_.begin(), static_sizes_.end(), dynamic) != + num_operands() - num_dyn_offsets_ - 1) { + throw compilation_error(loc(), status::ir_subview_mismatch); + } auto shape = std::vector{}; auto stride = std::vector{}; shape.reserve(m->dim()); stride.reserve(m->dim()); for (std::int64_t i = 0; i < m->dim(); ++i) { - auto &offset = offset_list()[i]; - auto &size = size_list()[i]; - visit(overloaded{[&](int_imm &i) { - if (i.value() < 0) { - throw compilation_error(loc(), status::ir_invalid_slice); - } - }, - [](auto &) {}}, - *offset); - visit(overloaded{[&](int_imm &i) { - if (i.value() < 0 && !is_dynamic_value(i.value())) { - throw compilation_error(loc(), status::ir_invalid_slice); - } - }, - [](auto &) {}}, - *size); - auto size_value = visit(overloaded{[&](int_imm &offset, int_imm &size) -> std::int64_t { - if (is_dynamic_value(size.value())) { - return is_dynamic_value(m->shape(i)) - ? dynamic - : m->shape(i) - offset.value(); - } - return size.value(); - }, - [&](val &, int_imm &size) -> std::int64_t { - if (is_dynamic_value(size.value())) { - return dynamic; - } - return size.value(); - }, - [](auto &, auto &) -> std::int64_t { return dynamic; }}, - *offset, *size); - if (size_value > 0 || is_dynamic_value(size_value)) { - shape.push_back(size_value); + auto offset = static_offsets_[i]; + auto size = static_sizes_[i]; + if ((offset < 0 && !is_dynamic_value(offset)) || (size < 0 && !is_dynamic_value(size))) { + throw compilation_error(loc(), status::ir_invalid_slice); + } + if (size > 0 || is_dynamic_value(size)) { + shape.push_back(size); stride.push_back(m->stride(i)); } } diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index bef8fe0e..d6d637e6 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -17,6 +17,7 @@ #include #include #include +#include #include namespace tinytc { @@ -35,6 +36,7 @@ enum class IK { barrier, cast, compare, + constant, expand, fuse, load, @@ -71,13 +73,14 @@ enum class IK { }; using inst_nodes = type_list; + class arith_unary_inst, class cast_inst, class compare_inst, class constant_inst, + class expand_inst, class fuse_inst, class load_inst, class group_id_inst, + class group_size_inst, class lifetime_stop_inst, class gemm_inst, class gemv_inst, + class ger_inst, class for_inst, class foreach_inst, class hadamard_inst, + class if_inst, class num_subgroups_inst, class parallel_inst, class size_inst, + class subview_inst, class store_inst, class subgroup_id_inst, + class subgroup_local_id_inst, class subgroup_size_inst, class sum_inst, + class yield_inst>; using value_range = iterator_range_wrapper; using const_value_range = iterator_range_wrapper; @@ -178,6 +181,7 @@ struct tinytc_inst : tinytc::ilist_node_with_parent case tinytc::IK::arith_unary: case tinytc::IK::cast: case tinytc::IK::compare: + case tinytc::IK::constant: case tinytc::IK::expand: case tinytc::IK::fuse: case tinytc::IK::load: @@ -428,20 +432,37 @@ class compare_inst : public standard_inst<2, 1> { cmp_condition cond_; }; +class constant_inst : public standard_inst<0, 1> { + public: + inline static bool classof(inst_node const &i) { return i.type_id() == IK::constant; } + constant_inst(std::variant const &value, tinytc_data_type_t ty, + location const &lc = {}); + + auto value() const -> std::variant const & { return value_; } + + private: + std::variant value_; +}; + class expand_inst : public standard_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::expand; } - expand_inst(value op, std::int64_t mode, std::vector const &expand_shape, - location const &lc = {}); + expand_inst(value op, std::int64_t expanded_mode, std::vector static_expand_shape, + std::vector const &expand_shape, location const &lc = {}); + + inline std::int64_t expanded_mode() const { return expanded_mode_; } + inline auto static_expand_shape() const -> std::vector const & { + return static_expand_shape_; + } inline auto operand() const -> value const & { return op(0); } - inline std::int64_t mode() const { return mode_; } inline auto expand_shape() { return operands() | std::views::drop(1); } inline auto expand_shape() const { return operands() | std::views::drop(1); } inline auto expand_shape(std::int64_t i) const -> value const & { return op(i + 1); } private: - std::int64_t mode_; + std::int64_t expanded_mode_; + std::vector static_expand_shape_; }; class fuse_inst : public standard_inst<1, 1> { @@ -636,16 +657,24 @@ class subgroup_size_inst : public standard_inst<0, 1> { class subview_inst : public standard_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::subview; } - subview_inst(value op, std::vector const &offset_list, - std::vector const &size_list, location const &lc = {}); + subview_inst(value op, std::vector static_offsets, + std::vector static_sizes, std::vector const &offsets, + std::vector const &sizes, location const &lc = {}); + + inline auto static_offsets() const -> std::vector const & { + return static_offsets_; + } + inline auto static_sizes() const -> std::vector const & { return static_sizes_; } inline auto operand() const -> value const & { return op(0); } - // We have num_operands() = 1 + 2 * num_indices() - inline auto num_indices() const { return (num_operands() - 1) / 2; } - inline auto offset_list() const { - return operands() | std::views::drop(1) | std::views::take(num_indices()); + inline auto offsets() const { + return operands() | std::views::drop(1) | std::views::take(num_dyn_offsets_); } - inline auto size_list() const { return operands() | std::views::drop(1 + num_indices()); } + inline auto sizes() const { return operands() | std::views::drop(1 + num_dyn_offsets_); } + + private: + std::vector static_offsets_, static_sizes_; + std::int32_t num_dyn_offsets_; }; class store_inst : public standard_inst { diff --git a/src/parser/lexer.re b/src/parser/lexer.re index 76037996..302d278f 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -136,6 +136,7 @@ lex: "alloca" { adv_loc(); return parser::make_ALLOCA(loc_); } "cast" { adv_loc(); return parser::make_CAST(loc_); } "cmp" { adv_loc(); return parser::make_CMP(loc_); } + "constant" { adv_loc(); return parser::make_CONSTANT(loc_); } "expand" { adv_loc(); return parser::make_EXPAND(loc_); } "fuse" { adv_loc(); return parser::make_FUSE(loc_); } "load" { adv_loc(); return parser::make_LOAD(loc_); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 646d770b..d687fb1e 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -17,6 +17,8 @@ namespace tinytc { class parse_context; class lexer; + + using int_or_val = std::variant; } } @@ -37,6 +39,7 @@ #include #include #include + #include #include namespace tinytc { @@ -118,6 +121,7 @@ ALLOCA "alloca" CAST "cast" CMP "cmp" + CONSTANT "constant" EXPAND "expand" FUSE "fuse" LOAD "load" @@ -171,9 +175,8 @@ %nterm instruction %nterm axpby_inst %nterm atomic -%nterm <::tinytc::value> identifier_or_constant -%nterm > optional_identifier_or_constant_list -%nterm > identifier_or_constant_list +%nterm > optional_value_list +%nterm > value_list %nterm barrier_inst %nterm optional_global_attr %nterm optional_local_attr @@ -201,14 +204,12 @@ %nterm arith_unary_inst %nterm cast_inst %nterm compare_inst +%nterm constant_inst %nterm expand_inst -%nterm <::tinytc::value> constant_or_dynamic_or_identifier -%nterm > expand_shape +%nterm integer_constant_or_identifier +%nterm > expand_shape %nterm fuse_inst %nterm load_inst -%nterm > optional_index_list -%nterm > index_list -%nterm <::tinytc::value> index_identifier_or_const %nterm group_id_inst %nterm group_size_inst %nterm num_subgroups_inst @@ -219,10 +220,10 @@ %nterm subgroup_size_inst %nterm store_inst %nterm subview_inst -%nterm , std::vector<::tinytc::value>>> optional_slice_list -%nterm , std::vector<::tinytc::value>>> slice_list -%nterm > slice -%nterm <::tinytc::value> slice_size +%nterm >> optional_slice_list +%nterm >> slice_list +%nterm > slice +%nterm slice_size %% prog: @@ -421,7 +422,7 @@ instruction: axpby_inst: AXPBY transpose[ta] atomic - identifier_or_constant[alpha] COMMA var[a] COMMA identifier_or_constant[beta] COMMA var[b] + var[alpha] COMMA var[a] COMMA var[beta] COMMA var[b] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA scalar_type[fbeta] COMMA memref_type[mb] { check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); check_type($a, $ma, @a, @ma); @@ -445,30 +446,16 @@ atomic: | ATOMIC { $$ = true; } ; -identifier_or_constant: - var { $$ = $var; } - | INTEGER_CONSTANT { - auto i64_ty = get_scalar(ctx.cctx(), scalar_type::i64); - $$ = make_imm($INTEGER_CONSTANT, i64_ty); - $$->loc(@INTEGER_CONSTANT); - } - | FLOATING_CONSTANT { - auto f64_ty = get_scalar(ctx.cctx(), scalar_type::f64); - $$ = make_fimm($FLOATING_CONSTANT, f64_ty); - $$->loc(@FLOATING_CONSTANT); - } -; - -optional_identifier_or_constant_list: +optional_value_list: %empty {} - | identifier_or_constant_list { $$ = std::move($1); } + | value_list { $$ = std::move($1); } ; -identifier_or_constant_list: - identifier_or_constant { $$.push_back(std::move($identifier_or_constant)); } - | identifier_or_constant_list COMMA identifier_or_constant { +value_list: + var { $$.push_back(std::move($var)); } + | value_list COMMA var { $$ = std::move($1); - $$.push_back(std::move($identifier_or_constant)); + $$.push_back(std::move($var)); } ; @@ -498,7 +485,7 @@ optional_local_attr: gemm_inst: GEMM transpose[ta] transpose[tb] atomic - identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] + var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] COMMA memref_type[mc] { check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); @@ -522,7 +509,7 @@ gemm_inst: gemv_inst: GEMV transpose[ta] atomic - identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] + var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] COMMA memref_type[mc] { check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); @@ -550,7 +537,7 @@ transpose: ger_inst: GER atomic - identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] + var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] COMMA memref_type[mc] { check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); @@ -573,7 +560,7 @@ ger_inst: for_inst: FOR LOCAL_IDENTIFIER[loop_var] - EQUALS identifier_or_constant[from] COMMA identifier_or_constant[to] optional_step + EQUALS var[from] COMMA var[to] optional_step for_loop_var_type { check_scalar_type(ctx.cctx(), $from, $for_loop_var_type, @from, @for_loop_var_type); check_scalar_type(ctx.cctx(), $to, $for_loop_var_type, @to, @for_loop_var_type); @@ -599,11 +586,11 @@ for_inst: optional_step: %empty { $$ = {}; } - | COMMA identifier_or_constant { $$ = $identifier_or_constant; } + | COMMA var { $$ = $var; } foreach_inst: FOREACH LOCAL_IDENTIFIER[loop_var] - EQUALS identifier_or_constant[from] COMMA identifier_or_constant[to] for_loop_var_type { + EQUALS var[from] COMMA var[to] for_loop_var_type { check_scalar_type(ctx.cctx(), $from, $for_loop_var_type, @from, @for_loop_var_type); check_scalar_type(ctx.cctx(), $to, $for_loop_var_type, @to, @for_loop_var_type); auto v = make_value(get_scalar(ctx.cctx(), $for_loop_var_type)); @@ -652,7 +639,7 @@ identifier_list: hadamard_inst: HADAMARD atomic - identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] + var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] COMMA memref_type[mc] { check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); @@ -676,7 +663,7 @@ hadamard_inst: sum_inst: SUM transpose[ta] atomic - identifier_or_constant[alpha] COMMA var[a] COMMA identifier_or_constant[beta] COMMA var[b] + var[alpha] COMMA var[a] COMMA var[beta] COMMA var[b] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA scalar_type[fbeta] COMMA memref_type[mb] { check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); check_type($a, $ma, @a, @ma); @@ -696,7 +683,7 @@ sum_inst: ; yield_inst: - YIELD optional_identifier_or_constant_list[vals] COLON optional_scalar_type_list[tys] { + YIELD optional_value_list[vals] COLON optional_scalar_type_list[tys] { if ($vals.size() != $tys.size()) { location loc = @vals; loc.end = @tys.end; @@ -719,6 +706,7 @@ valued_inst: | arith_unary_inst { $$ = std::move($1); } | cast_inst { $$ = std::move($1); } | compare_inst { $$ = std::move($1); } + | constant_inst { $$ = std::move($1); } | expand_inst { $$ = std::move($1); } | fuse_inst { $$ = std::move($1); } | group_id_inst { $$ = std::move($1); } @@ -747,7 +735,7 @@ alloca_inst: ; arith_inst: - ARITH ARITHMETIC identifier_or_constant[a] COMMA identifier_or_constant[b] COLON scalar_type[ty] { + ARITH ARITHMETIC var[a] COMMA var[b] COLON scalar_type[ty] { check_scalar_type(ctx.cctx(), $a, $ty, @a, @ty); check_scalar_type(ctx.cctx(), $b, $ty, @b, @ty); try { @@ -763,7 +751,7 @@ arith_inst: ; arith_unary_inst: - ARITH ARITHMETIC_UNARY identifier_or_constant[a] COLON scalar_type[ty] { + ARITH ARITHMETIC_UNARY var[a] COLON scalar_type[ty] { check_scalar_type(ctx.cctx(), $a, $ty, @a, @ty); try { $$ = inst { @@ -780,7 +768,7 @@ arith_unary_inst: cast_inst: - CAST identifier_or_constant[a] COLON scalar_type[from] RETURNS scalar_type[to] { + CAST var[a] COLON scalar_type[from] RETURNS scalar_type[to] { check_scalar_type(ctx.cctx(), $a, $from, @a, @from); try { $$ = inst { std::make_unique(std::move($a), $to, @cast_inst).release() }; @@ -792,7 +780,7 @@ cast_inst: ; compare_inst: - CMP CMP_CONDITION identifier_or_constant[a] COMMA identifier_or_constant[b] COLON scalar_type[ty] { + CMP CMP_CONDITION var[a] COMMA var[b] COLON scalar_type[ty] { check_scalar_type(ctx.cctx(), $a, $ty, @a, @ty); check_scalar_type(ctx.cctx(), $b, $ty, @b, @ty); try { @@ -808,44 +796,79 @@ compare_inst: } ; +constant_inst: + CONSTANT FLOATING_CONSTANT RETURNS data_type { + try { + $$ = inst { + std::make_unique($FLOATING_CONSTANT, $data_type, @constant_inst).release() + }; + } catch (compilation_error const &e) { + error(e.loc(), e.what()); + YYERROR; + } + } + | CONSTANT INTEGER_CONSTANT RETURNS data_type { + try { + $$ = inst { + std::make_unique($INTEGER_CONSTANT, $data_type, @constant_inst).release() + }; + } catch (compilation_error const &e) { + error(e.loc(), e.what()); + YYERROR; + } + } +; + expand_inst: - EXPAND var LSQBR INTEGER_CONSTANT[mode] RETURNS expand_shape RSQBR COLON memref_type { + EXPAND var LSQBR INTEGER_CONSTANT[expanded_mode] RETURNS expand_shape RSQBR COLON memref_type { if ($var->ty() != $memref_type) { auto loc = @var; loc.end = @memref_type.end; throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); } try { + auto static_shape = std::vector{}; + static_shape.reserve($expand_shape.size()); + auto dynamic_shape = std::vector{}; + dynamic_shape.reserve($expand_shape.size()); + for (auto &s : $expand_shape) { + std::visit(overloaded{ + [&](std::int64_t i) { static_shape.push_back(i); }, + [&](value const &v) { + static_shape.push_back(dynamic); + dynamic_shape.push_back(v); + }, + }, s); + } $$ = inst { - std::make_unique(std::move($var), $mode, std::move($expand_shape), - @expand_inst) + std::make_unique(std::move($var), $expanded_mode, std::move(static_shape), + std::move(dynamic_shape), @expand_inst) .release() }; } catch (compilation_error const &e) { error(e.loc(), e.what()); YYERROR; + } catch (std::exception const& e) { + error(@expand_inst, e.what()); } } ; expand_shape: - constant_or_dynamic_or_identifier[a] TIMES constant_or_dynamic_or_identifier[b] { - $$ = std::vector{$a, $b}; + integer_constant_or_identifier[a] TIMES integer_constant_or_identifier[b] { + $$ = std::vector>{$a, $b}; } - | expand_shape TIMES constant_or_dynamic_or_identifier[a] { $$ = std::move($1); $$.push_back($a); } + | expand_shape TIMES integer_constant_or_identifier[a] { $$ = std::move($1); $$.push_back($a); } ; -constant_or_dynamic_or_identifier: +integer_constant_or_identifier: var { check_scalar_type(ctx.cctx(), $var, scalar_type::index, @var, @var); $$ = $var; } | INTEGER_CONSTANT { - auto index_ty = get_scalar(ctx.cctx(), scalar_type::index); - $$ = make_imm($INTEGER_CONSTANT, index_ty); - $$->loc(@INTEGER_CONSTANT); + $$ = $INTEGER_CONSTANT; } - | DYNAMIC { $$ = make_imm(dynamic, get_scalar(ctx.cctx(), scalar_type::i64)); $$->loc(@DYNAMIC); } ; fuse_inst: @@ -867,7 +890,7 @@ fuse_inst: ; load_inst: - LOAD var LSQBR optional_index_list RSQBR COLON memref_or_group_type { + LOAD var LSQBR optional_value_list RSQBR COLON memref_or_group_type { if ($var->ty() != $memref_or_group_type) { auto loc = @var; loc.end = @memref_or_group_type.end; @@ -875,7 +898,7 @@ load_inst: } try { $$ = inst { - std::make_unique(std::move($var), std::move($optional_index_list), + std::make_unique(std::move($var), std::move($optional_value_list), @load_inst) .release() }; @@ -886,29 +909,8 @@ load_inst: } ; -optional_index_list: - %empty {} - | index_list { $$ = std::move($1); } -; - -index_list: - index_identifier_or_const { $$.push_back($index_identifier_or_const); } - | index_list COMMA index_identifier_or_const { $$ = std::move($1); $$.push_back($index_identifier_or_const); } -; - -index_identifier_or_const: - var { - check_scalar_type(ctx.cctx(), $var, scalar_type::index, @var, @var); - $$ = $var; - } - | INTEGER_CONSTANT { - $$ = make_imm($INTEGER_CONSTANT, get_scalar(ctx.cctx(), scalar_type::index)); - $$->loc(@INTEGER_CONSTANT); - } -; - store_inst: - STORE var[a] COMMA var[b] LSQBR optional_index_list RSQBR COLON memref_type { + STORE var[a] COMMA var[b] LSQBR optional_value_list RSQBR COLON memref_type { if ($b->ty() != $memref_type) { auto loc = @b; loc.end = @memref_type.end; @@ -917,7 +919,7 @@ store_inst: try { $$ = inst { std::make_unique(std::move($a), std::move($b), - std::move($optional_index_list), @store_inst) + std::move($optional_value_list), @store_inst) .release() }; } catch (compilation_error const &e) { @@ -936,7 +938,7 @@ group_size_inst: ; if_inst: - IF identifier_or_constant[condition] optional_returned_values region else_region { + IF var[condition] optional_returned_values region else_region { check_scalar_type(ctx.cctx(), $condition, scalar_type::i1, @condition, @condition); $$ = inst{std::make_unique(std::move($condition), std::move($region), std::move($else_region), @@ -1016,15 +1018,42 @@ subview_inst: throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); } try { + auto static_offsets = std::vector{}; + auto static_sizes = std::vector{}; + auto offsets = std::vector{}; + auto sizes = std::vector{}; + static_offsets.reserve($optional_slice_list.size()); + static_sizes.reserve($optional_slice_list.size()); + offsets.reserve($optional_slice_list.size()); + sizes.reserve($optional_slice_list.size()); + for (auto &s : $optional_slice_list) { + std::visit(overloaded{ + [&](std::int64_t i) { static_offsets.push_back(i); }, + [&](value const &v) { + static_offsets.push_back(dynamic); + offsets.push_back(v); + }, + }, s.first); + std::visit(overloaded{ + [&](std::int64_t i) { static_sizes.push_back(i); }, + [&](value const &v) { + static_sizes.push_back(dynamic); + sizes.push_back(v); + }, + }, s.second); + } $$ = inst { - std::make_unique(std::move($var), $optional_slice_list.first, - $optional_slice_list.second, @subview_inst) + std::make_unique(std::move($var), std::move(static_offsets), std::move(static_sizes), + std::move(offsets), std::move(sizes), @subview_inst) .release() }; } catch (compilation_error const &e) { error(e.loc(), e.what()); YYERROR; + } catch (std::exception const& e) { + error(@subview_inst, e.what()); } + } ; @@ -1035,29 +1064,21 @@ optional_slice_list: slice_list: slice { - $$.first.emplace_back(std::move($slice.first)); - $$.second.emplace_back(std::move($slice.second)); + $$.emplace_back(std::move($slice)); } | slice_list COMMA slice { $$ = std::move($1); - $$.first.emplace_back(std::move($slice.first)); - $$.second.emplace_back(std::move($slice.second)); + $$.emplace_back(std::move($slice)); } ; slice: - COLON { - auto index_ty = get_scalar(ctx.cctx(), scalar_type::index); - auto i64_ty = get_scalar(ctx.cctx(), scalar_type::i64); - $$ = std::make_pair(make_imm(0, index_ty), make_imm(dynamic, i64_ty)); - } - | index_identifier_or_const slice_size { $$ = std::make_pair(std::move($1), std::move($2)); } + integer_constant_or_identifier slice_size { $$ = std::make_pair(std::move($1), std::move($2)); } ; slice_size: %empty { $$ = {}; } - | COLON index_identifier_or_const { $$ = $2; } - | COLON DYNAMIC { $$ = make_imm(dynamic, get_scalar(ctx.cctx(), scalar_type::i64)); } + | COLON integer_constant_or_identifier { $$ = $2; } ; %% diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index ab6f9e98..36c5b1e3 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -30,8 +30,7 @@ #include #include #include - -#include +#include namespace tinytc { @@ -417,11 +416,22 @@ std::vector convert_to_opencl_pass::operator()(compare_inst const &c make(c.cond(), visit(*this, *c.a()), visit(*this, *c.b())))}; } +std::vector convert_to_opencl_pass::operator()(constant_inst const &c) { + auto v = declare(*c.result()); + auto ty = get_scalar_type(*c.result()); + auto rhs = std::visit( + overloaded{[&](std::int64_t i) { return clir::expr(i, static_cast(size(ty) * 8)); }, + [&](double d) { return clir::expr(d, static_cast(size(ty) * 8)); }}, + c.value()); + return {declaration_assignment(visit(*this, *c.result()->ty()), std::move(v), std::move(rhs))}; +} + std::vector convert_to_opencl_pass::operator()(expand_inst const &e) { auto result_var = declare(*e.result()); auto m = get_memref_type(*e.operand()); auto &dv = get_dope_vector(e.operand().get()); - auto eshape = e.expand_shape(); + auto static_shape = e.static_expand_shape(); + auto dyn_shape = e.expand_shape(); auto rhs = visit(*this, *e.operand()); auto clinst = std::vector{}; @@ -430,48 +440,32 @@ std::vector convert_to_opencl_pass::operator()(expand_inst const &e) auto shape = std::vector{}; auto stride = std::vector{}; - shape.reserve(m->dim() + eshape.size() - 1); - stride.reserve(m->dim() + eshape.size() - 1); + shape.reserve(m->dim() + static_shape.size() - 1); + stride.reserve(m->dim() + static_shape.size() - 1); std::int64_t i = 0; - for (; i < e.mode(); ++i) { + for (; i < e.expanded_mode(); ++i) { shape.push_back(dv.shape(i)); stride.push_back(dv.stride(i)); } auto eshape_cl = std::vector{}; - eshape_cl.reserve(eshape.size()); - for (auto &s : eshape) { - eshape_cl.push_back(visit(*this, *s)); - } - - auto const get_shape = [&](std::size_t j) -> clir::expr { - auto is_dynamic = - visit(overloaded{[&](int_imm const &i) { return is_dynamic_value(i.value()); }, - [](auto const &) { return false; }}, - *eshape[j]); - if (is_dynamic) { - clir::expr prod = 1; - for (std::size_t k = 0; k < eshape_cl.size(); ++k) { - if (j != k) { - prod = prod * eshape_cl[k]; - } - } - auto inferred_size = clir::var("inferred_size"); - clinst.emplace_back(clir::declaration_assignment(to_clir_ty(scalar_type::index), - inferred_size, - std::move(prod) / dv.shape(e.mode()))); - return inferred_size; + eshape_cl.reserve(static_shape.size()); + int j = 0; + for (auto &s : static_shape) { + if (is_dynamic_value(s)) { + eshape_cl.emplace_back(visit(*this, *dyn_shape[j++])); + } else { + eshape_cl.emplace_back(clir::expr(s, static_cast(size(scalar_type::index) * 8))); } - return eshape_cl[j]; - }; + } - stride.push_back(m->stride(e.mode())); - shape.push_back(get_shape(0)); - for (std::size_t j = 1; j < eshape.size(); ++j) { + stride.push_back(m->stride(e.expanded_mode())); + shape.push_back(eshape_cl[0]); + for (std::size_t j = 1; j < eshape_cl.size(); ++j) { stride.push_back(stride.back() * shape.back()); - shape.push_back(get_shape(j)); + shape.push_back(eshape_cl[j]); } - for (i = e.mode() + 1; i < m->dim(); ++i) { + for (i = e.expanded_mode() + 1; i < m->dim(); ++i) { shape.push_back(dv.shape(i)); stride.push_back(dv.stride(i)); } @@ -918,10 +912,6 @@ std::vector convert_to_opencl_pass::operator()(subgroup_size_inst co std::vector convert_to_opencl_pass::operator()(subview_inst const &s) { auto result_var = declare(*s.result()); auto t = get_memref_type(*s.operand()); - if (t->dim() != static_cast(s.num_indices())) { - throw compilation_error(s.loc(), status::ir_invalid_number_of_indices); - } - auto &dv = get_dope_vector(s.operand().get()); auto rhs = visit(*this, *s.operand()); @@ -930,25 +920,30 @@ std::vector convert_to_opencl_pass::operator()(subview_inst const &s auto stride_out = std::vector{}; shape_out.reserve(t->dim()); stride_out.reserve(t->dim()); - for (std::int64_t i = 0; i < t->dim(); ++i) { - auto &offset = s.offset_list()[i]; - auto &size = s.size_list()[i]; - rhs = rhs + visit(*this, *offset) * dv.stride(j); - - auto size_value = - visit(overloaded{[&](int_imm &s) -> clir::expr { - if (s.value() == 0) { - return nullptr; - } else if (is_dynamic_value(s.value())) { - return dv.shape(j) - visit(*this, *offset); - } - return this->operator()(s); - }, - [&](value_node &s) -> clir::expr { return visit(*this, s); }}, - *size); - - if (size_value) { - shape_out.emplace_back(size_value); + auto dyn_offsets = s.offsets(); + auto dyn_sizes = s.sizes(); + for (std::int64_t i = 0, joffset = 0, jsize = 0; i < t->dim(); ++i) { + auto offset = s.static_offsets()[i]; + + auto offset_cl = clir::expr{}; + if (is_dynamic_value(offset)) { + offset_cl = visit(*this, *dyn_offsets[joffset++]); + } else { + offset_cl = + clir::expr(offset, static_cast(tinytc::size(scalar_type::index) * 8)); + } + rhs = rhs + offset_cl * dv.stride(j); + + auto size = s.static_sizes()[i]; + if (size > 0 || is_dynamic_value(size)) { + auto size_cl = clir::expr{}; + if (is_dynamic_value(size)) { + size_cl = visit(*this, *dyn_sizes[jsize++]); + } else { + size_cl = + clir::expr(size, static_cast(tinytc::size(scalar_type::index) * 8)); + } + shape_out.emplace_back(size_cl); stride_out.emplace_back(dv.stride(j)); } diff --git a/src/pass/convert_to_opencl.hpp b/src/pass/convert_to_opencl.hpp index 6de99374..f3b56d39 100644 --- a/src/pass/convert_to_opencl.hpp +++ b/src/pass/convert_to_opencl.hpp @@ -78,6 +78,7 @@ class convert_to_opencl_pass { std::vector operator()(arith_unary_inst const &a); std::vector operator()(cast_inst const &c); std::vector operator()(compare_inst const &c); + std::vector operator()(constant_inst const &c); std::vector operator()(expand_inst const &e); std::vector operator()(fuse_inst const &f); std::vector operator()(load_inst const &e); diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index a9993d22..4012b365 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include namespace tinytc { @@ -167,14 +168,43 @@ void dump_ir_pass::operator()(compare_inst const &a) { visit(*this, *a.a()->ty()); } +void dump_ir_pass::operator()(constant_inst const &c) { + visit(*this, *c.result()); + *os_ << " = constant "; + std::visit(overloaded{[&](std::int64_t i) { + if (is_dynamic_value(i)) { + *os_ << "?"; + } else { + *os_ << i; + } + }, + [&](double d) { + auto flags = os_->flags(); + *os_ << std::hexfloat << d; + os_->flags(flags); + }}, + c.value()); + *os_ << " -> "; + visit(*this, *c.result()->ty()); +} + void dump_ir_pass::operator()(expand_inst const &e) { visit(*this, *e.result()); *os_ << " = expand "; visit(*this, *e.operand()); - *os_ << "[" << e.mode() << "->"; - do_with_infix( - e.expand_shape().begin(), e.expand_shape().end(), - [this](auto const &i) { visit(*this, *i); }, "x"); + *os_ << "[" << e.expanded_mode() << "->"; + auto const &ses = e.static_expand_shape(); + auto es = e.expand_shape(); + for (std::size_t i = 0, j = 0; i < ses.size(); ++i) { + if (i != 0) { + *os_ << " x "; + } + if (is_dynamic_value(ses[i])) { + visit(*this, *es[j++]); + } else { + *os_ << ses[i]; + } + } *os_ << "] : "; visit(*this, *e.operand()->ty()); } @@ -317,15 +347,28 @@ void dump_ir_pass::operator()(subview_inst const &s) { *os_ << " = subview "; visit(*this, *s.operand()); *os_ << "["; - auto irange = std::ranges::iota_view{std::size_t{0}, s.offset_list().size()}; - do_with_infix(irange.begin(), irange.end(), [&](auto const &i) { - visit(*this, *s.offset_list()[i]); - auto &size = s.size_list()[i]; - if (size) { + auto dyn_offsets = s.offsets(); + auto dyn_sizes = s.sizes(); + for (std::size_t i = 0, joffset = 0, jsize = 0; i < s.static_offsets().size(); ++i) { + if (i != 0) { + *os_ << ","; + } + auto offset = s.static_offsets()[i]; + if (is_dynamic_value(offset)) { + visit(*this, *dyn_offsets[joffset++]); + } else { + *os_ << offset; + } + auto size = s.static_sizes()[i]; + if (size > 0 || is_dynamic_value(size)) { *os_ << ":"; - visit(*this, *size); + if (is_dynamic_value(size)) { + visit(*this, *dyn_sizes[jsize++]); + } else { + *os_ << size; + } } - }); + } *os_ << "]"; *os_ << " : "; visit(*this, *s.operand()->ty()); diff --git a/src/pass/dump_ir.hpp b/src/pass/dump_ir.hpp index 28c5f99c..a5a037c6 100644 --- a/src/pass/dump_ir.hpp +++ b/src/pass/dump_ir.hpp @@ -41,6 +41,7 @@ class dump_ir_pass { void operator()(barrier_inst const &b); void operator()(cast_inst const &c); void operator()(compare_inst const &c); + void operator()(constant_inst const &c); void operator()(expand_inst const &e); void operator()(fuse_inst const &f); void operator()(load_inst const &e); diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index 311b2fc9..ef1ddf1d 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -84,8 +84,6 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( auto const ty_ = get_scalar(ctx_, enum_cast(ty)); auto const tA_ = enum_cast(tA); auto const tB_ = enum_cast(tB); - auto const index_ty = get_scalar(ctx_, scalar_type::index); - auto const i64_ty = get_scalar(ctx_, scalar_type::i64); auto const kernel = [&](function_builder &fb, bool is_beta_nonzero) { auto alpha = fb.argument(ty_, "alpha"); @@ -106,14 +104,16 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( fb.body( [&](region_builder &bb) { auto gid = bb.add(make_group_id(ctx_, my_loc())); - auto offsets = std::vector{make_imm(0, index_ty, my_loc()), - make_imm(0, index_ty, my_loc()), gid}; - auto size = - std::vector{make_imm(dynamic, i64_ty, my_loc()), - make_imm(dynamic, i64_ty, my_loc()), value{}}; - auto a = bb.add(make_subview(A, offsets, size, my_loc())); - auto b = bb.add(make_subview(B, offsets, size, my_loc())); - auto c = bb.add(make_subview(C, offsets, size, my_loc())); + auto const static_offsets = std::vector{0, 0}; + auto const A_static_sizes = std::vector{M, K}; + auto const B_static_sizes = std::vector{K, N}; + auto const C_static_sizes = std::vector{M, N}; + auto a = bb.add( + make_subview(A, static_offsets, A_static_sizes, {}, {}, my_loc())); + auto b = bb.add( + make_subview(B, static_offsets, B_static_sizes, {}, {}, my_loc())); + auto c = bb.add( + make_subview(C, static_offsets, C_static_sizes, {}, {}, my_loc())); bb.add(make_gemm(tA_, tB_, false, alpha, std::move(a), std::move(b), beta, std::move(c), my_loc())); }, diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 58a88cc2..4623c9e1 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -93,7 +93,6 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( [&] { auto const ty_ = get_scalar(ctx_, enum_cast(ty)); auto const index_ty = get_scalar(ctx_, scalar_type::index); - auto const i64_ty = get_scalar(ctx_, scalar_type::i64); auto const shapes = std::vector{blas_shape{enum_cast(ty), {M_block_size, N}}}; @@ -107,24 +106,38 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( auto const body = [&](region_builder &bb, value &alpha, value &A, value &B, value &beta, value &C) { - auto const gemm = [&](region_builder &bb, std::vector const &offsets, - value const &block_size) { - auto a = bb.add(make_subview( - A, offsets, {block_size, make_imm(K, index_ty, my_loc())}, my_loc())); - auto c = bb.add(make_subview( - C, offsets, {block_size, make_imm(N, index_ty, my_loc())}, my_loc())); - bb.add(make_gemm(transpose::N, transpose::N, false, alpha, a, B, beta, c, - my_loc())); - }; - auto const block_size_imm = make_imm(M_block_size, index_ty, my_loc()); auto gid = bb.add(make_group_id(ctx_, my_loc())); auto m = bb.add(make_arith(arithmetic::mul, gid, make_imm(M_block_size, index_ty, my_loc()), my_loc())); - auto const offsets = std::vector{m, make_imm(0, index_ty, my_loc())}; + + auto const static_offsets = std::vector{dynamic, 0}; + auto const offsets = std::vector{m}; + + auto const static_gemm = [&](region_builder &bb) { + auto const A_static_sizes = std::vector{M_block_size, K}; + auto const C_static_sizes = std::vector{M_block_size, N}; + auto a = bb.add( + make_subview(A, static_offsets, A_static_sizes, offsets, {}, my_loc())); + auto c = bb.add( + make_subview(C, static_offsets, C_static_sizes, offsets, {}, my_loc())); + bb.add(make_gemm(transpose::N, transpose::N, false, alpha, a, B, beta, c, + my_loc())); + }; + auto const dynamic_gemm = [&](region_builder &bb, value const &dyn_block_size) { + auto const A_static_sizes = std::vector{dynamic, K}; + auto const C_static_sizes = std::vector{dynamic, N}; + auto const sizes = std::vector{dyn_block_size}; + auto a = bb.add( + make_subview(A, static_offsets, A_static_sizes, offsets, sizes, my_loc())); + auto c = bb.add( + make_subview(C, static_offsets, C_static_sizes, offsets, sizes, my_loc())); + bb.add(make_gemm(transpose::N, transpose::N, false, alpha, a, B, beta, c, + my_loc())); + }; if (!is_dynamic_value(M) && M % M_block_size == 0) { - gemm(bb, offsets, block_size_imm); + static_gemm(bb); } else { auto M_val = is_dynamic_value(M) ? bb.add(make_size(C, 0, my_loc())) : make_imm(M, index_ty); @@ -132,11 +145,9 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( auto cond = bb.add(make_cmp(cmp_condition::lt, M_val_sub_m, make_imm(M_block_size, index_ty, my_loc()), my_loc())); - auto const dynamic_imm = make_imm(dynamic, i64_ty, my_loc()); bb.ifelse( - cond, [&](region_builder &bb) { gemm(bb, offsets, dynamic_imm); }, - [&](region_builder &bb) { gemm(bb, offsets, block_size_imm); }, {}, - my_loc()); + cond, [&](region_builder &bb) { dynamic_gemm(bb, M_val_sub_m); }, + [&](region_builder &bb) { static_gemm(bb); }, {}, my_loc()); } }; diff --git a/test/codegen/atomic.ir b/test/codegen/atomic.ir index 3d7ebfe4..a876dec7 100644 --- a/test/codegen/atomic.ir +++ b/test/codegen/atomic.ir @@ -3,13 +3,15 @@ ; RUN: %tinytc-oc < %s | filecheck %s func @axpby_atomic_store(%alpha: f64, %A: memref, %B: memref) { - axpby.n.atomic %alpha, %A, 0.0, %B : f64, memref, f64, memref + %zero = constant 0.0 -> f64 + axpby.n.atomic %alpha, %A, %zero, %B : f64, memref, f64, memref ; CHECK: global double* b = B + (blck + m) * 1; ; CHECK-NEXT: atomic_store_explicit((global volatile atomic_double*) b, alpha * A[(blck + m) * 1], memory_order_relaxed, memory_scope_work_group); } func @axpby_atomic_add(%alpha: f32, %A: memref, %B: memref) { - axpby.n.atomic %alpha, %A, 1.0, %B : f32, memref, f32, memref + %one = constant 1.0 -> f32 + axpby.n.atomic %alpha, %A, %one, %B : f32, memref, f32, memref ; CHECK: global float* b = Bb + (blck1 + m) * 1; ; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) b, alpha * Ab[(blck1 + m) * 1], memory_order_relaxed, memory_scope_work_group); } @@ -24,13 +26,15 @@ func @axpby_atomic_general(%alpha: f32, %A: memref, %B: memref) { } func @gemm_atomic(%A: memref, %B: memref, %C: memref) { - gemm.n.n.atomic 1.0, %A, %B, 1.0, %C + %one = constant 1.0 -> f32 + gemm.n.n.atomic %one, %A, %B, %one, %C : f32, memref, memref, f32, memref ; CHECK: atomic_fetch_add_explicit((global volatile atomic_float*) (Cb + get_sub_group_local_id()), c[n], memory_order_relaxed, memory_scope_work_group); } func @ger_atomic(%A: memref, %B: memref, %C: memref) { - ger.atomic 1.0, %A, %B, 1.0, %C + %one = constant 1.0 -> f32 + ger.atomic %one, %A, %B, %one, %C : f32, memref, memref, f32, memref ; CHECK: global float* c = Cb + (blck1 + m) * 1; ; CHECK-NEXT: float ab = A[(blck1 + m) * 1] * b; diff --git a/test/codegen/axpby0.ir b/test/codegen/axpby0.ir index 33019300..23ccfefc 100644 --- a/test/codegen/axpby0.ir +++ b/test/codegen/axpby0.ir @@ -3,6 +3,7 @@ ; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s func @axpby(%alpha: f32, %A: memref, %B: memref) { - axpby.n %alpha, %A, 0.0, %B : f32, memref, f32, memref -; CHECK: 6.5-77: Incompatible tensor shapes + %zero = constant 0.0 -> f32 + axpby.n %alpha, %A, %zero, %B : f32, memref, f32, memref +; CHECK: 7.5-79: Incompatible tensor shapes } diff --git a/test/codegen/axpby1.ir b/test/codegen/axpby1.ir index acecbcb8..4ff3517e 100644 --- a/test/codegen/axpby1.ir +++ b/test/codegen/axpby1.ir @@ -3,25 +3,32 @@ ; RUN: %tinytc-oc < %s func @axpby0(%alpha: f32, %A: memref, %B: memref) { - axpby.n %alpha, %A, 0.0, %B : f32, memref, f32, memref + %z = constant 0.0 -> f32 + axpby.n %alpha, %A, %z, %B : f32, memref, f32, memref } func @axpby1(%alpha: f64, %A: memref>, %B: memref) { - axpby.n %alpha, %A, 0.0, %B : f64, memref>, f32, memref + %z = constant 0.0 -> f32 + axpby.n %alpha, %A, %z, %B : f64, memref>, f32, memref } func @axpby2(%alpha: f32, %A: memref, %B: memref) { - axpby.n %alpha, %A, 0.0, %B : f32, memref, f32, memref + %z = constant 0.0 -> f32 + axpby.n %alpha, %A, %z, %B : f32, memref, f32, memref } func @axpby3(%alpha: f32, %A: memref, %B: memref) { - for %i=0,5 { - %A0 = subview %A[:,:,:,%i] : memref - %B0 = subview %B[:,:,:,%i] : memref - for %j=0,4 { - %A1 = subview %A0[:,:,%j] : memref - %B1 = subview %B0[:,:,%j] : memref - axpby.t %alpha, %A1, 0.0, %B1 : f32, memref, f32, memref + %z = constant 0.0 -> f32 + %lb = constant 0 -> index + %ub = constant 5 -> index + for %i=%lb,%ub { + %A0 = subview %A[0:48,0:48,0:4,%i] : memref + %B0 = subview %B[0:48,0:48,0:4,%i] : memref + %ub1 = constant 4 -> index + for %j=%lb,%ub1 { + %A1 = subview %A0[0:48,0:48,%j] : memref + %B1 = subview %B0[0:48,0:48,%j] : memref + axpby.t %alpha, %A1, %z, %B1 : f32, memref, f32, memref } } } diff --git a/test/codegen/dope_vector_group0.ir b/test/codegen/dope_vector_group0.ir index 83666ec9..acdbf875 100644 --- a/test/codegen/dope_vector_group0.ir +++ b/test/codegen/dope_vector_group0.ir @@ -4,15 +4,19 @@ ; RUN: %tinytc-oc < %s | filecheck %s func @kernel1(%in: group>) { ; CHECK: void kernel1(global float*global* in, global long* in_shape1, global long* in_stride2) - %0 = load %in[5] : group> - ; CHECK-NEXT: global float* x0 = *(in + 5ll) + 0; - ; CHECK-NEXT: long x0_shape1 = in_shape1[5ll]; - ; CHECK-NEXT: long x0_stride2 = in_stride2[5ll]; + %c5 = constant 5 -> index + %0 = load %in[%c5] : group> + ; CHECK-NEXT: long c5 = 5ll; + ; CHECK-NEXT: global float* x0 = *(in + c5) + 0; + ; CHECK-NEXT: long x0_shape1 = in_shape1[c5]; + ; CHECK-NEXT: long x0_stride2 = in_stride2[c5]; } func @kernel2(%in: group, offset: ?>) { ; CHECK: void kernel2(global float*global* in, global long* in_shape0, long in_offset) - %0 = load %in[5] : group, offset: ?> - ; CHECK-NEXT: global float* x0 = *(in + 5ll) + in_offset; - ; CHECK-NEXT: long x0_shape0 = in_shape0[5ll]; + %c5 = constant 5 -> index + %0 = load %in[%c5] : group, offset: ?> + ; CHECK-NEXT: long c5 = 5ll; + ; CHECK-NEXT: global float* x0 = *(in + c5) + in_offset; + ; CHECK-NEXT: long x0_shape0 = in_shape0[c5]; } diff --git a/test/codegen/expand.ir b/test/codegen/expand.ir index e4797079..217f5e97 100644 --- a/test/codegen/expand.ir +++ b/test/codegen/expand.ir @@ -3,112 +3,132 @@ ; RUN: %tinytc-oc < %s | filecheck %s func @t1(%0: memref) { + %z = constant 0 -> index %1 = expand %0[1->2x8] : memref - %2 = load %1[0,0,0,0] : memref + %2 = load %1[%z,%z,%z,%z] : memref +; CHECK-LABEL: void t1( ; CHECK: global float* x1 = x0; -; CHECK-NEXT: float x2 = *(x1 + 0ll * 1 + 0ll * 32 + 0ll * 64 + 0ll * 512); +; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * 32 + z * 64 + z * 512); } func @t2(%0: memref) { - %1 = expand %0[1->2x?] : memref - %2 = load %1[0,0,0,0] : memref -; CHECK: global float* x1 = x0; -; CHECK-NEXT: float x2 = *(x1 + 0ll * 1 + 0ll * 32 + 0ll * 64 + 0ll * 512); + %z = constant 0 -> index + %1 = expand %0[1->2x2x2x2] : memref + %2 = load %1[%z,%z,%z,%z,%z,%z] : memref +; CHECK-LABEL: void t2( +; CHECK: global float* x1 = x0; +; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * 32 + z * 64 + z * 128 + z * 256 + z * 512); } -func @t3(%0: memref) { - %1 = expand %0[1->?x2x4] : memref - %2 = load %1[0,0,0,0,0] : memref -; CHECK: global float* x1 = x0; -; CHECK-NEXT: float x2 = *(x1 + 0ll * 1 + 0ll * 32 + 0ll * 64 + 0ll * 128 + 0ll * 512); +func @t3(%0: memref, %1: index) { + %z = constant 0 -> index + %2 = expand %0[1->%1 x 2] : memref + %3 = load %2[%z,%z,%z] : memref +; CHECK-LABEL: void t3( +; CHECK: global float* x2 = x0; +; CHECK-NEXT: long x2_shape1 = x1; +; CHECK-NEXT: long x2_stride2 = 32 * x1; +; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * x2_stride2); } -func @t4(%0: memref) { - %1 = expand %0[1->2x?] : memref - %2 = load %1[0,0,0,0] : memref -; CHECK: global float* x1 = x0; -; CHECK-NEXT: long inferred_size = 1 * 2ll / x0_shape1; -; CHECK-NEXT: long x1_shape2 = inferred_size; -; CHECK-NEXT: long x1_stride3 = x0_stride2; -; CHECK-NEXT: float x2 = *(x1 + 0ll * 1 + 0ll * 32 + 0ll * 64 + 0ll * x1_stride3); +func @t4(%0: memref, %1: index) { + %z = constant 0 -> index + %2 = expand %0[1->2 x %1] : memref + %3 = load %2[%z,%z,%z] : memref +; CHECK-LABEL: void t4( +; CHECK: global float* x2 = x0; +; CHECK-NEXT: long x2_shape2 = x1; +; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * 64); } -func @t5(%0: memref) { - %1 = expand %0[1->?x4x8] : memref - %2 = load %1[0,0,0,0,0] : memref -; CHECK: global float* x1 = x0; -; CHECK-NEXT: long inferred_size = 1 * 4ll * 8ll / x0_shape1; -; CHECK-NEXT: long x1_shape1 = inferred_size; -; CHECK-NEXT: long x1_stride2 = 32 * inferred_size; -; CHECK-NEXT: long x1_stride3 = 32 * inferred_size * 4ll; -; CHECK-NEXT: long x1_stride4 = x0_stride2; -; CHECK-NEXT: float x2 = *(x1 + 0ll * 1 + 0ll * 32 + 0ll * x1_stride2 + 0ll * x1_stride3 + 0ll * x1_stride4); +func @t5(%0: memref, %1: index) { + %z = constant 0 -> index + %2 = expand %0[1->%1 x 2] : memref + %3 = load %2[%z,%z,%z] : memref +; CHECK-LABEL: void t5( +; CHECK: global float* x2 = x0; +; CHECK-NEXT: long x2_shape1 = x1; +; CHECK-NEXT: long x2_stride2 = 32 * x1; +; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * x2_stride2); } -func @t6(%0: memref) { - %1 = group_id - %2 = expand %0[1 -> %1 x ?] : memref - %3 = load %2[0,0,0] : memref -; CHECK: global float* x2 = x0; -; CHECK-NEXT: long inferred_size = 1 * x1 / x0_shape1; -; CHECK-NEXT: long x2_shape1 = x1; -; CHECK-NEXT: long x2_shape2 = inferred_size; -; CHECK-NEXT: long x2_stride2 = 32 * x1; -; CHECK-NEXT: float x3 = *(x2 + 0ll * 1 + 0ll * 32 + 0ll * x2_stride2); +func @t6(%0: memref, %1: index) { + %z = constant 0 -> index + %2 = expand %0[1->%1 x 2] : memref + %3 = load %2[%z,%z,%z] : memref +; CHECK-LABEL: void t6( +; CHECK: global float* x2 = x0; +; CHECK-NEXT: long x2_shape1 = x1; +; CHECK-NEXT: long x2_stride2 = 32 * x1; +; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * x2_stride2); } -func @t7(%0: memref) { - %1 = group_id - %2 = expand %0[1 -> %1 x ?] : memref - %3 = load %2[0,0,0] : memref -; CHECK: global float* x2 = x0; -; CHECK-NEXT: long inferred_size = 1 * x1 / 16; -; CHECK-NEXT: long x2_shape1 = x1; -; CHECK-NEXT: long x2_shape2 = inferred_size; -; CHECK-NEXT: long x2_stride2 = 32 * x1; -; CHECK-NEXT: float x3 = *(x2 + 0ll * 1 + 0ll * 32 + 0ll * x2_stride2); +func @t7(%0: memref, %1: index, %2: index) { + %z = constant 0 -> index + %3 = expand %0[1->%1 x %2 x 2] : memref + %4 = load %3[%z,%z,%z,%z] : memref +; CHECK-LABEL: void t7( +; CHECK: global float* x3 = x0; +; CHECK-NEXT: long x3_shape1 = x1; +; CHECK-NEXT: long x3_shape2 = x2; +; CHECK-NEXT: long x3_stride2 = 32 * x1; +; CHECK-NEXT: long x3_stride3 = 32 * x1 * x2; +; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x3_stride2 + z * x3_stride3); } -func @t8(%0: memref) { - %1 = group_id - %2 = arith.add %1, 5 : index - %3 = expand %0[1 -> %1 x %2] : memref - %4 = load %3[0,0,0] : memref -; CHECK: global float* x3 = x0; -; CHECK-NEXT: long x3_shape1 = x1; -; CHECK-NEXT: long x3_shape2 = x2; -; CHECK-NEXT: long x3_stride2 = 32 * x1; -; CHECK-NEXT: float x4 = *(x3 + 0ll * 1 + 0ll * 32 + 0ll * x3_stride2); +func @t8(%0: memref, %1: index, %2: index) { + %z = constant 0 -> index + %3 = expand %0[1->%2 x 2 x %1] : memref + %4 = load %3[%z,%z,%z,%z] : memref +; CHECK-LABEL: void t8( +; CHECK: global float* x3 = x0; +; CHECK-NEXT: long x3_shape1 = x2; +; CHECK-NEXT: long x3_stride2 = 32 * x2; +; CHECK-NEXT: long x3_shape3 = x1; +; CHECK-NEXT: long x3_stride3 = 32 * x2 * 2ll; +; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x3_stride2 + z * x3_stride3); } -func @t9(%0: memref) { - %1 = group_id - %2 = expand %0[1 -> 4 x %1] : memref - %3 = load %2[0,0,0] : memref -; CHECK: global float* x2 = x0; -; CHECK-NEXT: long x2_shape2 = x1; -; CHECK-NEXT: float x3 = *(x2 + 0ll * 1 + 0ll * 32 + 0ll * 128); +func @t9(%0: memref, %1: index, %2: index) { + %z = constant 0 -> index + %3 = expand %0[1->%1 x %2] : memref + %4 = load %3[%z,%z,%z] : memref +; CHECK-LABEL: void t9( +; CHECK: global float* x3 = x0; +; CHECK-NEXT: long x3_shape1 = x1; +; CHECK-NEXT: long x3_shape2 = x2; +; CHECK-NEXT: long x3_stride2 = 32 * x1; +; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x3_stride2); } -func @t10(%0: memref>) { - %1 = expand %0[0->4x8] : memref> - %2 = load %1[0,0,0] : memref> -; CHECK: global float* x1 = x0; -; CHECK-NEXT: float x2 = *(x1 + 0ll * 2 + 0ll * 8 + 0ll * 64); +func @t10(%0: memref, %1: index, %2: index) { + %z = constant 0 -> index + %3 = expand %0[1->%1 x %2] : memref + %4 = load %3[%z,%z,%z] : memref +; CHECK-LABEL: void t10( +; CHECK: global float* x3 = x0; +; CHECK-NEXT: long x3_shape1 = x1; +; CHECK-NEXT: long x3_shape2 = x2; +; CHECK-NEXT: long x3_stride2 = 32 * x1; +; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x3_stride2); } func @t11(%0: memref>) { - %1 = expand %0[0->4x?] : memref> - %2 = load %1[0,0,0] : memref> -; CHECK: global float* x1 = x0; -; CHECK-NEXT: float x2 = *(x1 + 0ll * 2 + 0ll * 8 + 0ll * 64); + %z = constant 0 -> index + %1 = expand %0[0->4 x 8] : memref> + %2 = load %1[%z,%z,%z] : memref> +; CHECK-LABEL: void t11( +; CHECK: global float* x1 = x0; +; CHECK-NEXT: float x2 = *(x1 + z * 2 + z * 8 + z * 64); } -func @t12(%0: memref>) { - %1 = expand %0[0->?x4] : memref> - %2 = load %1[0,0,0] : memref> -; CHECK: global float* x1 = x0; -; CHECK-NEXT: long inferred_size = 1 * 4ll / x0_shape0; -; CHECK-NEXT: long x1_shape0 = inferred_size; -; CHECK-NEXT: long x1_stride1 = 2 * inferred_size; -; CHECK-NEXT: long x1_stride2 = x0_stride1; -; CHECK-NEXT: float x2 = *(x1 + 0ll * 2 + 0ll * x1_stride1 + 0ll * x1_stride2); +func @t12(%0: memref>, %1: index) { + %z = constant 0 -> index + %2 = expand %0[0->%1 x 4] : memref> + %3 = load %2[%z,%z,%z] : memref> +; CHECK-LABEL: void t12( +; CHECK: global float* x2 = x0; +; CHECK-NEXT: long x2_shape0 = x1; +; CHECK-NEXT: long x2_stride1 = 2 * x1; +; CHECK-NEXT: long x2_stride2 = x0_stride1; +; CHECK-NEXT: float x3 = *(x2 + z * 2 + z * x2_stride1 + z * x2_stride2); } -func @t13(%0: memref>) { - %1 = expand %0[0->4x?] : memref> - %2 = load %1[0,0,0] : memref> -; CHECK: global float* x1 = x0; -; CHECK-NEXT: long inferred_size = 1 * 4ll / x0_shape0; -; CHECK-NEXT: long x1_shape1 = inferred_size; -; CHECK-NEXT: long x1_stride2 = x0_stride1; -; CHECK-NEXT: float x2 = *(x1 + 0ll * 2 + 0ll * 8 + 0ll * x1_stride2); +func @t13(%0: memref>, %1: index) { + %z = constant 0 -> index + %2 = expand %0[0->4 x %1] : memref> + %3 = load %2[%z,%z,%z] : memref> +; CHECK-LABEL: void t13( +; CHECK: global float* x2 = x0; +; CHECK-NEXT: long x2_shape1 = x1; +; CHECK-NEXT: long x2_stride2 = x0_stride1; +; CHECK-NEXT: float x3 = *(x2 + z * 2 + z * 8 + z * x2_stride2); } diff --git a/test/codegen/for.ir b/test/codegen/for.ir index 3c39da7f..1fda98d9 100644 --- a/test/codegen/for.ir +++ b/test/codegen/for.ir @@ -3,10 +3,14 @@ ; RUN: %tinytc-oc < %s | filecheck %s func @for1() { - for %0 = 0,10 { - ; CHECK: for (long x0 = 0ll; x0 < 10ll; ++x0) + %lb0 = constant 0 -> index + %ub0 = constant 10 -> index + for %0 = %lb0,%ub0 { + ; CHECK: for (long x0 = lb0; x0 < ub0; ++x0) } - for %1 = -2,2 : i16 { - ; CHECK: for (short x1 = -2; x1 < 2; ++x1) + %lb1 = constant -2 -> i16 + %ub1 = constant 2 -> i16 + for %1 = %lb1,%ub1 : i16 { + ; CHECK: for (short x1 = lb1; x1 < ub1; ++x1) } } diff --git a/test/codegen/fuse.ir b/test/codegen/fuse.ir index 2d81e4a0..2873ccbc 100644 --- a/test/codegen/fuse.ir +++ b/test/codegen/fuse.ir @@ -3,26 +3,30 @@ ; RUN: %tinytc-oc < %s | filecheck %s func @t1(%0: memref) { + %z = constant 0 -> index %1 = fuse %0[1,3] : memref - %2 = load %1[0,0,0] : memref -; CHECK: float x2 = *(x1 + 0ll * 1 + 0ll * 32 + 0ll * 16384); + %2 = load %1[%z,%z,%z] : memref +; CHECK: float x2 = *(x1 + z * 1 + z * 32 + z * 16384); } func @t2(%0: memref) { + %z = constant 0 -> index %1 = fuse %0[1,3] : memref - %2 = load %1[0,0,0] : memref> + %2 = load %1[%z,%z,%z] : memref> ; CHECK: long x1_shape1 = 16 * x0_shape2 * 4; ; CHECK-NEXT: long x1_stride2 = x0_stride4; -; CHECK-NEXT: float x2 = *(x1 + 0ll * 1 + 0ll * 32 + 0ll * x1_stride2); +; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * 32 + z * x1_stride2); } func @t3(%0: memref>) { + %z = constant 0 -> index %1 = fuse %0[1,2] : memref> - %2 = load %1[0,0,0] : memref> -; CHECK: float x2 = *(x1 + 0ll * 1 + 0ll * 48 + 0ll * 1536); + %2 = load %1[%z,%z,%z] : memref> +; CHECK: float x2 = *(x1 + z * 1 + z * 48 + z * 1536); } func @t4(%0: memref>) { + %z = constant 0 -> index %1 = fuse %0[0,1] : memref> - %2 = load %1[0,0] : memref> + %2 = load %1[%z,%z] : memref> ; CHECK: long x1_shape0 = 8 * x0_shape1; ; CHECK-NEXT: long x1_stride1 = x0_stride2; -; CHECK-NEXT: float x2 = *(x1 + 0ll * 1 + 0ll * x1_stride1); +; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * x1_stride1); } diff --git a/test/codegen/if.ir b/test/codegen/if.ir index b8c8728d..1ba73c9d 100644 --- a/test/codegen/if.ir +++ b/test/codegen/if.ir @@ -3,14 +3,16 @@ ; RUN: %tinytc-oc < %s | filecheck %s func @if0(%0: i32) { - %1 = cmp.lt %0, 16 : i32 - %2 = cmp.ge %0, 0 : i32 + %c16 = constant 16 -> i32 + %c0 = constant 0 -> i32 + %1 = cmp.lt %0, %c16 : i32 + %2 = cmp.ge %0, %c0 : i32 %3 = arith.and %1, %2 : i1 if %3 { } else { } -; CHECK: bool x1 = x0 < 16; -; CHECK: bool x2 = x0 >= 0; +; CHECK: bool x1 = x0 < c16; +; CHECK: bool x2 = x0 >= c0; ; CHECK: bool x3 = x1 && x2; ; CHECK: if (x3) { ; CHECK-NEXT: } else { @@ -18,7 +20,8 @@ func @if0(%0: i32) { } func @if1(%0: i32) { - %1 = cmp.lt %0, 16 : i32 + %c16 = constant 16 -> i32 + %1 = cmp.lt %0, %c16 : i32 if %1 { } else { } @@ -28,7 +31,8 @@ func @if1(%0: i32) { } func @if2(%0: i32) { - %1 = cmp.lt %0, 16 : i32 + %c16 = constant 16 -> i32 + %1 = cmp.lt %0, %c16 : i32 if %1 -> () { yield : } else { @@ -40,49 +44,57 @@ func @if2(%0: i32) { } func @if3(%0: i32) { - %1 = cmp.lt %0, 16 : i32 + %c16 = constant 16 -> i32 + %1 = cmp.lt %0, %c16 : i32 %x = if %1 -> (i32) { yield %0 : i32 } else { - yield 16 : i32 + yield %c16 : i32 } ; CHECK: int x; ; CHECK-NEXT: if (x1) { ; CHECK-NEXT: x = x0; ; CHECK-NEXT: } else { -; CHECK-NEXT: x = 16; +; CHECK-NEXT: x = c16; ; CHECK-NEXT: } } func @if4(%0: i32) { - %1 = cmp.lt %0, 16 : i32 + %c16 = constant 16 -> i32 + %1 = cmp.lt %0, %c16 : i32 %x, %y = if %1 -> (i32, f32) { if %1 { } - yield %0, 1.0 : i32, f32 + %one = constant 1.0 -> f32 + yield %0, %one : i32, f32 } else { %z = if %1 -> (f32) { - yield 1.0 : f32 + %one = constant 1.0 -> f32 + yield %one : f32 } else { - yield 0.0 : f32 + %zero = constant 0.0 -> f32 + yield %zero : f32 } - yield 16, %z : i32, f32 + yield %c16, %z : i32, f32 } ; CHECK: int x; ; CHECK-NEXT: float y; ; CHECK-NEXT: if (x1) { ; CHECK-NEXT: if (x1) { ; CHECK-NEXT: } +; CHECK-NEXT: float one = 0x1p+0f; ; CHECK-NEXT: x = x0; -; CHECK-NEXT: y = 0x1p+0f; +; CHECK-NEXT: y = one; ; CHECK-NEXT: } else { ; CHECK-NEXT: float z; ; CHECK-NEXT: if (x1) { -; CHECK-NEXT: z = 0x1p+0f; +; CHECK-NEXT: float one = 0x1p+0f; +; CHECK-NEXT: z = one; ; CHECK-NEXT: } else { -; CHECK-NEXT: z = 0x0p+0f; +; CHECK-NEXT: float zero = 0x0p+0f; +; CHECK-NEXT: z = zero; ; CHECK-NEXT: } -; CHECK-NEXT: x = 16; +; CHECK-NEXT: x = c16; ; CHECK-NEXT: y = z; ; CHECK-NEXT: } } diff --git a/test/codegen/load.ir b/test/codegen/load.ir index df94f0f6..53288b52 100644 --- a/test/codegen/load.ir +++ b/test/codegen/load.ir @@ -3,13 +3,14 @@ ; RUN: %tinytc-oc < %s | filecheck %s func @kernel1(%a: memref, %b: memref, %c: group>) { + %c5 = constant 5 -> index %0 = load %a[] : memref %1 = group_id - %2 = load %b[5, %1] : memref + %2 = load %b[%c5, %1] : memref %3 = load %c[%1] : group> ; CHECK: float x0 = *a; ; CHECK-NEXT: long x1 = get_global_id(2); - ; CHECK-NEXT: float x2 = *(b + 5ll * 1 + x1 * 10); + ; CHECK-NEXT: float x2 = *(b + c5 * 1 + x1 * 10); ; CHECK-NEXT: global float* x3 = *(c + x1) + 0; } diff --git a/test/codegen/store.ir b/test/codegen/store.ir index 22b4192c..cca632c5 100644 --- a/test/codegen/store.ir +++ b/test/codegen/store.ir @@ -3,9 +3,10 @@ ; RUN: %tinytc-oc < %s | filecheck %s func @kernel(%a: memref, %b: memref, %c: f32) { + %c5 = constant 5 -> index %1 = group_id store %c, %a[] : memref - store %c, %b[5, %1] : memref + store %c, %b[%c5, %1] : memref ; CHECK: *a = c; - ; CHECK-NEXT: *(b + 5ll * 1 + x1 * 10) = c; + ; CHECK-NEXT: *(b + c5 * 1 + x1 * 10) = c; } diff --git a/test/codegen/subview_return_type.ir b/test/codegen/subview_return_type.ir index 35a2ab7f..a91ac1b4 100644 --- a/test/codegen/subview_return_type.ir +++ b/test/codegen/subview_return_type.ir @@ -3,59 +3,50 @@ ; RUN: %tinytc-oc < %s | filecheck %s func @t1(%0: memref) { +; CHECK-LABEL: void t1( + %z = constant 0 -> index %1 = subview %0[4:8,8:4] : memref - %2 = load %1[0,0] : memref> + %2 = load %1[%z,%z] : memref> } func @t2(%0: memref, %1: index) { +; CHECK-LABEL: void t2( + %z = constant 0 -> index %2 = subview %0[2:4,%1] : memref - %3 = load %2[0] : memref + %3 = load %2[%z] : memref } func @t3(%0: memref, %1: index) { +; CHECK-LABEL: void t3( + %z = constant 0 -> index + %2 = subview %0[2:4,%1:0] : memref + %3 = load %2[%z] : memref +} +func @t4(%0: memref, %1: index) { +; CHECK-LABEL: void t4( + %z = constant 0 -> index %2 = subview %0[2:4,%1:1] : memref - %3 = load %2[0,0] : memref> -} -func @t4(%0: memref, %1: index) { - %2 = subview %0[0:%1] : memref - %3 = load %2[0] : memref -} -func @t5(%0: memref, %1: index) { - %2 = subview %0[2:4, 0:%1, 6:7] : memref - %3 = load %2[0,0,0] : memref> -} -func @t6(%0: memref>, %1: index) { - %2 = subview %0[2:4, 0:%1, 6:7] : memref> - %3 = load %2[0,0,0] : memref> -} -func @t7(%0: memref) { - %2 = subview %0[:] : memref -; CHECK: void t7(global float* x0) { -; CHECK-NEXT: global float* x2 = x0 + 0ll * 1; - %3 = load %2[0] : memref -} -func @t8(%0: memref) { - %2 = subview %0[:] : memref -; CHECK: void t8(global float* x0, long x0_shape0) { -; CHECK-NEXT: global float* x2 = x0 + 0ll * 1; -; CHECK-NEXT: long x2_shape0 = x0_shape0 - 0ll; - %3 = load %2[0] : memref -} -func @t9(%0: memref) { - %2 = subview %0[5:?] : memref -; CHECK: void t9(global float* x0) { -; CHECK-NEXT: global float* x2 = x0 + 5ll * 1; - %3 = load %2[0] : memref -} -func @t10(%0: memref) { - %2 = subview %0[5:?] : memref -; CHECK: void t10(global float* x0, long x0_shape0) { -; CHECK-NEXT: global float* x2 = x0 + 5ll * 1; -; CHECK-NEXT: long x2_shape0 = x0_shape0 - 5ll; - %3 = load %2[0] : memref -} -func @t11(%0: memref, %1: index) { - %2 = subview %0[%1:?] : memref -; CHECK: void t11(global float* x0, long x1) { -; CHECK-NEXT: global float* x2 = x0 + x1 * 1; -; CHECK-NEXT: long x2_shape0 = 16 - x1; - %3 = load %2[0] : memref + %3 = load %2[%z,%z] : memref> +} +func @t5(%0: memref, %1: index) { +; CHECK-LABEL: void t5( + %z = constant 0 -> index + %2 = subview %0[%1:4] : memref + %3 = load %2[%z] : memref +} +func @t6(%0: memref, %1: index) { +; CHECK-LABEL: void t6( + %z = constant 0 -> index + %2 = subview %0[%1:%1] : memref + %3 = load %2[%z] : memref +} +func @t7(%0: memref, %1: index) { +; CHECK-LABEL: void t7( + %z = constant 0 -> index + %2 = subview %0[2:4, %1:%1, 6:7] : memref + %3 = load %2[%z,%z,%z] : memref> +} +func @t8(%0: memref>, %1: index) { +; CHECK-LABEL: void t8( + %z = constant 0 -> index + %2 = subview %0[2:4, %1:%1, 6:7] : memref> + %3 = load %2[%z,%z,%z] : memref> } diff --git a/test/codegen/type_mismatch1.ir b/test/codegen/type_mismatch1.ir index 51a9612f..be9cdc2f 100644 --- a/test/codegen/type_mismatch1.ir +++ b/test/codegen/type_mismatch1.ir @@ -3,9 +3,10 @@ ; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s func @kernel(%K0: memref, %x: index, %y: index) { + %z = constant 0 -> index %0 = subview %K0[0:%x] : memref %1 = subview %0[0:%y] : memref - %2 = load %1[0] : memref - %3 = load %1[0] : memref> -; CHECK: 9.13-44: Type of SSA value does not match operand type + %2 = load %1[%z] : memref + %3 = load %1[%z] : memref> +; CHECK: 10.13-45: Type of SSA value does not match operand type } diff --git a/test/opt/check-ir/nesting0.ir b/test/opt/check-ir/nesting0.ir index 3683b9a3..f2627b77 100644 --- a/test/opt/check-ir/nesting0.ir +++ b/test/opt/check-ir/nesting0.ir @@ -2,9 +2,11 @@ ; SPDX-License-Identifier: BSD-3-Clause ; RUN: not %tinytc-opt --check-ir < %s 2>&1 | filecheck %s -func @illegal_nesting(%A: memref, %B: memref, %C: memref) { - foreach %i=1,16 { - gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref +func @illegal_nesting(%c: f32, %A: memref, %B: memref, %C: memref) { + %lb = constant 1 -> index + %ub = constant 16 -> index + foreach %i=%lb,%ub { + gemm.n.n %c, %A, %B, %c, %C : f32, memref, memref, f32, memref } -; CHECK: 7.9-99: Collective instruction must not be called from SPMD region +; CHECK: 9.9-97: Collective instruction must not be called from SPMD region } diff --git a/test/opt/check-ir/nesting1.ir b/test/opt/check-ir/nesting1.ir index b6ec7dd9..ec252c5d 100644 --- a/test/opt/check-ir/nesting1.ir +++ b/test/opt/check-ir/nesting1.ir @@ -3,9 +3,11 @@ ; RUN: not %tinytc-opt --check-ir < %s 2>&1 | filecheck %s func @illegal_nesting() { - foreach %i=1,16 { - foreach %j=1,16 { + %lb = constant 1 -> index + %ub = constant 16 -> index + foreach %i=%lb,%ub { + foreach %j=%lb,%ub { } -; CHECK: 7.9-8.9: Collective instruction must not be called from SPMD region +; CHECK: 9.9-10.9: Collective instruction must not be called from SPMD region } } diff --git a/test/opt/check-ir/nesting3.ir b/test/opt/check-ir/nesting3.ir index c36375a1..4c896771 100644 --- a/test/opt/check-ir/nesting3.ir +++ b/test/opt/check-ir/nesting3.ir @@ -3,9 +3,11 @@ ; RUN: not %tinytc-opt --check-ir < %s 2>&1 | filecheck %s func @illegal_nesting() { + %lb = constant 1 -> index + %ub = constant 16 -> index parallel { - foreach %j=1,16 { + foreach %j=%lb,%ub { } -; CHECK: 7.9-8.9: Collective instruction must not be called from SPMD region +; CHECK: 9.9-10.9: Collective instruction must not be called from SPMD region } } diff --git a/test/opt/insert-barrier.ir b/test/opt/insert-barrier.ir index 4a514e3b..31a1b74b 100644 --- a/test/opt/insert-barrier.ir +++ b/test/opt/insert-barrier.ir @@ -70,7 +70,7 @@ func @respect_manual_barrier(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref) { %B = alloca -> memref - %0 = subview %B[:,0:8] : memref + %0 = subview %B[0:8,0:8] : memref axpby.n %a, %B, %b, %C : f32, memref, f32, memref axpby.n %a, %A, %b, %0 : f32, memref, f32, memref ; CHECK-LABEL: func @war_alias({{.*}} @@ -80,7 +80,8 @@ func @war_alias(%a: f32, %b: f32, %A: memref, %C: memref) { } func @if(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { - %0 = cmp.gt %a, 42.0 : f32 + %c42 = constant 42.0 -> f32 + %0 = cmp.gt %a, %c42 : f32 if %0 { axpby.n %a, %A, %b, %B : f32, memref, f32, memref axpby.n %a, %B, %b, %C : f32, memref, f32, memref @@ -101,7 +102,8 @@ func @if(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref< } func @if2(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { - %0 = cmp.gt %a, 42.0 : f32 + %c42 = constant 42.0 -> f32 + %0 = cmp.gt %a, %c42 : f32 axpby.n %a, %B, %b, %A : f32, memref, f32, memref if %0 { axpby.n %a, %A, %b, %B : f32, memref, f32, memref @@ -124,51 +126,58 @@ func @if2(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref } func @region1() { + %one = constant 1.0 -> f32 + %zero = constant 0.0 -> f32 %0 = alloca -> memref - for %i=0,4 : index { + %lb = constant 0 -> index + %ub = constant 4 -> index + for %i=%lb,%ub : index { %1 = alloca -> memref - for %k=0,4 : index { + for %k=%lb,%ub : index { %2 = alloca -> memref - gemm.n.n 1.0, %0, %1, 0.0, %2 + gemm.n.n %one, %0, %1, %zero, %2 : f32, memref, memref, f32, memref - axpby.n 1.0, %1, 0.0, %0 : f32, memref, f32, memref + axpby.n %one, %1, %zero, %0 : f32, memref, f32, memref } - axpby.n 1.0, %0, 0.0, %1 : f32, memref, f32, memref + axpby.n %one, %0, %zero, %1 : f32, memref, f32, memref } ; CHECK-LABEL: func @region1({{.*}} -; CHECK: for %i=0,4 : index { +; CHECK: for %i=%lb,%ub : index { ; CHECK-NEXT: %1 = alloca -> memref -; CHECK-NEXT: for %k=0,4 : index { +; CHECK-NEXT: for %k=%lb,%ub : index { ; CHECK-NEXT: %2 = alloca -> memref ; CHECK-NEXT: barrier.local -; CHECK-NEXT: gemm.n.n 0x1p+0, %0, %1, 0x0p+0, %2{{.*}} +; CHECK-NEXT: gemm.n.n %one, %0, %1, %zero, %2{{.*}} ; CHECK-NEXT: barrier.local -; CHECK-NEXT: axpby.n 0x1p+0, %1, 0x0p+0, %0{{.*}} +; CHECK-NEXT: axpby.n %one, %1, %zero, %0{{.*}} ; CHECK-NEXT: } ; CHECK-NEXT: barrier.local -; CHECK-NEXT: axpby.n 0x1p+0, %0, 0x0p+0, %1{{.*}} +; CHECK-NEXT: axpby.n %one, %0, %zero, %1{{.*}} ; CHECK-NEXT: } } func @no_barrier_spmd(%a: f32, %b: f32, %A: memref, %B: memref) { + %c0 = constant 0 -> i32 + %c3 = constant 3 -> i32 + %c4 = constant 4 -> i32 parallel { %1 = subgroup_id - %2 = cmp.eq %1, 0 : i32 + %2 = cmp.eq %1, %c0 : i32 if %2 { - %3 = load %A[3,4] : memref - store %3, %A[3,4] : memref + %3 = load %A[%c3,%c4] : memref + store %3, %A[%c3,%c4] : memref } } - %0 = load %A[3,4] : memref + %0 = load %A[%c3,%c4] : memref ; CHECK-LABEL: func @no_barrier_spmd({{.*}} ; CHECK: parallel { ; CHECK-NEXT: %1 = subgroup_id -; CHECK-NEXT: %2 = cmp.eq %1, 0 : i32 +; CHECK-NEXT: %2 = cmp.eq %1, %c0 : i32 ; CHECK-NEXT: if %2 { -; CHECK-NEXT: %3 = load %A[3,4] : memref -; CHECK-NEXT: store %3, %A[3,4] : memref +; CHECK-NEXT: %3 = load %A[%c3,%c4] : memref +; CHECK-NEXT: store %3, %A[%c3,%c4] : memref ; CHECK-NEXT: } ; CHECK-NEXT: } ; CHECK-NEXT: barrier.global -; CHECK-NEXT: %0 = load %A[3,4] : memref +; CHECK-NEXT: %0 = load %A[%c3,%c4] : memref } diff --git a/test/opt/insert-lifetime-stop.ir b/test/opt/insert-lifetime-stop.ir index dbceb9ed..7a196e7e 100644 --- a/test/opt/insert-lifetime-stop.ir +++ b/test/opt/insert-lifetime-stop.ir @@ -11,18 +11,20 @@ func @basic() { func @use1(%A: memref, %C: memref) { ; CHECK-LABEL: func @use1{{.*}} %B = alloca -> memref - gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref + %one = constant 1.0 -> f32 + gemm.n.n %one, %A, %B, %one, %C : f32, memref, memref, f32, memref ; CHECK: gemm.n.n{{.*}} ; CHECK-NEXT: lifetime_stop %B } func @use2(%A: memref, %C: memref) { ; CHECK-LABEL: func @use2{{.*}} + %one = constant 1.0 -> f32 %B = alloca -> memref - gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref + gemm.n.n %one, %A, %B, %one, %C : f32, memref, memref, f32, memref %B2 = alloca -> memref - gemm.n.n 1.0, %A, %B, 0.0, %C : f32, memref, memref, f32, memref - gemm.n.n 1.0, %A, %B2, 0.0, %C : f32, memref, memref, f32, memref + gemm.n.n %one, %A, %B, %one, %C : f32, memref, memref, f32, memref + gemm.n.n %one, %A, %B2, %one, %C : f32, memref, memref, f32, memref ; CHECK: %B2 = {{.*}} ; CHECK-NEXT: gemm.n.n{{.*}} ; CHECK-NEXT: lifetime_stop %B @@ -30,26 +32,29 @@ func @use2(%A: memref, %C: memref) { ; CHECK-NEXT: lifetime_stop %B2 } -func @use_alias(%A: memref, %C: memref) { +func @use_alias(%a: f32, %A: memref, %C: memref) { ; CHECK-LABEL: func @use_alias{{.*}} %B = alloca -> memref %0 = fuse %B[1,3] : memref - %1 = subview %0[0:8,:] : memref - gemm.n.n 1.0, %A, %1, 0.0, %C : f32, memref, memref,local>, f32, memref + %1 = subview %0[0:8,0:8] : memref + gemm.n.n %a, %A, %1, %a, %C : f32, memref, memref,local>, f32, memref ; CHECK: gemm.n.n{{.*}} ; CHECK-NEXT: lifetime_stop %B } func @region1() { ; CHECK-LABEL: func @region1{{.*}} + %one = constant 1.0 -> f32 %0 = alloca -> memref - for %i=0,4 : index { + %lb = constant 0 -> index + %ub = constant 4 -> index + for %i=%lb,%ub : index { %1 = alloca -> memref - for %k=0,4 : index { + for %k=%lb,%ub : index { %2 = alloca -> memref - gemm.n.n 1.0, %0, %1, 0.0, %2 + gemm.n.n %one, %0, %1, %one, %2 : f32, memref, memref, f32, memref - axpby.n 1.0, %0, 0.0, %1 : f32, memref, f32, memref + axpby.n %one, %0, %one, %1 : f32, memref, f32, memref } } ; CHECK: gemm.n.n{{.*}} diff --git a/test/opt/work-group-size.ir b/test/opt/work-group-size.ir index 04f0f81f..be976dc2 100644 --- a/test/opt/work-group-size.ir +++ b/test/opt/work-group-size.ir @@ -10,8 +10,12 @@ func @f32_blas() { ; CHECK: func @f32_blas() subgroup_size(32) work_group_size(128,2) { %0 = alloca -> memref %1 = alloca -> memref - for %i=0,4 { - axpby.n 1.0, %0, 0.0, %1 : f32, memref, f32, memref + %one = constant 1.0 -> f32 + %zero = constant 0.0 -> f32 + %lb = constant 0 -> index + %ub = constant 4 -> index + for %i=%lb,%ub { + axpby.n %one, %0, %zero, %1 : f32, memref, f32, memref } } @@ -20,8 +24,12 @@ func @f64_blas() { %0 = alloca -> memref %1 = alloca -> memref %2 = alloca -> memref - for %i=0,4 { - gemm.n.n 1.0, %0, %1, 0.0, %2 + %one = constant 1.0 -> f64 + %zero = constant 0.0 -> f64 + %lb = constant 0 -> index + %ub = constant 4 -> index + for %i=%lb,%ub { + gemm.n.n %one, %0, %1, %zero, %2 : f64, memref, memref, f64, memref } } From 88151cce2489182cb527f6cb61aaf0b52a209a2a Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 27 Sep 2024 09:31:03 +0200 Subject: [PATCH 029/297] Add complex constants; remove make_imm / create_imm functions Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.yaml | 5 +- docs/api/builder_cxxapi.yaml | 6 +- include/tinytc/tinytc.h | 71 +++++++++++------- include/tinytc/tinytc.hpp | 93 +++++++++++++++-------- src/codegen_tools.cpp | 118 +++++++++++++++--------------- src/inst.cpp | 34 +++++++++ src/node/inst_node.cpp | 24 +++++- src/node/inst_node.hpp | 10 ++- src/parser/parser_impl.yy | 13 +++- src/pass/convert_to_opencl.cpp | 15 +++- src/pass/dump_ir.cpp | 31 +++++--- src/recipe/small_gemm_batched.cpp | 3 +- src/recipe/tall_and_skinny.cpp | 22 +++--- src/value.cpp | 25 ------- 14 files changed, 290 insertions(+), 180 deletions(-) diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index de132721..c554591c 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -55,6 +55,9 @@ Builder C-API: - tinytc_arith_unary_inst_create - tinytc_cast_inst_create - tinytc_cmp_inst_create + - tinyc_constant_inst_create_complex + - tinyc_constant_inst_create_float + - tinyc_constant_inst_create_int - tinytc_expand_inst_create - tinytc_for_inst_create - tinytc_foreach_inst_create @@ -97,8 +100,6 @@ Builder C-API: - tinytc_region_destroy Value: function: - - tinytc_float_imm_create - - tinytc_int_imm_create - tinytc_value_create - tinytc_value_get_name - tinytc_value_set_name diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index 6994f28b..98ebeeb7 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -50,6 +50,10 @@ Builder C++-API: - tinytc::make_arith(arithmetic_unary,value const&,location const&) - tinytc::make_cast - tinytc::make_cmp + - tinytc::make_constant(std::complex,tinytc_data_type_t,location const&) + - tinytc::make_constant(double,tinytc_data_type_t,location const&) + - tinytc::make_constant(std::int32_t,tinytc_data_type_t,location const&) + - tinytc::make_constant(std::int64_t,tinytc_data_type_t,location const&) - tinytc::make_expand - tinytc::make_for - tinytc::make_foreach @@ -88,8 +92,6 @@ Builder C++-API: - tinytc::region_builder Value: function: - - tinytc::make_fimm - - tinytc::make_imm - tinytc::make_value class: - tinytc::value diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 8bc0525d..e53b46bc 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -116,33 +116,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_group_type_get(tinytc_data_type_t *dt, TINYTC_EXPORT tinytc_status_t tinytc_value_create(tinytc_value_t *vl, tinytc_data_type_t type, const tinytc_location_t *loc); -/** - * @brief Create floating point immediate value - * - * @param vl [out] pointer to the value object created - * @param imm [in] immediate value - * @param type [in] type of immediate value - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_float_imm_create(tinytc_value_t *vl, double imm, - tinytc_data_type_t type, - const tinytc_location_t *loc); -/** - * @brief Create integer immediate value - * - * @param vl [out] pointer to the value object created - * @param imm [in] immediate value - * @param type [in] type of immediate value - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_int_imm_create(tinytc_value_t *vl, int64_t imm, - tinytc_data_type_t type, - const tinytc_location_t *loc); - /** * @brief Release value object * @@ -269,6 +242,50 @@ TINYTC_EXPORT tinytc_status_t tinytc_cmp_inst_create(tinytc_inst_t *instr, tinytc_value_t b, const tinytc_location_t *loc); +/** + * @brief Create complex constant instruction + * + * @param instr [out] pointer to the inst object created + * @param value_re [in] constant value (real part) + * @param value_im [in] constant value (imaginary part) + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_complex(tinytc_inst_t *instr, + double value_re, double value_im, + tinytc_data_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Create floating constant instruction + * + * @param instr [out] pointer to the inst object created + * @param value [in] constant value + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_float(tinytc_inst_t *instr, double value, + tinytc_data_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Create integer constant instruction + * + * @param instr [out] pointer to the inst object created + * @param value [in] constant value + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_int(tinytc_inst_t *instr, int64_t value, + tinytc_data_type_t ty, + const tinytc_location_t *loc); + /** * @brief Create alloca instruction * diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index c2680ab6..68c555e3 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -479,36 +479,6 @@ inline auto make_value(tinytc_data_type_t ty, location const &loc = {}) -> value return value{val}; } -/** - * @brief Make immediate value - * - * @param imm Float value - * @param type Type of immediate value - * @param loc Source code location - * - * @return Value - */ -inline auto make_fimm(double imm, tinytc_data_type_t type, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_float_imm_create(&val, imm, type, &loc), loc); - return value{val}; -} - -/** - * @brief Make immediate value - * - * @param imm Int value - * @param type Type of immediate value - * @param loc Source code location - * - * @return Value - */ -inline auto make_imm(std::int64_t imm, tinytc_data_type_t type, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, type, &loc), loc); - return value{val}; -} - //////////////////////////// /////////// Inst /////////// //////////////////////////// @@ -718,6 +688,69 @@ inline inst make_cmp(cmp_condition cond, value const &a, value const &b, locatio return inst(instr); } +/** + * @brief Make complex constant + * + * @param value_re Real part + * @param value_im Imaginary part + * @param ty Data type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_constant(std::complex value, tinytc_data_type_t ty, + location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC( + tinytc_constant_inst_create_complex(&instr, value.real(), value.imag(), ty, &loc), loc); + return inst(instr); +} + +/** + * @brief Make floating constant + * + * @param value Constant + * @param ty Data type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_constant(double value, tinytc_data_type_t ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_float(&instr, value, ty, &loc), loc); + return inst(instr); +} + +/** + * @brief Make integer constant + * + * @param value Constant + * @param ty Data type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_constant(std::int32_t value, tinytc_data_type_t ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_int(&instr, value, ty, &loc), loc); + return inst(instr); +} + +/** + * @brief Make integer constant + * + * @param value Constant + * @param ty Data type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_constant(std::int64_t value, tinytc_data_type_t ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_int(&instr, value, ty, &loc), loc); + return inst(instr); +} + /** * @brief Make alloca instruction * diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 762be241..b06faf44 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -452,26 +452,25 @@ void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_co std::int64_t blocks = loop_trip_count / sgs; std::int64_t rem = loop_trip_count % sgs; + auto c_sgs = bb.add(make_constant(sgs, index_ty)); + auto c_sgs_blocks = bb.add(make_constant(sgs * blocks, index_ty)); + auto c_sgs_tiles = bb.add(make_constant(sgs * num_tiles, index_ty)); + auto c_tiles_1 = bb.add(make_constant(num_tiles - 1, index_ty)); + auto c_rem = bb.add(make_constant(rem, index_ty)); + auto sg_id_index = bb.add(make_cast(sg_id, scalar_type::index)); if (blocks > 0) { - auto block_start = - bb.add(make_arith(arithmetic::mul, make_imm(sgs, index_ty), sg_id_index)); - auto block_end = make_imm(sgs * blocks, index_ty); - auto step = make_imm(sgs * num_tiles, index_ty); + auto block_start = bb.add(make_arith(arithmetic::mul, c_sgs, sg_id_index)); bb.for_loop( - scalar_type::index, std::move(block_start), std::move(block_end), std::move(step), - [&](region_builder &bb, value const &block) { - body(bb, block, false, make_imm(sgs, index_ty)); - }, + scalar_type::index, std::move(block_start), c_sgs_blocks, c_sgs_tiles, + [&](region_builder &bb, value const &block) { body(bb, block, false, c_sgs); }, "block"); } if (rem > 0) { - auto condition = - bb.add(make_cmp(cmp_condition::eq, sg_id_index, make_imm(num_tiles - 1, index_ty))); - bb.if_condition(condition, [&](region_builder &bb) { - body(bb, make_imm(blocks * sgs, index_ty), true, make_imm(rem, index_ty)); - }); + auto condition = bb.add(make_cmp(cmp_condition::eq, sg_id_index, c_tiles_1)); + bb.if_condition(condition, + [&](region_builder &bb) { body(bb, c_sgs_blocks, true, c_rem); }); } } @@ -479,26 +478,26 @@ void tile_loop_by_sgs_new_dynamic(region_builder &bb, value const &loop_trip_cou int num_tiles, value const &sg_id, sgs_loop_body_builder_new const &body) { auto index_ty = get_scalar(bb.context(), scalar_type::index); - auto blocks = bb.add(make_arith(arithmetic::div, loop_trip_count, make_imm(sgs, index_ty))); - auto rem = bb.add(make_arith(arithmetic::rem, loop_trip_count, make_imm(sgs, index_ty))); + auto c_sgs = bb.add(make_constant(sgs, index_ty)); + auto c_sgs_tiles = bb.add(make_constant(sgs * num_tiles, index_ty)); + auto c0 = bb.add(make_constant(0, index_ty)); + auto c_tiles_1 = bb.add(make_constant(num_tiles - 1, index_ty)); + + auto blocks = bb.add(make_arith(arithmetic::div, loop_trip_count, c_sgs)); + auto rem = bb.add(make_arith(arithmetic::rem, loop_trip_count, c_sgs)); auto sg_id_index = bb.add(make_cast(sg_id, scalar_type::index)); - auto block_start = bb.add(make_arith(arithmetic::mul, make_imm(sgs, index_ty), sg_id_index)); - auto block_end = bb.add(make_arith(arithmetic::mul, make_imm(sgs, index_ty), blocks)); - auto step = make_imm(sgs * num_tiles, index_ty); + auto block_start = bb.add(make_arith(arithmetic::mul, c_sgs, sg_id_index)); + auto block_end = bb.add(make_arith(arithmetic::mul, c_sgs, blocks)); bb.for_loop( - scalar_type::index, std::move(block_start), std::move(block_end), std::move(step), - [&](region_builder &bb, value const &block) { - body(bb, block, false, make_imm(sgs, index_ty)); - }, - "block"); + scalar_type::index, std::move(block_start), std::move(block_end), c_sgs_tiles, + [&](region_builder &bb, value const &block) { body(bb, block, false, c_sgs); }, "block"); - auto condition0 = bb.add(make_cmp(cmp_condition::gt, rem, make_imm(0, index_ty))); + auto condition0 = bb.add(make_cmp(cmp_condition::gt, rem, c0)); bb.if_condition(condition0, [&](region_builder &bb) { - auto condition1 = - bb.add(make_cmp(cmp_condition::eq, sg_id_index, make_imm(num_tiles - 1, index_ty))); + auto condition1 = bb.add(make_cmp(cmp_condition::eq, sg_id_index, c_tiles_1)); bb.if_condition(condition1, [&](region_builder &bb) { - auto block = bb.add(make_arith(arithmetic::mul, blocks, make_imm(sgs, index_ty))); + auto block = bb.add(make_arith(arithmetic::mul, blocks, c_sgs)); body(bb, block, true, rem); }); }); @@ -534,48 +533,51 @@ void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip std::int64_t bs_1 = bs + 1; std::int64_t rem = loop_trip_count % blocks; + auto c_bs = bb.add(make_constant(bs, index_ty)); + auto c_bs_tiles = bb.add(make_constant(bs * num_tiles, index_ty)); + auto c_bs_1 = bb.add(make_constant(bs_1, index_ty)); + auto c_bs_1_rem = bb.add(make_constant(bs_1 * rem, index_ty)); + auto c_bs_1_tiles = bb.add(make_constant(bs_1 * num_tiles, index_ty)); + auto c_rem_mod_tiles = bb.add(make_constant(rem % num_tiles, index_ty)); + auto c_tiles = bb.add(make_constant(num_tiles, index_ty)); + auto c_loop_trip_count = bb.add(make_constant(loop_trip_count, index_ty)); + auto sg_id_index = bb.add(make_cast(sg_id, scalar_type::index)); if (rem > 0) { - auto block_start = - bb.add(make_arith(arithmetic::mul, make_imm(bs_1, index_ty), sg_id_index)); - auto block_end = make_imm(bs_1 * rem, index_ty); - auto step = make_imm(bs_1 * num_tiles, index_ty); + auto block_start = bb.add(make_arith(arithmetic::mul, c_bs_1, sg_id_index)); bb.for_loop( - scalar_type::index, std::move(block_start), std::move(block_end), std::move(step), - [&](region_builder &bb, value const &block) { - body(bb, block, make_imm(bs_1, index_ty)); - }, - "block"); + scalar_type::index, std::move(block_start), c_bs_1_rem, c_bs_1_tiles, + [&](region_builder &bb, value const &block) { body(bb, block, c_bs_1); }, "block"); } - auto tmp = - bb.add(make_arith(arithmetic::add, sg_id_index, make_imm(rem % num_tiles, index_ty))); - auto sg_id_1 = bb.add(make_arith(arithmetic::rem, tmp, make_imm(num_tiles, index_ty))); - auto tmp2 = bb.add(make_arith(arithmetic::mul, make_imm(bs, index_ty), sg_id_1)); - auto block_start = bb.add(make_arith(arithmetic::add, make_imm(bs_1 * rem, index_ty), tmp2)); - auto block_end = make_imm(loop_trip_count, index_ty); - auto step = make_imm(bs * num_tiles, index_ty); + auto tmp = bb.add(make_arith(arithmetic::add, sg_id_index, c_rem_mod_tiles)); + auto sg_id_1 = bb.add(make_arith(arithmetic::rem, tmp, c_tiles)); + auto tmp2 = bb.add(make_arith(arithmetic::mul, c_bs, sg_id_1)); + auto block_start = bb.add(make_arith(arithmetic::add, c_bs_1_rem, tmp2)); bb.for_loop( - scalar_type::index, std::move(block_start), std::move(block_end), std::move(step), - [&](region_builder &bb, value const &block) { body(bb, block, make_imm(bs, index_ty)); }, - "block"); + scalar_type::index, std::move(block_start), c_loop_trip_count, c_bs_tiles, + [&](region_builder &bb, value const &block) { body(bb, block, c_bs); }, "block"); } void tile_loop_uniformly_new_dynamic(region_builder &bb, value const &loop_trip_count, int block_size, int num_tiles, value const &sg_id, uniform_loop_body_builder_new const &body) { auto index_ty = get_scalar(bb.context(), scalar_type::index); - auto blocks0 = bb.add(make_arith(arithmetic::sub, loop_trip_count, make_imm(1, index_ty))); - auto blocks1 = bb.add(make_arith(arithmetic::div, blocks0, make_imm(block_size, index_ty))); - auto blocks2 = bb.add(make_arith(arithmetic::add, make_imm(1, index_ty), blocks1)); - auto blocks3 = bb.add(make_arith(arithmetic::sub, blocks2, make_imm(1, index_ty))); - auto blocks4 = bb.add(make_arith(arithmetic::div, blocks3, make_imm(num_tiles, index_ty))); - auto blocks5 = bb.add(make_arith(arithmetic::add, make_imm(1, index_ty), blocks4)); - auto blocks = bb.add(make_arith(arithmetic::mul, blocks5, make_imm(num_tiles, index_ty))); + auto c1 = bb.add(make_constant(1, index_ty)); + auto c_block_size = bb.add(make_constant(block_size, index_ty)); + auto c_tiles = bb.add(make_constant(num_tiles, index_ty)); + + auto blocks0 = bb.add(make_arith(arithmetic::sub, loop_trip_count, c1)); + auto blocks1 = bb.add(make_arith(arithmetic::div, blocks0, c_block_size)); + auto blocks2 = bb.add(make_arith(arithmetic::add, c1, blocks1)); + auto blocks3 = bb.add(make_arith(arithmetic::sub, blocks2, c1)); + auto blocks4 = bb.add(make_arith(arithmetic::div, blocks3, c_tiles)); + auto blocks5 = bb.add(make_arith(arithmetic::add, c1, blocks4)); + auto blocks = bb.add(make_arith(arithmetic::mul, blocks5, c_tiles)); blocks->name("blocks"); auto bs = bb.add(make_arith(arithmetic::div, loop_trip_count, blocks)); bs->name("bs"); - auto bs_1 = bb.add(make_arith(arithmetic::add, bs, make_imm(1, index_ty))); + auto bs_1 = bb.add(make_arith(arithmetic::add, bs, c1)); bs_1->name("bs_1"); auto rem = bb.add(make_arith(arithmetic::rem, loop_trip_count, blocks)); rem->name("rem"); @@ -583,18 +585,18 @@ void tile_loop_uniformly_new_dynamic(region_builder &bb, value const &loop_trip_ auto sg_id_index = bb.add(make_cast(sg_id, scalar_type::index)); auto block_start_1 = bb.add(make_arith(arithmetic::mul, bs_1, sg_id_index)); auto block_end_1 = bb.add(make_arith(arithmetic::mul, bs_1, rem)); - auto step_1 = bb.add(make_arith(arithmetic::mul, bs_1, make_imm(num_tiles, index_ty))); + auto step_1 = bb.add(make_arith(arithmetic::mul, bs_1, c_tiles)); bb.for_loop( scalar_type::index, std::move(block_start_1), std::move(block_end_1), std::move(step_1), [&](region_builder &bb, value const &block) { body(bb, block, bs_1); }, "block"); - auto tmp0 = bb.add(make_arith(arithmetic::rem, rem, make_imm(num_tiles, index_ty))); + auto tmp0 = bb.add(make_arith(arithmetic::rem, rem, c_tiles)); auto tmp1 = bb.add(make_arith(arithmetic::add, sg_id_index, tmp0)); - auto sg_id_1 = bb.add(make_arith(arithmetic::rem, tmp1, make_imm(num_tiles, index_ty))); + auto sg_id_1 = bb.add(make_arith(arithmetic::rem, tmp1, c_tiles)); auto tmp2 = bb.add(make_arith(arithmetic::mul, bs, sg_id_1)); auto tmp3 = bb.add(make_arith(arithmetic::mul, bs_1, rem)); auto block_start = bb.add(make_arith(arithmetic::add, tmp3, tmp2)); - auto step = bb.add(make_arith(arithmetic::mul, bs, make_imm(num_tiles, index_ty))); + auto step = bb.add(make_arith(arithmetic::mul, bs, c_tiles)); bb.for_loop( scalar_type::index, std::move(block_start), loop_trip_count, std::move(step), [&](region_builder &bb, value const &block) { body(bb, block, bs); }, "block"); diff --git a/src/inst.cpp b/src/inst.cpp index bafa5e2b..f1abf9ae 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -11,6 +11,7 @@ #include "tinytc/types.hpp" #include +#include #include #include #include @@ -146,6 +147,39 @@ tinytc_status_t tinytc_cmp_inst_create(tinytc_inst_t *instr, tinytc_cmp_conditio }); } +tinytc_status_t tinytc_constant_inst_create_complex(tinytc_inst_t *instr, double value_re, + double value_im, tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = std::make_unique(std::complex(value_re, value_im), ty, + get_optional(loc)) + .release(); + }); +} + +tinytc_status_t tinytc_constant_inst_create_float(tinytc_inst_t *instr, double value, + tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *instr = std::make_unique(value, ty, get_optional(loc)).release(); }); +} + +tinytc_status_t tinytc_constant_inst_create_int(tinytc_inst_t *instr, int64_t value, + tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *instr = std::make_unique(value, ty, get_optional(loc)).release(); }); +} + tinytc_status_t tinytc_alloca_inst_create(tinytc_inst_t *instr, tinytc_data_type_t ty, const tinytc_location_t *loc) { if (instr == nullptr) { diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index b19190b2..671208ea 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -196,14 +196,30 @@ compare_inst::compare_inst(cmp_condition cond, value a0, value b0, location cons result(0) = make_value(scalar_data_type::get(at->context(), scalar_type::i1)); } -constant_inst::constant_inst(std::variant const &value, tinytc_data_type_t ty, - location const &lc) +constant_inst::constant_inst(value_type const &value, tinytc_data_type_t ty, location const &lc) : standard_inst{IK::constant}, value_(value) { loc(lc); if (auto st = dyn_cast(ty); st) { - if ((is_floating_type(st->ty()) && std::holds_alternative(value_)) || - (!is_floating_type(st->ty()) && std::holds_alternative(value_))) { + const auto type_ok = [](value_type const &val, scalar_type ty) { + switch (ty) { + case scalar_type::i1: + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return std::holds_alternative(val); + case scalar_type::f32: + case scalar_type::f64: + return std::holds_alternative(val); + case scalar_type::c32: + case scalar_type::c64: + return std::holds_alternative>(val); + } + return false; + }; + if (!type_ok(value_, st->ty())) { throw compilation_error(loc(), status::ir_scalar_mismatch); } } else { diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index d6d637e6..e324b34b 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -13,6 +13,7 @@ #include "tinytc/types.hpp" #include +#include #include #include #include @@ -434,14 +435,15 @@ class compare_inst : public standard_inst<2, 1> { class constant_inst : public standard_inst<0, 1> { public: + using value_type = std::variant>; + inline static bool classof(inst_node const &i) { return i.type_id() == IK::constant; } - constant_inst(std::variant const &value, tinytc_data_type_t ty, - location const &lc = {}); + constant_inst(value_type const &value, tinytc_data_type_t ty, location const &lc = {}); - auto value() const -> std::variant const & { return value_; } + auto value() const -> value_type const & { return value_; } private: - std::variant value_; + value_type value_; }; class expand_inst : public standard_inst { diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index d687fb1e..3e6c39a5 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -797,7 +797,18 @@ compare_inst: ; constant_inst: - CONSTANT FLOATING_CONSTANT RETURNS data_type { + CONSTANT LSQBR FLOATING_CONSTANT[re] COMMA FLOATING_CONSTANT[im] RSQBR RETURNS data_type { + try { + $$ = inst { + std::make_unique(std::complex{$re, $im}, $data_type, @constant_inst) + .release() + }; + } catch (compilation_error const &e) { + error(e.loc(), e.what()); + YYERROR; + } + } + | CONSTANT FLOATING_CONSTANT RETURNS data_type { try { $$ = inst { std::make_unique($FLOATING_CONSTANT, $data_type, @constant_inst).release() diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 36c5b1e3..4ea73a14 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -419,10 +419,17 @@ std::vector convert_to_opencl_pass::operator()(compare_inst const &c std::vector convert_to_opencl_pass::operator()(constant_inst const &c) { auto v = declare(*c.result()); auto ty = get_scalar_type(*c.result()); - auto rhs = std::visit( - overloaded{[&](std::int64_t i) { return clir::expr(i, static_cast(size(ty) * 8)); }, - [&](double d) { return clir::expr(d, static_cast(size(ty) * 8)); }}, - c.value()); + auto ty_bits = static_cast(size(ty) * 8); + auto rhs = + std::visit(overloaded{ + [&](std::int64_t i) { return clir::expr(i, ty_bits); }, + [&](double d) { return clir::expr(d, ty_bits); }, + [&](std::complex d) { + return init_vector(to_clir_ty(ty), {clir::expr(d.real(), ty_bits), + clir::expr(d.imag(), ty_bits)}); + }, + }, + c.value()); return {declaration_assignment(visit(*this, *c.result()->ty()), std::move(v), std::move(rhs))}; } diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 4012b365..5a900812 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -171,18 +171,25 @@ void dump_ir_pass::operator()(compare_inst const &a) { void dump_ir_pass::operator()(constant_inst const &c) { visit(*this, *c.result()); *os_ << " = constant "; - std::visit(overloaded{[&](std::int64_t i) { - if (is_dynamic_value(i)) { - *os_ << "?"; - } else { - *os_ << i; - } - }, - [&](double d) { - auto flags = os_->flags(); - *os_ << std::hexfloat << d; - os_->flags(flags); - }}, + std::visit(overloaded{ + [&](std::int64_t i) { + if (is_dynamic_value(i)) { + *os_ << "?"; + } else { + *os_ << i; + } + }, + [&](double d) { + auto flags = os_->flags(); + *os_ << std::hexfloat << d; + os_->flags(flags); + }, + [&](std::complex d) { + auto flags = os_->flags(); + *os_ << std::hexfloat << "[" << d.real() << "," << d.imag() << "]"; + os_->flags(flags); + }, + }, c.value()); *os_ << " -> "; visit(*this, *c.result()->ty()); diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index ef1ddf1d..8052cf5c 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -100,7 +100,6 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( {1, ldC, strideC}, address_space::global, my_loc()), "C", my_loc()); - auto beta = is_beta_nonzero ? std::move(beta_arg) : make_fimm(0.0, ty_, my_loc()); fb.body( [&](region_builder &bb) { auto gid = bb.add(make_group_id(ctx_, my_loc())); @@ -114,6 +113,8 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( make_subview(B, static_offsets, B_static_sizes, {}, {}, my_loc())); auto c = bb.add( make_subview(C, static_offsets, C_static_sizes, {}, {}, my_loc())); + auto beta = is_beta_nonzero ? std::move(beta_arg) + : bb.add(make_constant(0.0, ty_, my_loc())); bb.add(make_gemm(tA_, tB_, false, alpha, std::move(a), std::move(b), beta, std::move(c), my_loc())); }, diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 4623c9e1..a78a5fd5 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -104,12 +104,12 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( tiling[1] /= 2; } - auto const body = [&](region_builder &bb, value &alpha, value &A, value &B, value &beta, - value &C) { - auto const block_size_imm = make_imm(M_block_size, index_ty, my_loc()); + auto const body = [&](region_builder &bb, value &alpha, value &A, value &B, + bool is_beta_nonzero, value &beta_arg, value &C) { + auto c_M_block_size = bb.add(make_constant(M_block_size, index_ty, my_loc())); auto gid = bb.add(make_group_id(ctx_, my_loc())); - auto m = bb.add(make_arith(arithmetic::mul, gid, - make_imm(M_block_size, index_ty, my_loc()), my_loc())); + auto m = bb.add(make_arith(arithmetic::mul, gid, c_M_block_size, my_loc())); + auto beta = is_beta_nonzero ? beta_arg : bb.add(make_constant(0.0, ty_, my_loc())); auto const static_offsets = std::vector{dynamic, 0}; auto const offsets = std::vector{m}; @@ -140,11 +140,10 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( static_gemm(bb); } else { auto M_val = is_dynamic_value(M) ? bb.add(make_size(C, 0, my_loc())) - : make_imm(M, index_ty); + : bb.add(make_constant(M, index_ty)); auto M_val_sub_m = bb.add(make_arith(arithmetic::sub, M_val, m, my_loc())); auto cond = - bb.add(make_cmp(cmp_condition::lt, M_val_sub_m, - make_imm(M_block_size, index_ty, my_loc()), my_loc())); + bb.add(make_cmp(cmp_condition::lt, M_val_sub_m, c_M_block_size, my_loc())); bb.ifelse( cond, [&](region_builder &bb) { dynamic_gemm(bb, M_val_sub_m); }, [&](region_builder &bb) { static_gemm(bb); }, {}, my_loc()); @@ -167,8 +166,11 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( auto const wgs = tiling.work_group_size(sgs); fb.work_group_size(wgs[0], wgs[1]); - auto beta = is_beta_nonzero ? beta_arg : make_fimm(0.0, ty_, my_loc()); - fb.body([&](region_builder &bb) { body(bb, alpha, A, B, beta, C); }, my_loc()); + fb.body( + [&](region_builder &bb) { + body(bb, alpha, A, B, is_beta_nonzero, beta_arg, C); + }, + my_loc()); }; auto p = [&] { diff --git a/src/value.cpp b/src/value.cpp index 3ed0ad4e..016311dd 100644 --- a/src/value.cpp +++ b/src/value.cpp @@ -16,22 +16,6 @@ using namespace tinytc; -namespace { -template -tinytc_status_t create_imm(tinytc_value_t *vl, T imm, tinytc_data_type_t type, - const tinytc_location_t *lc) { - if (vl == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *vl = std::make_unique(imm, type).release(); - if (lc) { - (*vl)->loc(*lc); - } - }); -} -} // namespace - extern "C" { tinytc_status_t tinytc_value_create(tinytc_value_t *vl, tinytc_data_type_t type, const tinytc_location_t *lc) { @@ -42,15 +26,6 @@ tinytc_status_t tinytc_value_create(tinytc_value_t *vl, tinytc_data_type_t type, [&] { *vl = std::make_unique(type, get_optional(lc)).release(); }); } -tinytc_status_t tinytc_float_imm_create(tinytc_value_t *vl, double imm, tinytc_data_type_t type, - const tinytc_location_t *loc) { - return create_imm(vl, imm, type, loc); -} -tinytc_status_t tinytc_int_imm_create(tinytc_value_t *vl, int64_t imm, tinytc_data_type_t type, - const tinytc_location_t *loc) { - return create_imm(vl, imm, type, loc); -} - tinytc_status_t tinytc_value_release(tinytc_value_t obj) { if (obj == nullptr) { return tinytc_status_invalid_arguments; From 665cd2438c821976ca65a49e311f2e8aeb183035 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 27 Sep 2024 09:51:37 +0200 Subject: [PATCH 030/297] remove value class hierarchy Signed-off-by: Carsten Uphoff --- src/codegen_tools.cpp | 27 +------ src/node/function_node.hpp | 3 +- src/node/value_node.hpp | 78 +++---------------- src/parser/parser_impl.yy | 16 ++-- src/pass/convert_to_opencl.cpp | 136 +++++++++++++-------------------- src/pass/convert_to_opencl.hpp | 6 +- src/pass/dump_ir.cpp | 124 ++++++++++++++---------------- src/pass/dump_ir.hpp | 6 +- src/value.cpp | 2 +- 9 files changed, 133 insertions(+), 265 deletions(-) diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index b06faf44..c1a1b201 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -432,17 +432,8 @@ void write_matrix_block(block_builder &bb, block_accessor const &block, void tile_loop_by_sgs_new(region_builder &bb, value const &loop_trip_count, int sgs, int num_tiles, value const &sg_id, sgs_loop_body_builder_new const &body) { - visit(overloaded{ - [&](int_imm &c) { - tile_loop_by_sgs_new_constant(bb, c.value(), sgs, num_tiles, std::move(sg_id), - body); - }, - [&](auto &) { - tile_loop_by_sgs_new_dynamic(bb, std::move(loop_trip_count), sgs, num_tiles, - std::move(sg_id), body); - }, - }, - *loop_trip_count); + tile_loop_by_sgs_new_dynamic(bb, std::move(loop_trip_count), sgs, num_tiles, std::move(sg_id), + body); } void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_count, int sgs, @@ -506,18 +497,8 @@ void tile_loop_by_sgs_new_dynamic(region_builder &bb, value const &loop_trip_cou void tile_loop_uniformly_new(region_builder &bb, value const &loop_trip_count, int block_size, int num_tiles, value const &sg_id, uniform_loop_body_builder_new const &body) { - visit( - overloaded{ - //[&](int_imm &c) { - // tile_loop_uniformly_new_constant(bb, c.value(), block_size, num_tiles, - // std::move(sg_id), body); - //}, - [&](auto &) { - tile_loop_uniformly_new_dynamic(bb, std::move(loop_trip_count), block_size, - num_tiles, std::move(sg_id), body); - }, - }, - *loop_trip_count); + tile_loop_uniformly_new_dynamic(bb, std::move(loop_trip_count), block_size, num_tiles, + std::move(sg_id), body); } void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip_count, diff --git a/src/node/function_node.hpp b/src/node/function_node.hpp index 02c5db38..6aa71007 100644 --- a/src/node/function_node.hpp +++ b/src/node/function_node.hpp @@ -25,8 +25,7 @@ struct tinytc_func final { inline tinytc_func(std::string name, std::vector args, tinytc_region_t body, tinytc_location const &lc = {}) : name_(std::move(name)), args_(std::move(args)), body_(tinytc::region{body}), - work_group_size_{0, 0}, subgroup_size_{0} { - loc(lc); + work_group_size_{0, 0}, subgroup_size_{0}, loc_{lc} { body_->kind(tinytc::region_kind::collective); } diff --git a/src/node/value_node.hpp b/src/node/value_node.hpp index 4d75b1dc..0f54bdf4 100644 --- a/src/node/value_node.hpp +++ b/src/node/value_node.hpp @@ -4,101 +4,41 @@ #ifndef VALUE_NODE_20230309_HPP #define VALUE_NODE_20230309_HPP +#include "location.hpp" #include "node/data_type_node.hpp" #include "reference_counted.hpp" -#include "support/type_list.hpp" -#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" #include #include #include -namespace tinytc { -enum class VK { float_, int_, val }; -using value_nodes = type_list; -} // namespace tinytc - -struct tinytc_value : tinytc::reference_counted { +struct tinytc_value final : tinytc::reference_counted { public: - using leaves = tinytc::value_nodes; - - inline tinytc_value(tinytc::VK tid, tinytc_data_type_t ty) : tid_(tid), ty_(std::move(ty)) {} - virtual ~tinytc_value() = default; - inline auto type_id() const -> tinytc::VK { return tid_; } + inline tinytc_value(tinytc_data_type_t ty, tinytc::location const &lc = {}) + : ty_{std::move(ty)}, loc_{lc} {} inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } inline tinytc_data_type_t ty() const { return ty_; } - inline void ty(tinytc_data_type_t ty) { ty_ = std::move(ty); } inline auto context() const -> tinytc_compiler_context_t { return ty_->context(); } - virtual auto name() const -> char const * = 0; - virtual void name(std::string name) = 0; - virtual auto has_name() const -> bool = 0; + inline auto name() const -> char const * { return name_.c_str(); } + inline void name(std::string name) { name_ = std::move(name); } + auto has_name() const -> bool { return !name_.empty(); } private: - tinytc::VK tid_; tinytc_data_type_t ty_; tinytc::location loc_; + std::string name_; }; namespace tinytc { using value_node = ::tinytc_value; -class float_imm : public value_node { - public: - inline static bool classof(value_node const &v) { return v.type_id() == VK::float_; } - inline float_imm(double v, tinytc_data_type_t ty, location const &lc = {}) - : value_node(VK::float_, ty), value_(v) { - loc(lc); - } - - inline auto name() const -> char const * override { return ""; } - inline void name(std::string) override {} - auto has_name() const -> bool override { return false; } - - inline double value() const { return value_; } - - private: - double value_; -}; - -class int_imm : public value_node { - public: - inline static bool classof(value_node const &v) { return v.type_id() == VK::int_; } - inline int_imm(std::int64_t v, tinytc_data_type_t ty, location const &lc = {}) - : value_node(VK::int_, ty), value_(v) { - loc(lc); - } - - inline auto name() const -> char const * override { return ""; } - inline void name(std::string) override {} - auto has_name() const -> bool override { return false; } - - inline std::int64_t value() const { return value_; } - - private: - std::int64_t value_; -}; - -class val : public value_node { - public: - inline static bool classof(value_node const &v) { return v.type_id() == VK::val; } - inline val(tinytc_data_type_t ty, location const &lc = {}) : value_node(VK::val, ty) { - loc(lc); - } - - inline auto name() const -> char const * override { return name_.c_str(); } - inline void name(std::string name) override { name_ = std::move(name); } - virtual auto has_name() const -> bool override { return !name_.empty(); } - - private: - std::string name_; -}; - } // namespace tinytc #endif // VALUE_NODE_20230309_HPP diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 3e6c39a5..ad758e1a 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -45,17 +45,11 @@ namespace tinytc { void check_scalar_type(compiler_context const &ctx, value &val, scalar_type const &sty, location &loc1, location &loc2) { - visit(overloaded{[&](int_imm &i) { i.ty(get_scalar(ctx, sty)); }, - [&](float_imm &i) { i.ty(get_scalar(ctx, sty)); }, - [&](auto &) { - if (val->ty() != get_scalar(ctx, sty)) { - auto loc = loc1; - loc.end = loc2.end; - throw parser::syntax_error( - loc, "Type of SSA value does not match operand type"); - } - }}, - *val); + if (val->ty() != get_scalar(ctx, sty)) { + auto loc = loc1; + loc.end = loc2.end; + throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); + } } void check_type(value &val, tinytc_data_type_t ty, location &loc1, location &loc2) { if (val->ty() != ty) { diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 4ea73a14..b8c48f94 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -189,15 +189,7 @@ clir::data_type convert_to_opencl_pass::operator()(scalar_data_type const &s) { } /* Value nodes */ -clir::expr convert_to_opencl_pass::operator()(float_imm const &v) { - auto ty = get_scalar_type(v); - return clir::expr(v.value(), static_cast(size(ty) * 8)); -} -clir::expr convert_to_opencl_pass::operator()(int_imm const &v) { - auto ty = get_scalar_type(v); - return clir::expr(v.value(), static_cast(size(ty) * 8)); -} -clir::expr convert_to_opencl_pass::operator()(val const &v) { +auto convert_to_opencl_pass::val(value_node const &v) -> clir::expr { uintptr_t u = std::bit_cast(&v); for (auto it = declared_vars_.rbegin(); it != declared_vars_.rend(); ++it) { if (auto j = it->find(u); j != it->end()) { @@ -243,8 +235,8 @@ std::vector convert_to_opencl_pass::operator()(axpby_inst const &ins auto pA = inst.tA() == transpose::T && at->dim() == 2 ? 1 : 0; - auto alpha = visit(*this, *inst.alpha()); - auto beta = visit(*this, *inst.beta()); + auto alpha = val(*inst.alpha()); + auto beta = val(*inst.beta()); auto const inner_loop = [&](clir::block_builder &bb, clir::expr Ab, clir::expr Bb, clir::expr trip_count, std::size_t num_tiles, clir::var sg_id) { auto m = bb.declare_assign(clir::generic_uint(), "m", clir::get_sub_group_local_id()); @@ -270,8 +262,8 @@ std::vector convert_to_opencl_pass::operator()(axpby_inst const &ins }); }; - auto A = visit(*this, *inst.A()); - auto B = visit(*this, *inst.B()); + auto A = val(*inst.A()); + auto B = val(*inst.B()); if (bt->dim() == 0) { auto bb = clir::block_builder{}; const auto a_scaled = multiply(alpha_ty, at->element_ty(), alpha, A[0]); @@ -362,9 +354,8 @@ std::vector convert_to_opencl_pass::operator()(arith_inst const &a) }; auto sty = get_scalar_type(*a.a()); auto v = declare(*a.result()); - return {declaration_assignment( - visit(*this, *a.result()->ty()), std::move(v), - make(a.operation(), visit(*this, *a.a()), visit(*this, *a.b()), sty))}; + return {declaration_assignment(visit(*this, *a.result()->ty()), std::move(v), + make(a.operation(), val(*a.a()), val(*a.b()), sty))}; } std::vector convert_to_opencl_pass::operator()(arith_unary_inst const &a) { @@ -383,13 +374,13 @@ std::vector convert_to_opencl_pass::operator()(arith_unary_inst cons auto sty = get_scalar_type(*a.a()); auto v = declare(*a.result()); return {declaration_assignment(visit(*this, *a.result()->ty()), std::move(v), - make(a.operation(), visit(*this, *a.a()), sty))}; + make(a.operation(), val(*a.a()), sty))}; } std::vector convert_to_opencl_pass::operator()(cast_inst const &c) { auto v = declare(*c.result()); auto result_ty = visit(*this, *c.result()->ty()); - auto cst = cast(result_ty, visit(*this, *c.a())); + auto cst = cast(result_ty, val(*c.a())); return {declaration_assignment(std::move(result_ty), std::move(v), std::move(cst))}; } @@ -413,7 +404,7 @@ std::vector convert_to_opencl_pass::operator()(compare_inst const &c }; auto v = declare(*c.result()); return {declaration_assignment(visit(*this, *c.result()->ty()), std::move(v), - make(c.cond(), visit(*this, *c.a()), visit(*this, *c.b())))}; + make(c.cond(), val(*c.a()), val(*c.b())))}; } std::vector convert_to_opencl_pass::operator()(constant_inst const &c) { @@ -440,7 +431,7 @@ std::vector convert_to_opencl_pass::operator()(expand_inst const &e) auto static_shape = e.static_expand_shape(); auto dyn_shape = e.expand_shape(); - auto rhs = visit(*this, *e.operand()); + auto rhs = val(*e.operand()); auto clinst = std::vector{}; clinst.emplace_back( clir::declaration_assignment(this->operator()(*m), std::move(result_var), std::move(rhs))); @@ -460,7 +451,7 @@ std::vector convert_to_opencl_pass::operator()(expand_inst const &e) int j = 0; for (auto &s : static_shape) { if (is_dynamic_value(s)) { - eshape_cl.emplace_back(visit(*this, *dyn_shape[j++])); + eshape_cl.emplace_back(val(*dyn_shape[j++])); } else { eshape_cl.emplace_back(clir::expr(s, static_cast(size(scalar_type::index) * 8))); } @@ -491,7 +482,7 @@ std::vector convert_to_opencl_pass::operator()(fuse_inst const &f) { auto m = get_memref_type(*f.operand()); auto &dv = get_dope_vector(f.operand().get()); - auto rhs = visit(*this, *f.operand()); + auto rhs = val(*f.operand()); auto shape = std::vector{}; auto stride = std::vector{}; shape.reserve(m->dim()); @@ -528,7 +519,7 @@ std::vector convert_to_opencl_pass::operator()(fuse_inst const &f) { std::vector convert_to_opencl_pass::operator()(load_inst const &e) { auto op_val = e.operand(); - auto rhs = visit(*this, *op_val); + auto rhs = val(*op_val); auto clinst = std::vector{}; @@ -537,7 +528,7 @@ std::vector convert_to_opencl_pass::operator()(load_inst const &e) { throw compilation_error(e.loc(), status::ir_invalid_number_of_indices); } - auto idx = visit(*this, *e.index_list().front()); + auto idx = val(*e.index_list().front()); rhs = rhs + idx; auto &dv = get_dope_vector(e.operand().get()); @@ -560,7 +551,7 @@ std::vector convert_to_opencl_pass::operator()(load_inst const &e) { } auto &dv = get_dope_vector(e.operand().get()); for (std::int64_t i = 0; i < m.dim(); ++i) { - rhs = rhs + visit(*this, *e.index_list()[i]) * dv.stride(i); + rhs = rhs + val(*e.index_list()[i]) * dv.stride(i); } rhs = clir::dereference(std::move(rhs)); }, @@ -612,14 +603,6 @@ std::vector convert_to_opencl_pass::operator()(gemm_inst const &g) { auto const ak = g.tA() == transpose::T ? 0 : 1; auto const K = a->shape(ak); - auto const get_fixed = [](value const &v) { - return visit( - overloaded{[&](int_imm const &i) -> std::optional { return i.value(); }, - [&](float_imm const &i) -> std::optional { return i.value(); }, - [](auto const &) -> std::optional { return std::nullopt; }}, - *v); - }; - auto gemm_ty = gemm_scalar_type{get_scalar_type(*g.alpha()), a->element_ty(), b->element_ty(), get_scalar_type(*g.beta()), c->element_ty()}; auto cfg = gemm_configuration{std::move(gemm_ty), @@ -631,8 +614,8 @@ std::vector convert_to_opencl_pass::operator()(gemm_inst const &g) { {a->stride(0), a->stride(1)}, {b->stride(0), b->stride(1)}, {c->stride(0), c->stride(1)}, - get_fixed(g.alpha()), - get_fixed(g.beta()), + std::nullopt, + std::nullopt, g.atomic()}; auto name = cfg.identifier(); int name_counter = 0; @@ -647,10 +630,9 @@ std::vector convert_to_opencl_pass::operator()(gemm_inst const &g) { } has_gemm_.emplace(name); return {clir::expression_statement(clir::call( - std::move(name), - {cdv.shape(0), cdv.shape(1), adv.shape(ak), visit(*this, *g.alpha()), visit(*this, *g.A()), - adv.stride(0), adv.stride(1), visit(*this, *g.B()), bdv.stride(0), bdv.stride(1), - visit(*this, *g.beta()), visit(*this, *g.C()), cdv.stride(0), cdv.stride(1)}))}; + std::move(name), {cdv.shape(0), cdv.shape(1), adv.shape(ak), val(*g.alpha()), val(*g.A()), + adv.stride(0), adv.stride(1), val(*g.B()), bdv.stride(0), bdv.stride(1), + val(*g.beta()), val(*g.C()), cdv.stride(0), cdv.stride(1)}))}; } std::vector convert_to_opencl_pass::operator()(gemv_inst const &g) { @@ -664,14 +646,7 @@ std::vector convert_to_opencl_pass::operator()(gemv_inst const &g) { auto const M = c->shape(0); auto const ak = g.tA() == transpose::T ? 0 : 1; auto const K = a->shape(ak); - auto const N = 1; - auto const get_fixed = [](value const &v) { - return visit( - overloaded{[&](int_imm const &i) -> std::optional { return i.value(); }, - [&](float_imm const &i) -> std::optional { return i.value(); }, - [](auto const &) -> std::optional { return std::nullopt; }}, - *v); - }; + constexpr auto N = 1; auto gemm_ty = gemm_scalar_type{get_scalar_type(*g.alpha()), a->element_ty(), b->element_ty(), get_scalar_type(*g.beta()), c->element_ty()}; @@ -684,8 +659,8 @@ std::vector convert_to_opencl_pass::operator()(gemv_inst const &g) { {a->stride(0), a->stride(1)}, {b->stride(0), 0}, {c->stride(0), 0}, - get_fixed(g.alpha()), - get_fixed(g.beta()), + std::nullopt, + std::nullopt, g.atomic()}; auto name = cfg.identifier("gemv"); int name_counter = 0; @@ -700,10 +675,9 @@ std::vector convert_to_opencl_pass::operator()(gemv_inst const &g) { } has_gemm_.emplace(name); return {clir::expression_statement( - clir::call(std::move(name), - {cdv.shape(0), 1, adv.shape(ak), visit(*this, *g.alpha()), visit(*this, *g.A()), - adv.stride(0), adv.stride(1), visit(*this, *g.B()), bdv.stride(0), 0, - visit(*this, *g.beta()), visit(*this, *g.C()), cdv.stride(0), 0}))}; + clir::call(std::move(name), {cdv.shape(0), 1, adv.shape(ak), val(*g.alpha()), val(*g.A()), + adv.stride(0), adv.stride(1), val(*g.B()), bdv.stride(0), 0, + val(*g.beta()), val(*g.C()), cdv.stride(0), 0}))}; } std::vector convert_to_opencl_pass::operator()(ger_inst const &g) { @@ -714,14 +688,14 @@ std::vector convert_to_opencl_pass::operator()(ger_inst const &g) { auto &bdv = get_dope_vector(g.B().get()); auto &cdv = get_dope_vector(g.C().get()); - auto alpha = visit(*this, *g.alpha()); - auto beta = visit(*this, *g.beta()); + auto alpha = val(*g.alpha()); + auto beta = val(*g.beta()); auto alpha_ty = get_scalar_type(*g.alpha()); auto beta_ty = get_scalar_type(*g.beta()); - auto A = visit(*this, *g.A()); - auto B = visit(*this, *g.B()); - auto C = visit(*this, *g.C()); + auto A = val(*g.A()); + auto B = val(*g.B()); + auto C = val(*g.C()); auto bb = clir::block_builder{}; auto sg_n = bb.declare_assign(clir::generic_uint(), "sg_n", @@ -780,9 +754,9 @@ std::vector convert_to_opencl_pass::operator()(for_inst const &p) { auto lv = declare(*p.loop_var()); auto lv_ty = visit(*this, *p.loop_var()->ty()); - auto start = clir::declaration_assignment(std::move(lv_ty), lv, visit(*this, *p.from())); - auto condition = lv < visit(*this, *p.to()); - auto step = p.step() ? clir::add_into(lv, visit(*this, *p.step())) : ++lv; + auto start = clir::declaration_assignment(std::move(lv_ty), lv, val(*p.from())); + auto condition = lv < val(*p.to()); + auto step = p.step() ? clir::add_into(lv, val(*p.step())) : ++lv; auto body = run_on_region(p.body()); clinst.emplace_back(clir::stmt(std::make_shared( std::move(start), std::move(condition), std::move(step), std::move(body)))); @@ -794,8 +768,8 @@ std::vector convert_to_opencl_pass::operator()(for_inst const &p) { std::vector convert_to_opencl_pass::operator()(foreach_inst const &p) { auto lv = declare(*p.loop_var()); auto lv_ty = visit(*this, *p.loop_var()->ty()); - auto from = visit(*this, *p.from()); - auto to = visit(*this, *p.to()); + auto from = val(*p.from()); + auto to = val(*p.to()); auto bb = clir::block_builder{}; auto sg = bb.declare_assign(clir::generic_uint(), "sg", clir::get_sub_group_id()); auto m = bb.declare_assign(clir::generic_uint(), "m", clir::get_sub_group_local_id()); @@ -817,14 +791,14 @@ std::vector convert_to_opencl_pass::operator()(hadamard_inst const & auto &bdv = get_dope_vector(g.B().get()); auto &cdv = get_dope_vector(g.C().get()); - auto alpha = visit(*this, *g.alpha()); - auto beta = visit(*this, *g.beta()); + auto alpha = val(*g.alpha()); + auto beta = val(*g.beta()); auto alpha_ty = get_scalar_type(*g.alpha()); auto beta_ty = get_scalar_type(*g.beta()); - auto A = visit(*this, *g.A()); - auto B = visit(*this, *g.B()); - auto C = visit(*this, *g.C()); + auto A = val(*g.A()); + auto B = val(*g.B()); + auto C = val(*g.C()); auto bb = clir::block_builder{}; auto sg = bb.declare_assign(clir::generic_uint(), "sg", clir::get_sub_group_id()); @@ -866,7 +840,7 @@ std::vector convert_to_opencl_pass::operator()(if_inst const &in) { clinst.emplace_back(clir::declaration(visit(*this, *r->ty()), v)); yielded_vars_.back().emplace_back(std::move(v)); } - auto ib = clir::if_selection_builder(visit(*this, *in.condition())); + auto ib = clir::if_selection_builder(val(*in.condition())); ib.set_then(run_on_region(in.then())); if (in.has_otherwise()) { ib.set_otherwise(run_on_region(in.otherwise())); @@ -921,7 +895,7 @@ std::vector convert_to_opencl_pass::operator()(subview_inst const &s auto t = get_memref_type(*s.operand()); auto &dv = get_dope_vector(s.operand().get()); - auto rhs = visit(*this, *s.operand()); + auto rhs = val(*s.operand()); int j = 0; auto shape_out = std::vector{}; auto stride_out = std::vector{}; @@ -934,7 +908,7 @@ std::vector convert_to_opencl_pass::operator()(subview_inst const &s auto offset_cl = clir::expr{}; if (is_dynamic_value(offset)) { - offset_cl = visit(*this, *dyn_offsets[joffset++]); + offset_cl = val(*dyn_offsets[joffset++]); } else { offset_cl = clir::expr(offset, static_cast(tinytc::size(scalar_type::index) * 8)); @@ -945,7 +919,7 @@ std::vector convert_to_opencl_pass::operator()(subview_inst const &s if (size > 0 || is_dynamic_value(size)) { auto size_cl = clir::expr{}; if (is_dynamic_value(size)) { - size_cl = visit(*this, *dyn_sizes[jsize++]); + size_cl = val(*dyn_sizes[jsize++]); } else { size_cl = clir::expr(size, static_cast(tinytc::size(scalar_type::index) * 8)); @@ -978,13 +952,13 @@ std::vector convert_to_opencl_pass::operator()(store_inst const &s) throw compilation_error(s.loc(), status::ir_invalid_number_of_indices); } - auto lhs = visit(*this, *s.operand()); + auto lhs = val(*s.operand()); auto &dv = get_dope_vector(s.operand().get()); for (std::int64_t i = 0; i < ot->dim(); ++i) { - lhs = lhs + visit(*this, *s.index_list()[i]) * dv.stride(i); + lhs = lhs + val(*s.index_list()[i]) * dv.stride(i); } - auto rhs = visit(*this, *s.val()); + auto rhs = val(*s.val()); auto st = assignment(dereference(std::move(lhs)), std::move(rhs)); return {expression_statement(std::move(st))}; } @@ -995,15 +969,15 @@ std::vector convert_to_opencl_pass::operator()(sum_inst const &inst) auto &adv = get_dope_vector(inst.A().get()); auto &bdv = get_dope_vector(inst.B().get()); - auto alpha = visit(*this, *inst.alpha()); - auto beta = visit(*this, *inst.beta()); + auto alpha = val(*inst.alpha()); + auto beta = val(*inst.beta()); auto alpha_ty = get_scalar_type(*inst.alpha()); auto beta_ty = get_scalar_type(*inst.beta()); auto zero = clir::expr(0.0, static_cast(size(at->element_ty()) * 8)); - auto A = visit(*this, *inst.A()); - auto B = visit(*this, *inst.B()); + auto A = val(*inst.A()); + auto B = val(*inst.B()); auto bb = clir::block_builder{}; auto acc = bb.declare_assign(to_clir_ty(at->element_ty()), "acc", std::move(zero)); auto sg = bb.declare_assign(clir::generic_uint(), "sg", clir::get_sub_group_id()); @@ -1083,8 +1057,8 @@ std::vector convert_to_opencl_pass::operator()(yield_inst const &in) } std::vector clinst; for (std::int64_t i = 0; i < in.num_operands(); ++i) { - clinst.push_back(clir::expression_statement( - clir::assignment(yielded_vars_.back()[i], visit(*this, *in.op(i))))); + clinst.push_back( + clir::expression_statement(clir::assignment(yielded_vars_.back()[i], val(*in.op(i))))); } return clinst; } diff --git a/src/pass/convert_to_opencl.hpp b/src/pass/convert_to_opencl.hpp index f3b56d39..c71103cb 100644 --- a/src/pass/convert_to_opencl.hpp +++ b/src/pass/convert_to_opencl.hpp @@ -65,11 +65,6 @@ class convert_to_opencl_pass { clir::data_type operator()(memref_data_type const &m); clir::data_type operator()(scalar_data_type const &s); - /* Var nodes */ - clir::expr operator()(float_imm const &v); - clir::expr operator()(int_imm const &v); - clir::expr operator()(val const &v); - /* Inst nodes */ std::vector operator()(alloca_inst const &a); std::vector operator()(axpby_inst const &a); @@ -108,6 +103,7 @@ class convert_to_opencl_pass { private: auto run_on_region(region_node const ®) -> clir::stmt; auto run_on_function(function_node const &fn) -> clir::func; + auto val(value_node const &v) -> clir::expr; auto get_dope_vector(value_node *v) -> dope_vector &; void set_dope_vector(value_node *v, dope_vector dv); diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 5a900812..5e15c703 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -47,19 +47,7 @@ void dump_ir_pass::operator()(memref_data_type const &d) { void dump_ir_pass::operator()(scalar_data_type const &s) { *os_ << to_string(s.ty()); } /* Value nodes */ -void dump_ir_pass::operator()(float_imm const &v) { - auto flags = os_->flags(); - *os_ << std::hexfloat << v.value(); - os_->flags(flags); -} -void dump_ir_pass::operator()(int_imm const &v) { - if (is_dynamic_value(v.value())) { - *os_ << "?"; - } else { - *os_ << v.value(); - } -} -void dump_ir_pass::operator()(val const &v) { +void dump_ir_pass::dump_val(value_node const &v) { *os_ << "%" << v.name(); auto const slot = tracker_.get_slot(v); if (slot >= 0) { @@ -69,13 +57,13 @@ void dump_ir_pass::operator()(val const &v) { /* Inst nodes */ void dump_ir_pass::dump_blas_a2(blas_a2_inst const &g) { - visit(*this, *g.alpha()); + dump_val(*g.alpha()); *os_ << ", "; - visit(*this, *g.A()); + dump_val(*g.A()); *os_ << ", "; - visit(*this, *g.beta()); + dump_val(*g.beta()); *os_ << ", "; - visit(*this, *g.B()); + dump_val(*g.B()); *os_ << " : "; visit(*this, *g.alpha()->ty()); *os_ << ", "; @@ -87,15 +75,15 @@ void dump_ir_pass::dump_blas_a2(blas_a2_inst const &g) { } void dump_ir_pass::dump_blas_a3(blas_a3_inst const &g) { - visit(*this, *g.alpha()); + dump_val(*g.alpha()); *os_ << ", "; - visit(*this, *g.A()); + dump_val(*g.A()); *os_ << ", "; - visit(*this, *g.B()); + dump_val(*g.B()); *os_ << ", "; - visit(*this, *g.beta()); + dump_val(*g.beta()); *os_ << ", "; - visit(*this, *g.C()); + dump_val(*g.C()); *os_ << " : "; visit(*this, *g.alpha()->ty()); *os_ << ", "; @@ -109,7 +97,7 @@ void dump_ir_pass::dump_blas_a3(blas_a3_inst const &g) { } void dump_ir_pass::operator()(alloca_inst const &a) { - visit(*this, *a.result()); + dump_val(*a.result()); *os_ << " = alloca -> "; visit(*this, *a.result()->ty()); } @@ -121,19 +109,19 @@ void dump_ir_pass::operator()(axpby_inst const &a) { } void dump_ir_pass::operator()(arith_inst const &a) { - visit(*this, *a.result()); + dump_val(*a.result()); *os_ << " = arith." << to_string(a.operation()) << " "; - visit(*this, *a.a()); + dump_val(*a.a()); *os_ << ", "; - visit(*this, *a.b()); + dump_val(*a.b()); *os_ << " : "; visit(*this, *a.a()->ty()); } void dump_ir_pass::operator()(arith_unary_inst const &a) { - visit(*this, *a.result()); + dump_val(*a.result()); *os_ << " = arith." << to_string(a.operation()) << " "; - visit(*this, *a.a()); + dump_val(*a.a()); *os_ << " : "; visit(*this, *a.a()->ty()); } @@ -149,9 +137,9 @@ void dump_ir_pass::operator()(barrier_inst const &b) { } void dump_ir_pass::operator()(cast_inst const &c) { - visit(*this, *c.result()); + dump_val(*c.result()); *os_ << " = cast "; - visit(*this, *c.a()); + dump_val(*c.a()); *os_ << " : "; visit(*this, *c.a()->ty()); *os_ << " -> "; @@ -159,17 +147,17 @@ void dump_ir_pass::operator()(cast_inst const &c) { } void dump_ir_pass::operator()(compare_inst const &a) { - visit(*this, *a.result()); + dump_val(*a.result()); *os_ << " = cmp." << to_string(a.cond()) << " "; - visit(*this, *a.a()); + dump_val(*a.a()); *os_ << ", "; - visit(*this, *a.b()); + dump_val(*a.b()); *os_ << " : "; visit(*this, *a.a()->ty()); } void dump_ir_pass::operator()(constant_inst const &c) { - visit(*this, *c.result()); + dump_val(*c.result()); *os_ << " = constant "; std::visit(overloaded{ [&](std::int64_t i) { @@ -196,9 +184,9 @@ void dump_ir_pass::operator()(constant_inst const &c) { } void dump_ir_pass::operator()(expand_inst const &e) { - visit(*this, *e.result()); + dump_val(*e.result()); *os_ << " = expand "; - visit(*this, *e.operand()); + dump_val(*e.operand()); *os_ << "[" << e.expanded_mode() << "->"; auto const &ses = e.static_expand_shape(); auto es = e.expand_shape(); @@ -207,7 +195,7 @@ void dump_ir_pass::operator()(expand_inst const &e) { *os_ << " x "; } if (is_dynamic_value(ses[i])) { - visit(*this, *es[j++]); + dump_val(*es[j++]); } else { *os_ << ses[i]; } @@ -217,38 +205,38 @@ void dump_ir_pass::operator()(expand_inst const &e) { } void dump_ir_pass::operator()(fuse_inst const &f) { - visit(*this, *f.result()); + dump_val(*f.result()); *os_ << " = fuse "; - visit(*this, *f.operand()); + dump_val(*f.operand()); *os_ << "[" << f.from() << "," << f.to() << "]"; *os_ << " : "; visit(*this, *f.operand()->ty()); } void dump_ir_pass::operator()(load_inst const &e) { - visit(*this, *e.result()); + dump_val(*e.result()); *os_ << " = load "; - visit(*this, *e.operand()); + dump_val(*e.operand()); *os_ << "["; do_with_infix(e.index_list().begin(), e.index_list().end(), - [this](auto const &i) { visit(*this, *i); }); + [this](auto const &i) { dump_val(*i); }); *os_ << "] : "; visit(*this, *e.operand()->ty()); } void dump_ir_pass::operator()(group_id_inst const &g) { - visit(*this, *g.result()); + dump_val(*g.result()); *os_ << " = group_id"; } void dump_ir_pass::operator()(group_size_inst const &g) { - visit(*this, *g.result()); + dump_val(*g.result()); *os_ << " = group_size"; } void dump_ir_pass::operator()(lifetime_stop_inst const &l) { *os_ << "lifetime_stop "; - visit(*this, *l.object()); + dump_val(*l.object()); } void dump_ir_pass::operator()(gemm_inst const &g) { @@ -271,14 +259,14 @@ void dump_ir_pass::operator()(ger_inst const &g) { void dump_ir_pass::operator()(for_inst const &p) { *os_ << "for "; - visit(*this, *p.loop_var()); + dump_val(*p.loop_var()); *os_ << "="; - visit(*this, *p.from()); + dump_val(*p.from()); *os_ << ","; - visit(*this, *p.to()); + dump_val(*p.to()); if (p.step()) { *os_ << ","; - visit(*this, *p.step()); + dump_val(*p.step()); } *os_ << " : "; visit(*this, *p.loop_var()->ty()); @@ -288,11 +276,11 @@ void dump_ir_pass::operator()(for_inst const &p) { void dump_ir_pass::operator()(foreach_inst const &p) { *os_ << "foreach "; - visit(*this, *p.loop_var()); + dump_val(*p.loop_var()); *os_ << "="; - visit(*this, *p.from()); + dump_val(*p.from()); *os_ << ","; - visit(*this, *p.to()); + dump_val(*p.to()); *os_ << " : "; visit(*this, *p.loop_var()->ty()); *os_ << " "; @@ -306,7 +294,7 @@ void dump_ir_pass::operator()(hadamard_inst const &g) { void dump_ir_pass::operator()(if_inst const &in) { *os_ << "if "; - visit(*this, *in.condition()); + dump_val(*in.condition()); *os_ << " "; dump_region(in.then()); if (in.has_otherwise()) { @@ -316,7 +304,7 @@ void dump_ir_pass::operator()(if_inst const &in) { } void dump_ir_pass::operator()(num_subgroups_inst const &sg) { - visit(*this, *sg.result()); + dump_val(*sg.result()); *os_ << " = num_subgroups"; } @@ -326,33 +314,33 @@ void dump_ir_pass::operator()(parallel_inst const &p) { } void dump_ir_pass::operator()(size_inst const &s) { - visit(*this, *s.result()); + dump_val(*s.result()); *os_ << " = size "; - visit(*this, *s.operand()); + dump_val(*s.operand()); *os_ << "[" << s.mode() << "]"; *os_ << " : "; visit(*this, *s.operand()->ty()); } void dump_ir_pass::operator()(subgroup_id_inst const &sg) { - visit(*this, *sg.result()); + dump_val(*sg.result()); *os_ << " = subgroup_id"; } void dump_ir_pass::operator()(subgroup_local_id_inst const &sg) { - visit(*this, *sg.result()); + dump_val(*sg.result()); *os_ << " = subgroup_local_id"; } void dump_ir_pass::operator()(subgroup_size_inst const &sg) { - visit(*this, *sg.result()); + dump_val(*sg.result()); *os_ << " = subgroup_size"; } void dump_ir_pass::operator()(subview_inst const &s) { - visit(*this, *s.result()); + dump_val(*s.result()); *os_ << " = subview "; - visit(*this, *s.operand()); + dump_val(*s.operand()); *os_ << "["; auto dyn_offsets = s.offsets(); auto dyn_sizes = s.sizes(); @@ -362,7 +350,7 @@ void dump_ir_pass::operator()(subview_inst const &s) { } auto offset = s.static_offsets()[i]; if (is_dynamic_value(offset)) { - visit(*this, *dyn_offsets[joffset++]); + dump_val(*dyn_offsets[joffset++]); } else { *os_ << offset; } @@ -370,7 +358,7 @@ void dump_ir_pass::operator()(subview_inst const &s) { if (size > 0 || is_dynamic_value(size)) { *os_ << ":"; if (is_dynamic_value(size)) { - visit(*this, *dyn_sizes[jsize++]); + dump_val(*dyn_sizes[jsize++]); } else { *os_ << size; } @@ -385,12 +373,12 @@ void dump_ir_pass::operator()(subview_inst const &s) { void dump_ir_pass::operator()(store_inst const &e) { *os_ << "store "; - visit(*this, *e.val()); + dump_val(*e.val()); *os_ << ", "; - visit(*this, *e.operand()); + dump_val(*e.operand()); *os_ << "["; do_with_infix(e.index_list().begin(), e.index_list().end(), - [this](auto const &i) { visit(*this, *i); }); + [this](auto const &i) { dump_val(*i); }); *os_ << "] : "; visit(*this, *e.operand()->ty()); } @@ -404,7 +392,7 @@ void dump_ir_pass::operator()(sum_inst const &a) { void dump_ir_pass::operator()(yield_inst const &y) { *os_ << "yield "; if (y.num_operands() > 0) { - do_with_infix(y.op_begin(), y.op_end(), [this](auto const &i) { visit(*this, *i); }, ", "); + do_with_infix(y.op_begin(), y.op_end(), [this](auto const &i) { dump_val(*i); }, ", "); *os_ << " : "; do_with_infix( y.op_begin(), y.op_end(), [this](auto const &i) { visit(*this, *i->ty()); }, ", "); @@ -437,7 +425,7 @@ void dump_ir_pass::run_on_function(function_node const &fn) { do_with_infix( fn.args().begin(), fn.args().end(), [this](auto const &a) { - visit(*this, *a); + dump_val(*a); *os_ << ": "; visit(*this, *a->ty()); }, diff --git a/src/pass/dump_ir.hpp b/src/pass/dump_ir.hpp index a5a037c6..a6f070a9 100644 --- a/src/pass/dump_ir.hpp +++ b/src/pass/dump_ir.hpp @@ -28,11 +28,6 @@ class dump_ir_pass { void operator()(memref_data_type const &m); void operator()(scalar_data_type const &s); - /* Var nodes */ - void operator()(float_imm const &v); - void operator()(int_imm const &v); - void operator()(val const &v); - /* Inst nodes */ void operator()(alloca_inst const &a); void operator()(axpby_inst const &a); @@ -74,6 +69,7 @@ class dump_ir_pass { void dump_region(region_node const ®); void dump_blas_a2(blas_a2_inst const &g); void dump_blas_a3(blas_a3_inst const &g); + void dump_val(value_node const &v); template void do_with_infix(Iterator begin, Iterator end, Action a, std::string const &infix = ",") { diff --git a/src/value.cpp b/src/value.cpp index 016311dd..251bfca9 100644 --- a/src/value.cpp +++ b/src/value.cpp @@ -23,7 +23,7 @@ tinytc_status_t tinytc_value_create(tinytc_value_t *vl, tinytc_data_type_t type, return tinytc_status_invalid_arguments; } return exception_to_status_code( - [&] { *vl = std::make_unique(type, get_optional(lc)).release(); }); + [&] { *vl = std::make_unique(type, get_optional(lc)).release(); }); } tinytc_status_t tinytc_value_release(tinytc_value_t obj) { From 2cd2276df0416b20103dc74d7423d7d5a70c2e22 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 30 Sep 2024 14:41:41 +0200 Subject: [PATCH 031/297] Move region to stack of tinytc_inst Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.rst | 117 ++++--- docs/api/builder_capi.yaml | 22 +- docs/api/builder_cxxapi.rst | 128 ++++--- docs/api/builder_cxxapi.yaml | 16 +- docs/api/core_cxxapi.rst | 7 + docs/api/core_cxxapi.yaml | 1 + include/tinytc/tinytc.h | 135 +++++--- include/tinytc/tinytc.hpp | 556 +++++++++++++++--------------- src/analysis/cfg.cpp | 27 +- src/codegen_tools.cpp | 41 ++- src/codegen_tools.hpp | 4 +- src/func.cpp | 33 +- src/inst.cpp | 76 ++-- src/node/function_node.hpp | 39 +-- src/node/inst_node.cpp | 57 +-- src/node/inst_node.hpp | 78 ++--- src/node/program_node.cpp | 6 - src/node/program_node.hpp | 33 +- src/node/region_node.cpp | 25 +- src/node/region_node.hpp | 31 +- src/parser/parse_context.cpp | 10 + src/parser/parse_context.hpp | 16 +- src/parser/parser_impl.yy | 201 ++++++----- src/pass/check_ir.cpp | 2 +- src/pass/convert_to_opencl.cpp | 12 +- src/pass/dump_cfg.cpp | 2 +- src/pass/dump_cfg.hpp | 2 +- src/pass/dump_ir.cpp | 4 +- src/pass/insert_lifetime_stop.cpp | 2 +- src/pass/slot_tracker.cpp | 2 +- src/passes.hpp | 8 +- src/prog.cpp | 6 +- src/recipe/small_gemm_batched.cpp | 89 ++--- src/recipe/tall_and_skinny.cpp | 63 ++-- src/region.cpp | 38 +- src/scalar_type.cpp | 15 + src/scalar_type.hpp | 1 + src/support/ilist.hpp | 7 +- src/support/ilist_base.hpp | 61 ++-- src/support/util.hpp | 95 ++++- src/support/walk.cpp | 2 +- src/support/walk.hpp | 8 +- src/value.cpp | 7 + test/codegen/if.ir | 2 - test/opt/check-ir/nesting1.ir | 2 +- test/opt/check-ir/nesting3.ir | 2 +- 46 files changed, 1203 insertions(+), 888 deletions(-) diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index 45f5ef17..c2375961 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -270,31 +270,38 @@ Function * Functions - * :ref:`tinytc_function_create` + * :ref:`tinytc_func_create` - * :ref:`tinytc_function_set_subgroup_size` + * :ref:`tinytc_func_set_subgroup_size` - * :ref:`tinytc_function_set_work_group_size` + * :ref:`tinytc_func_set_work_group_size` + + * :ref:`tinytc_func_get_body` * :ref:`tinytc_func_destroy` Function Functions ------------------ -tinytc_function_create -...................... +tinytc_func_create +.................. -.. doxygenfunction:: tinytc_function_create +.. doxygenfunction:: tinytc_func_create -tinytc_function_set_subgroup_size -................................. +tinytc_func_set_subgroup_size +............................. -.. doxygenfunction:: tinytc_function_set_subgroup_size +.. doxygenfunction:: tinytc_func_set_subgroup_size -tinytc_function_set_work_group_size -................................... +tinytc_func_set_work_group_size +............................... -.. doxygenfunction:: tinytc_function_set_work_group_size +.. doxygenfunction:: tinytc_func_set_work_group_size + +tinytc_func_get_body +.................... + +.. doxygenfunction:: tinytc_func_get_body tinytc_func_destroy ................... @@ -318,6 +325,12 @@ Instruction * :ref:`tinytc_cmp_inst_create` + * :ref:`tinytc_constant_inst_create_complex` + + * :ref:`tinytc_constant_inst_create_float` + + * :ref:`tinytc_constant_inst_create_int` + * :ref:`tinytc_expand_inst_create` * :ref:`tinytc_for_inst_create` @@ -362,6 +375,10 @@ Instruction * :ref:`tinytc_yield_inst_create` + * :ref:`tinytc_inst_get_region` + + * :ref:`tinytc_inst_get_regions` + * :ref:`tinytc_inst_get_value` * :ref:`tinytc_inst_get_values` @@ -401,6 +418,21 @@ tinytc_cmp_inst_create .. doxygenfunction:: tinytc_cmp_inst_create +tinytc_constant_inst_create_complex +................................... + +.. doxygenfunction:: tinytc_constant_inst_create_complex + +tinytc_constant_inst_create_float +................................. + +.. doxygenfunction:: tinytc_constant_inst_create_float + +tinytc_constant_inst_create_int +............................... + +.. doxygenfunction:: tinytc_constant_inst_create_int + tinytc_expand_inst_create ......................... @@ -511,6 +543,16 @@ tinytc_yield_inst_create .. doxygenfunction:: tinytc_yield_inst_create +tinytc_inst_get_region +...................... + +.. doxygenfunction:: tinytc_inst_get_region + +tinytc_inst_get_regions +....................... + +.. doxygenfunction:: tinytc_inst_get_regions + tinytc_inst_get_value ..................... @@ -531,7 +573,7 @@ Program * Functions - * :ref:`tinytc_program_create` + * :ref:`tinytc_prog_create` * :ref:`tinytc_prog_add_function` @@ -550,10 +592,10 @@ Program Program Functions ----------------- -tinytc_program_create -..................... +tinytc_prog_create +.................. -.. doxygenfunction:: tinytc_program_create +.. doxygenfunction:: tinytc_prog_create tinytc_prog_add_function ........................ @@ -595,45 +637,43 @@ Region * Functions - * :ref:`tinytc_region_create` - * :ref:`tinytc_region_add_instruction` - * :ref:`tinytc_region_destroy` + * :ref:`tinytc_region_get_parameter` + + * :ref:`tinytc_region_get_parameters` Region Functions ---------------- -tinytc_region_create -.................... - -.. doxygenfunction:: tinytc_region_create - tinytc_region_add_instruction ............................. .. doxygenfunction:: tinytc_region_add_instruction -tinytc_region_destroy -..................... +tinytc_region_get_parameter +........................... + +.. doxygenfunction:: tinytc_region_get_parameter + +tinytc_region_get_parameters +............................ -.. doxygenfunction:: tinytc_region_destroy +.. doxygenfunction:: tinytc_region_get_parameters Value ===== * Functions - * :ref:`tinytc_float_imm_create` - - * :ref:`tinytc_int_imm_create` - * :ref:`tinytc_value_create` * :ref:`tinytc_value_get_name` * :ref:`tinytc_value_set_name` + * :ref:`tinytc_value_set_name_n` + * :ref:`tinytc_value_release` * :ref:`tinytc_value_retain` @@ -641,16 +681,6 @@ Value Value Functions --------------- -tinytc_float_imm_create -....................... - -.. doxygenfunction:: tinytc_float_imm_create - -tinytc_int_imm_create -..................... - -.. doxygenfunction:: tinytc_int_imm_create - tinytc_value_create ................... @@ -666,6 +696,11 @@ tinytc_value_set_name .. doxygenfunction:: tinytc_value_set_name +tinytc_value_set_name_n +....................... + +.. doxygenfunction:: tinytc_value_set_name_n + tinytc_value_release .................... diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index c554591c..9ca7ec40 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -43,9 +43,10 @@ Builder C-API: - tinytc_scalar_type_get Function: function: - - tinytc_function_create - - tinytc_function_set_subgroup_size - - tinytc_function_set_work_group_size + - tinytc_func_create + - tinytc_func_set_subgroup_size + - tinytc_func_set_work_group_size + - tinytc_func_get_body - tinytc_func_destroy Instruction: function: @@ -55,9 +56,9 @@ Builder C-API: - tinytc_arith_unary_inst_create - tinytc_cast_inst_create - tinytc_cmp_inst_create - - tinyc_constant_inst_create_complex - - tinyc_constant_inst_create_float - - tinyc_constant_inst_create_int + - tinytc_constant_inst_create_complex + - tinytc_constant_inst_create_float + - tinytc_constant_inst_create_int - tinytc_expand_inst_create - tinytc_for_inst_create - tinytc_foreach_inst_create @@ -80,12 +81,14 @@ Builder C-API: - tinytc_subview_inst_create - tinytc_sum_inst_create - tinytc_yield_inst_create + - tinytc_inst_get_region + - tinytc_inst_get_regions - tinytc_inst_get_value - tinytc_inst_get_values - tinytc_inst_destroy Program: function: - - tinytc_program_create + - tinytc_prog_create - tinytc_prog_add_function - tinytc_prog_dump - tinytc_prog_get_compiler_context @@ -95,13 +98,14 @@ Builder C-API: - tinytc_prog_retain Region: function: - - tinytc_region_create - tinytc_region_add_instruction - - tinytc_region_destroy + - tinytc_region_get_parameter + - tinytc_region_get_parameters Value: function: - tinytc_value_create - tinytc_value_get_name - tinytc_value_set_name + - tinytc_value_set_name_n - tinytc_value_release - tinytc_value_retain diff --git a/docs/api/builder_cxxapi.rst b/docs/api/builder_cxxapi.rst index 3888ccd7..e7e72d3e 100644 --- a/docs/api/builder_cxxapi.rst +++ b/docs/api/builder_cxxapi.rst @@ -217,35 +217,19 @@ Function * Functions - * :ref:`make_function` - - * :ref:`set_work_group_size` - - * :ref:`set_subgroup_size` + * :ref:`make_func` * Classes * :ref:`func` - * :ref:`function_builder` - Function Functions ------------------ -make_function -............. - -.. doxygenfunction:: tinytc::make_function - -set_work_group_size -................... - -.. doxygenfunction:: tinytc::set_work_group_size - -set_subgroup_size -................. +make_func +......... -.. doxygenfunction:: tinytc::set_subgroup_size +.. doxygenfunction:: tinytc::make_func Function Classes ---------------- @@ -255,11 +239,6 @@ func .. doxygenclass:: tinytc::func -function_builder -................ - -.. doxygenclass:: tinytc::function_builder - Instruction =========== @@ -277,6 +256,14 @@ Instruction * :ref:`make_cmp` + * :ref:`make_constant(std::complex\,tinytc_data_type_t,location const&)` + + * :ref:`make_constant(double,tinytc_data_type_t,location const&)` + + * :ref:`make_constant(std::int32_t,tinytc_data_type_t,location const&)` + + * :ref:`make_constant(std::int64_t,tinytc_data_type_t,location const&)` + * :ref:`make_expand` * :ref:`make_for` @@ -358,6 +345,26 @@ make_cmp .. doxygenfunction:: tinytc::make_cmp +make_constant(std::complex,tinytc_data_type_t,location const&) +...................................................................... + +.. doxygenfunction:: tinytc::make_constant(std::complex,tinytc_data_type_t,location const&) + +make_constant(double,tinytc_data_type_t,location const&) +........................................................ + +.. doxygenfunction:: tinytc::make_constant(double,tinytc_data_type_t,location const&) + +make_constant(std::int32_t,tinytc_data_type_t,location const&) +.............................................................. + +.. doxygenfunction:: tinytc::make_constant(std::int32_t,tinytc_data_type_t,location const&) + +make_constant(std::int64_t,tinytc_data_type_t,location const&) +.............................................................. + +.. doxygenfunction:: tinytc::make_constant(std::int64_t,tinytc_data_type_t,location const&) + make_expand ........... @@ -481,21 +488,19 @@ Program * Functions - * :ref:`make_program` + * :ref:`make_prog` * Classes * :ref:`prog` - * :ref:`program_builder` - Program Functions ----------------- -make_program -............ +make_prog +......... -.. doxygenfunction:: tinytc::make_program +.. doxygenfunction:: tinytc::make_prog Program Classes --------------- @@ -505,39 +510,48 @@ prog .. doxygenclass:: tinytc::prog -program_builder -............... - -.. doxygenclass:: tinytc::program_builder - Region ====== * Functions - * :ref:`make_region` + * :ref:`add_instruction` -* Classes + * :ref:`get_num_parameters` + + * :ref:`get_parameter` - * :ref:`region` + * :ref:`get_parameters` + +* Classes * :ref:`region_builder` Region Functions ---------------- -make_region -........... +add_instruction +............... -.. doxygenfunction:: tinytc::make_region +.. doxygenfunction:: tinytc::add_instruction -Region Classes --------------- +get_num_parameters +.................. -region -...... +.. doxygenfunction:: tinytc::get_num_parameters -.. doxygenclass:: tinytc::region +get_parameter +............. + +.. doxygenfunction:: tinytc::get_parameter + +get_parameters +.............. + +.. doxygenfunction:: tinytc::get_parameters + +Region Classes +-------------- region_builder .............. @@ -549,12 +563,12 @@ Value * Functions - * :ref:`make_fimm` - - * :ref:`make_imm` + * :ref:`get_name` * :ref:`make_value` + * :ref:`set_name` + * Classes * :ref:`value` @@ -562,21 +576,21 @@ Value Value Functions --------------- -make_fimm -......... - -.. doxygenfunction:: tinytc::make_fimm - -make_imm +get_name ........ -.. doxygenfunction:: tinytc::make_imm +.. doxygenfunction:: tinytc::get_name make_value .......... .. doxygenfunction:: tinytc::make_value +set_name +........ + +.. doxygenfunction:: tinytc::set_name + Value Classes ------------- diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index 98ebeeb7..3f5bd53e 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -36,12 +36,9 @@ Builder C++-API: - tinytc::to_scalar_type_v Function: function: - - tinytc::make_function - - tinytc::set_work_group_size - - tinytc::set_subgroup_size + - tinytc::make_func class: - tinytc::func - - tinytc::function_builder Instruction: function: - tinytc::make_alloca @@ -80,18 +77,21 @@ Builder C++-API: - tinytc::inst Program: function: - - tinytc::make_program + - tinytc::make_prog class: - tinytc::prog - - tinytc::program_builder Region: function: - - tinytc::make_region + - tinytc::add_instruction + - tinytc::get_num_parameters + - tinytc::get_parameter + - tinytc::get_parameters class: - - tinytc::region - tinytc::region_builder Value: function: + - tinytc::get_name - tinytc::make_value + - tinytc::set_name class: - tinytc::value diff --git a/docs/api/core_cxxapi.rst b/docs/api/core_cxxapi.rst index f7c60e0c..76284917 100644 --- a/docs/api/core_cxxapi.rst +++ b/docs/api/core_cxxapi.rst @@ -24,6 +24,8 @@ Common * Classes + * :ref:`array_view` + * :ref:`shared_handle` * :ref:`unique_handle` @@ -66,6 +68,11 @@ CHECK_STATUS_LOC Common Classes -------------- +array_view +.......... + +.. doxygenclass:: tinytc::array_view + shared_handle ............. diff --git a/docs/api/core_cxxapi.yaml b/docs/api/core_cxxapi.yaml index 13355dd6..f5e7e09c 100644 --- a/docs/api/core_cxxapi.yaml +++ b/docs/api/core_cxxapi.yaml @@ -10,6 +10,7 @@ Core C++-API: - tinytc::CHECK_STATUS - tinytc::CHECK_STATUS_LOC class: + - tinytc::array_view - tinytc::shared_handle - tinytc::unique_handle typedef: diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index e53b46bc..016ddd5e 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -146,6 +146,18 @@ TINYTC_EXPORT tinytc_status_t tinytc_value_retain(tinytc_value_t vl); */ TINYTC_EXPORT tinytc_status_t tinytc_value_set_name(tinytc_value_t vl, char const *name); +/** + * @brief Set name of value with explicit number of characters + * + * @param vl [inout] value object + * @param name_length [in] number of characters + * @param name [in] name; not necessarily null-terminated + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_value_set_name_n(tinytc_value_t vl, uint32_t name_length, + char const *name); + /** * @brief Get name of value * @@ -538,13 +550,11 @@ TINYTC_EXPORT tinytc_status_t tinytc_num_subgroups_inst_create(tinytc_inst_t *in * @endcode * * @param instr [out] pointer to the inst object created - * @param body [in,pass_ownership] loop body * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_parallel_inst_create(tinytc_inst_t *instr, - tinytc_region_t body, const tinytc_location_t *loc); /** @@ -690,18 +700,17 @@ TINYTC_EXPORT tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinyt * @endcode * * @param instr [out] pointer to the inst object created - * @param loop_var [in] loop variable * @param from [in] loop begion * @param to [in] loop bound * @param step [in][optional] loop step; can be nullptr - * @param body [in,pass_ownership] loop body + * @param loop_var_type [in] type of loop variable * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_value_t loop_var, - tinytc_value_t from, tinytc_value_t to, - tinytc_value_t step, tinytc_region_t body, +TINYTC_EXPORT tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_value_t from, + tinytc_value_t to, tinytc_value_t step, + tinytc_data_type_t loop_var_type, const tinytc_location_t *loc); /** @@ -716,18 +725,16 @@ TINYTC_EXPORT tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinyt * @endcode * * @param instr [out] pointer to the inst object created - * @param loop_var [in] loop variable * @param from [in] loop begion * @param to [in] loop bound - * @param body [in,pass_ownership] loop body + * @param loop_var_type [in] type of loop variable * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, - tinytc_value_t loop_var, - tinytc_value_t from, tinytc_value_t to, - tinytc_region_t body, +TINYTC_EXPORT tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, tinytc_value_t from, + tinytc_value_t to, + tinytc_data_type_t loop_var_type, const tinytc_location_t *loc); /** @@ -741,8 +748,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, * * @param instr [out] pointer to the inst object created * @param condition [in] condition - * @param then [in,pass_ownership] region taken if condition is true - * @param otherwise [in,pass_ownership][optional] region taken if condition is false; can be nullptr * @param return_type_list_size [in] length of return type array * @param return_type_list [in][range(0, return_type_list_size)] return type array; can be nullptr * if return_type_list_size is 0 @@ -751,7 +756,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condition, - tinytc_region_t then, tinytc_region_t otherwise, uint32_t return_type_list_size, tinytc_data_type_t *return_type_list, const tinytc_location_t *loc); @@ -811,20 +815,38 @@ TINYTC_EXPORT tinytc_status_t tinytc_inst_get_values(const_tinytc_inst_t instr, uint32_t *result_list_size, tinytc_value_t *result_list); -//////////////////////////// -////////// Region ////////// -//////////////////////////// +/** + * @brief Get child region of instruction + * + * @param instr [in] inst object + * @param region_no [in] region index + * @param result [out] result value + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_inst_get_region(tinytc_inst_t instr, uint32_t region_no, + tinytc_region_t *result); /** - * @brief Create region + * @brief Get child regions of instruction * - * @param reg [out] pointer to the region object created - * @param loc [in][optional] Source code location; can be nullptr + * Function can be called with result_list_size = 0 and result_list = nullptr in order to obtain + * the number of results + * + * @param instr [in] inst object + * @param result_list_size [inout] number of results to fetch; is updated with the actual value + * @param result_list [out][range(0, result_list_size)] user-provided memory for storing result + * handles; at most result_list_size values are written; can be nullptr if result_list_size is 0 * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_region_create(tinytc_region_t *reg, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_inst_get_regions(tinytc_inst_t instr, + uint32_t *result_list_size, + tinytc_region_t *result_list); + +//////////////////////////// +////////// Region ////////// +//////////////////////////// /** * @brief Append instruction to region @@ -841,11 +863,33 @@ TINYTC_EXPORT tinytc_status_t tinytc_region_add_instruction(tinytc_region_t reg, tinytc_inst_t instruction); /** - * @brief Delete region object + * @brief Get region parameter * - * @param reg [inout] region object + * @param reg [in] region object + * @param param_no [in] parameter index + * @param result [out] result value + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_region_get_parameter(tinytc_region_t reg, uint32_t param_no, + tinytc_value_t *result); + +/** + * @brief Get region parameters + * + * Function can be called with result_list_size = 0 and result_list = nullptr in order to obtain + * the number of results + * + * @param reg [in] region object + * @param result_list_size [inout] number of results to fetch; is updated with the actual value + * @param result_list [out][range(0, result_list_size)] user-provided memory for storing result + * handles; at most result_list_size values are written; can be nullptr if result_list_size is 0 + * + * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT void tinytc_region_destroy(tinytc_region_t reg); +TINYTC_EXPORT tinytc_status_t tinytc_region_get_parameters(tinytc_region_t reg, + uint32_t *result_list_size, + tinytc_value_t *result_list); //////////////////////////// /////////// Func /////////// @@ -857,18 +901,20 @@ TINYTC_EXPORT void tinytc_region_destroy(tinytc_region_t reg); * Function takes ownership of region. * * @param fun [out] pointer to the func object created + * @param name_length [in] length of function_name * @param name [in] function name - * @param arg_list_size [in] length of argument array - * @param arg_list [in][range(0,arg_list_size)] argument array; can be nullptr if arg_list_size is 0 - * @param body [in,pass_ownership] function body + * @param num_params [in] number of parameters + * @param param_type_list [in][range(0,num_params)] parameter data types; can be nullptr if + * num_params is 0 * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_function_create(tinytc_func_t *fun, char const *name, - uint32_t arg_list_size, - tinytc_value_t *arg_list, tinytc_region_t body, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_func_create(tinytc_func_t *fun, uint32_t name_length, + char const *name, uint32_t num_params, + tinytc_data_type_t *param_type_list, + const tinytc_location_t *loc); + /** * @brief Set work-group size * @@ -878,8 +924,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_function_create(tinytc_func_t *fun, char co * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_function_set_work_group_size(tinytc_func_t fun, int32_t x, - int32_t y); +TINYTC_EXPORT tinytc_status_t tinytc_func_set_work_group_size(tinytc_func_t fun, int32_t x, + int32_t y); /** * @brief Set subgroup size * @@ -888,7 +934,17 @@ TINYTC_EXPORT tinytc_status_t tinytc_function_set_work_group_size(tinytc_func_t * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_function_set_subgroup_size(tinytc_func_t fun, int32_t sgs); +TINYTC_EXPORT tinytc_status_t tinytc_func_set_subgroup_size(tinytc_func_t fun, int32_t sgs); + +/** + * @brief Get function body + * + * @param fun [in] function object + * @param body [out] pointer to body region + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_func_get_body(tinytc_func_t fun, tinytc_region_t *body); /** * @brief Delete function object @@ -912,9 +968,8 @@ TINYTC_EXPORT void tinytc_func_destroy(tinytc_func_t fun); * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_program_create(tinytc_prog_t *prg, - tinytc_compiler_context_t ctx, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_prog_create(tinytc_prog_t *prg, tinytc_compiler_context_t ctx, + const tinytc_location_t *loc); /** * @brief Append function to program diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 68c555e3..c311c7c4 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -7,12 +7,12 @@ #include "tinytc/tinytc.h" #include "tinytc/types.hpp" +#include #include #include #include #include #include -#include #include #include #include @@ -281,6 +281,82 @@ template class unique_handle { T obj_; }; +//////////////////////////// +//////// Array view //////// +//////////////////////////// + +/** + * @brief Stores a view on an array (pointer + size) + * + * @tparam T array element type + */ +template class array_view { + public: + using const_iterator = T const *; + + /** + * @brief Empty array view + */ + array_view() = default; + + /** + * @brief Single element view + * + * @param single the single element + */ + array_view(T const &single) : data_{&single}, size_{1} {} + + /** + * @brief ctor + * + * @param data base pointer + * @param size array size + */ + array_view(T const *data, std::size_t size) : data_{data}, size_{size} {} + + /** + * @brief ctor + * + * @param begin begin pointer + * @param end end pointer (not included) + */ + array_view(T const *begin, T const *end) : data_{begin}, size_{end - begin} {} + + /** + * @brief Convert vector to array view + * + * @param vec standard vector + */ + array_view(std::vector const &vec) + : data_{!vec.empty() ? vec.data() : nullptr}, size_{vec.size()} {} + + /** + * @brief Convert std::array to array view + * + * @tparam N array size + * @param arr standard array + */ + template + array_view(std::array const &arr) : data_{arr.data()}, size_{arr.size()} {} + + //! Begin iterator + auto begin() const -> const_iterator { return data_; } + //! End iterator + auto end() const -> const_iterator { return data_ + size_; } + //! Returns true if view is empty + auto empty() const -> bool { return size_ == 0; } + //! Returns array size + auto size() const -> std::size_t { return size_; } + //! Access first element; must not call when array size is 0 + auto front() const -> T const & { return data_[0]; } + //! Access last element; must not call when array size is 0 + auto back() const -> T const & { return data_[size_ - 1]; } + + private: + T const *data_ = nullptr; + std::size_t size_ = 0; +}; + //////////////////////////// ///// Compiler context ///// //////////////////////////// @@ -454,8 +530,8 @@ class value : public shared_handle { * * @param name Name */ - inline void name(std::string const &name) { - CHECK_STATUS(tinytc_value_set_name(obj_, name.c_str())); + inline void name(std::string_view name) { + CHECK_STATUS(tinytc_value_set_name_n(obj_, name.size(), name.data())); } }; @@ -479,6 +555,29 @@ inline auto make_value(tinytc_data_type_t ty, location const &loc = {}) -> value return value{val}; } +/** + * @brief Get name + * + * @param val value object + * + * @return Name as C-string + */ +inline auto get_name(tinytc_value_t val) -> char const * { + char const *name; + CHECK_STATUS(tinytc_value_get_name(val, &name)); + return name; +} + +/** + * @brief Set value name + * + * @param val value object + * @param name Name + */ +inline void set_name(tinytc_value_t val, std::string_view name) { + CHECK_STATUS(tinytc_value_set_name_n(val, name.size(), name.data())); +} + //////////////////////////// /////////// Inst /////////// //////////////////////////// @@ -560,6 +659,17 @@ class inst : public unique_handle { return value{result}; } + /** + * @brief Get number of result values + * + * @return Number of result values + */ + inline auto get_num_values() const -> std::uint32_t { + std::uint32_t result_list_size = 0; + CHECK_STATUS(tinytc_inst_get_values(obj_, &result_list_size, nullptr)); + return result_list_size; + } + /** * @brief Get result values * @@ -567,51 +677,99 @@ class inst : public unique_handle { */ inline auto get_values() const -> std::vector { static_assert(internal::value_reinterpret_allowed); - std::uint32_t result_list_size = 0; - CHECK_STATUS(tinytc_inst_get_values(obj_, &result_list_size, nullptr)); + std::uint32_t result_list_size = get_num_values(); auto values = std::vector(result_list_size); tinytc_value_t *result_list = reinterpret_cast(values.data()); CHECK_STATUS(tinytc_inst_get_values(obj_, &result_list_size, result_list)); return values; } -}; -//////////////////////////// -////////// Region ////////// -//////////////////////////// - -namespace internal { -template <> struct unique_handle_traits { - static void destroy(tinytc_region_t handle) { return tinytc_region_destroy(handle); } -}; -} // namespace internal + /** + * @brief Get child region + * + * @param region_no region index + * + * @return Region + */ + inline auto get_region(std::uint32_t region_no) const -> tinytc_region_t { + tinytc_region_t result; + CHECK_STATUS(tinytc_inst_get_region(obj_, region_no, &result)); + return result; + } -//! @brief Reference-counting wrapper for tinytc_region_t -class region : public unique_handle { - public: - using unique_handle::unique_handle; + /** + * @brief Get number of child regions + * + * @return Number of child regions + */ + inline auto get_num_regions() const -> std::uint32_t { + std::uint32_t result_list_size = 0; + CHECK_STATUS(tinytc_inst_get_regions(obj_, &result_list_size, nullptr)); + return result_list_size; + } /** - * @brief Append instruction to region + * @brief Get child regions * - * @param instruction instruction; region takes ownership + * @return Vector of regions */ - inline void add_instruction(inst instruction) { - CHECK_STATUS(tinytc_region_add_instruction(get(), instruction.release())); + inline auto get_regions() const -> std::vector { + std::uint32_t result_list_size = get_num_regions(); + auto regions = std::vector(result_list_size); + CHECK_STATUS(tinytc_inst_get_regions(obj_, &result_list_size, regions.data())); + return regions; } }; +//////////////////////////// +////////// Region ////////// +//////////////////////////// + /** - * @brief Make region + * @brief Append instruction to region * - * @param loc Source code location + * @param reg region object + * @param instruction instruction object + */ +inline void add_instruction(tinytc_region_t reg, inst instruction) { + CHECK_STATUS(tinytc_region_add_instruction(reg, instruction.release())); +} + +/** + * @brief Get region parameter + * + * @param reg Region object + * @param region_no Region index + * + * @return Parameter + */ +inline auto get_parameter(tinytc_region_t reg, std::uint32_t region_no) -> tinytc_value_t { + tinytc_value_t result; + CHECK_STATUS(tinytc_region_get_parameter(reg, region_no, &result)); + return result; +} + +/** + * @brief Get number of child regions * - * @return Region + * @return Number of child regions */ -inline region make_region(location const &loc = {}) { - tinytc_region_t reg; - CHECK_STATUS_LOC(tinytc_region_create(®, &loc), loc); - return region{reg}; +inline auto get_num_parameters(tinytc_region_t reg) -> std::uint32_t { + std::uint32_t result_list_size = 0; + CHECK_STATUS(tinytc_region_get_parameters(reg, &result_list_size, nullptr)); + return result_list_size; +} + +/** + * @brief Get parameters + * + * @return Vector of parameters + */ +inline auto get_parameters(tinytc_region_t reg) -> std::vector { + std::uint32_t result_list_size = get_num_parameters(reg); + auto params = std::vector(result_list_size); + CHECK_STATUS(tinytc_region_get_parameters(reg, &result_list_size, params.data())); + return params; } //////////////////////////// @@ -691,8 +849,7 @@ inline inst make_cmp(cmp_condition cond, value const &a, value const &b, locatio /** * @brief Make complex constant * - * @param value_re Real part - * @param value_im Imaginary part + * @param value Complex constant * @param ty Data type * @param loc Source code location * @@ -999,14 +1156,13 @@ inline inst make_num_subgroups(compiler_context const &ctx, location const &loc /** * @brief Make parallel region * - * @param body Loop body * @param loc Source code location * * @return Instruction */ -inline inst make_parallel(region body, location const &loc = {}) { +inline inst make_parallel(location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_parallel_inst_create(&instr, body.release(), &loc), loc); + CHECK_STATUS_LOC(tinytc_parallel_inst_create(&instr, &loc), loc); return inst(instr); } @@ -1165,40 +1321,36 @@ inline inst make_sum(transpose tA, bool atomic, value const &alpha, value const /** * @brief Make for loop instruction * - * @param loop_var Loop variable * @param from Loop variable start * @param to Loop variable bound * @param step Loop variable step - * @param body Loop body + * @param loop_var_type Type of loop variable * @param loc Source code location * * @return Instruction */ -inline inst make_for(value const &loop_var, value const &from, value const &to, value const &step, - region body, location const &loc = {}) { +inline inst make_for(value const &from, value const &to, value const &step, + tinytc_data_type_t loop_var_type, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_for_inst_create(&instr, loop_var.get(), from.get(), to.get(), - step.get(), body.release(), &loc), - loc); + CHECK_STATUS_LOC( + tinytc_for_inst_create(&instr, from.get(), to.get(), step.get(), loop_var_type, &loc), loc); return inst(instr); } /** * @brief Make foreach loop instruction * - * @param loop_var Loop variable * @param from Loop variable start * @param to Loop variable bound - * @param body Loop body + * @param loop_var_type Type of loop variable * @param loc Source code location * * @return Instruction */ -inline inst make_foreach(value const &loop_var, value const &from, value const &to, region body, +inline inst make_foreach(value const &from, value const &to, tinytc_data_type_t loop_var_type, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_foreach_inst_create(&instr, loop_var.get(), from.get(), to.get(), - body.release(), &loc), + CHECK_STATUS_LOC(tinytc_foreach_inst_create(&instr, from.get(), to.get(), loop_var_type, &loc), loc); return inst(instr); } @@ -1207,14 +1359,12 @@ inline inst make_foreach(value const &loop_var, value const &from, value const & * @brief Make if condition instruction * * @param condition Condition value (of type bool) - * @param then Then region - * @param otherwise Else region * @param return_type_list Types of returned values * @param loc Source code location * * @return Instruction */ -inline inst make_if(value const &condition, region then, region otherwise = region{}, +inline inst make_if(value const &condition, std::vector const &return_type_list = {}, location const &loc = {}) { tinytc_inst_t instr; @@ -1223,7 +1373,7 @@ inline inst make_if(value const &condition, region then, region otherwise = regi throw std::out_of_range("return type list too long"); } CHECK_STATUS_LOC( - tinytc_if_inst_create(&instr, condition.get(), then.release(), otherwise.release(), len, + tinytc_if_inst_create(&instr, condition.get(), len, const_cast(return_type_list.data()), &loc), loc); return inst(instr); @@ -1264,52 +1414,43 @@ template <> struct unique_handle_traits { class func : public unique_handle { public: using unique_handle::unique_handle; + + void set_work_group_size(std::int32_t x, std::int32_t y) { + CHECK_STATUS(tinytc_func_set_work_group_size(obj_, x, y)); + } + + void set_subgroup_size(std::int32_t sgs) { + CHECK_STATUS(tinytc_func_set_subgroup_size(obj_, sgs)); + } + + auto get_body() -> tinytc_region_t { + tinytc_region_t body; + CHECK_STATUS(tinytc_func_get_body(obj_, &body)); + return body; + } }; /** * @brief Make function * * @param name Function name - * @param arg_list Argument list - * @param body Function body + * @param param_type_list List of parameter types * @param loc Source code location * * @return Function */ -inline func make_function(char const *name, std::vector &arg_list, region body, - location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); +inline func make_func(std::string_view name, std::vector const ¶m_type_list, + location const &loc = {}) { tinytc_func_t fun; - auto len = arg_list.size(); + auto len = param_type_list.size(); if (len > std::numeric_limits::max()) { - throw std::out_of_range("argument list too long"); + throw std::out_of_range("param list too long"); } - tinytc_value_t *al = reinterpret_cast(arg_list.data()); - CHECK_STATUS_LOC(tinytc_function_create(&fun, name, len, al, body.release(), &loc), loc); + tinytc_data_type_t *pl = const_cast(param_type_list.data()); + CHECK_STATUS_LOC(tinytc_func_create(&fun, name.size(), name.data(), len, pl, &loc), loc); return func(fun); } -/** - * @brief Set work-group size (x,y) - * - * @param fun Function object; must have been created with "make_function" - * @param x x - * @param y y - */ -inline void set_work_group_size(func &fun, std::int32_t x, std::int32_t y) { - CHECK_STATUS(tinytc_function_set_work_group_size(fun.get(), x, y)); -} - -/** - * @brief Set subgroup size - * - * @param fun Function object; must have been created with "make_function" - * @param sgs Subgroup size - */ -inline void set_subgroup_size(func &fun, std::int32_t sgs) { - CHECK_STATUS(tinytc_function_set_subgroup_size(fun.get(), sgs)); -} - //////////////////////////// /////////// Prog /////////// //////////////////////////// @@ -1384,9 +1525,9 @@ class prog : public shared_handle { * * @return Program */ -inline prog make_program(compiler_context const &ctx, location const &loc = {}) { +inline prog make_prog(compiler_context const &ctx, location const &loc = {}) { tinytc_prog_t prg; - CHECK_STATUS_LOC(tinytc_program_create(&prg, ctx.get(), &loc), loc); + CHECK_STATUS_LOC(tinytc_prog_create(&prg, ctx.get(), &loc), loc); return prog{prg}; } @@ -1400,18 +1541,9 @@ class region_builder { /** * @brief ctor * - * @param ctx compiler context - * @param loc Source code location - */ - region_builder(compiler_context const &ctx, location const &loc = {}) - : ctx_(ctx), reg_{make_region(loc)} {} - - /** - * @brief Returns built product - * - * @return Region + * @param reg region object */ - inline auto get_product() && -> region { return std::move(reg_); } + region_builder(tinytc_region_t reg) : reg_{reg} {} /** * @brief Add instruction @@ -1421,12 +1553,12 @@ class region_builder { * * @return Value returned by instruction; may be empty */ - [[maybe_unused]] inline auto add(inst i, std::string const &name = "") -> value { + [[maybe_unused]] inline auto add(inst i, std::string_view name = "") -> value { auto result = i.get_value(); if (result && name.size() > 0) { result.name(name); } - reg_.add_instruction(std::move(i)); + add_instruction(reg_, std::move(i)); return result; } @@ -1438,86 +1570,86 @@ class region_builder { * * @return Values returned by instruction */ - [[maybe_unused]] inline auto - add_multivalued(inst i, std::string const &name = "") -> std::vector { + [[maybe_unused]] inline auto add_multivalued(inst i, + std::string_view name = "") -> std::vector { auto results = i.get_values(); if (name.size() > 0) { int counter = 0; + auto name_str = std::string{name}; for (auto &result : results) { - result.name(name + std::to_string(counter++)); + result.name(name_str + std::to_string(counter++)); } } - reg_.add_instruction(std::move(i)); + add_instruction(reg_, std::move(i)); return results; } /** - * @brief Build for-loop with functor f(region_builder&, value) -> void + * @brief Build for-loop with functor f(region_builder&, tinytc_value_t) -> void * * The loop trip count is passed as second argument to the functor. * * @tparam F Functor type - * @param loop_var_ty Type of loop variable * @param from Loop variable start * @param to Loop variable bound + * @param loop_var_ty Type of loop variable * @param f Functor - * @param name Loop variable name + * @param loop_var_name Loop variable name * @param loc Source code location */ template - void for_loop(scalar_type loop_var_ty, value const &from, value const &to, F &&f, - std::string const &name = "", location const &loc = {}) { - for_loop(std::move(loop_var_ty), std::move(from), std::move(to), value{nullptr}, - std::forward(f), name, loc); + void for_loop(value const &from, value const &to, tinytc_data_type_t loop_var_ty, F &&f, + std::string_view loop_var_name = "", location const &loc = {}) { + for_loop(std::move(from), std::move(to), value{nullptr}, std::move(loop_var_ty), + std::forward(f), std::move(loop_var_name), loc); } /** - * @brief Build for-loop with functor f(region_builder&, value) -> void + * @brief Build for-loop with functor f(region_builder&, tinytc_value_t) -> void * * The loop trip count is passed as second argument to the functor. * * @tparam F Functor type - * @param loop_var_ty Type of loop variable * @param from Loop variable start * @param to Loop variable bound * @param step Loop variable step + * @param loop_var_ty Type of loop variable * @param f Functor - * @param name Loop variable name + * @param loop_var_name Loop variable name * @param loc Source code location */ template - void for_loop(scalar_type loop_var_ty, value const &from, value const &to, value const &step, - F &&f, std::string const &name = "", location const &loc = {}) { - auto loop_var = make_value(get_scalar(ctx_, loop_var_ty)); - if (name.size() > 0) { - loop_var.name(name); - } - auto bb = region_builder{ctx_}; + void for_loop(value const &from, value const &to, value const &step, + tinytc_data_type_t loop_var_ty, F &&f, std::string_view loop_var_name = "", + location const &loc = {}) { + auto fi = ::tinytc::make_for(from, to, step, loop_var_ty, loc); + auto reg = fi.get_region(0); + auto loop_var = get_parameter(reg, 0); + set_name(loop_var, loop_var_name); + add_instruction(reg_, std::move(fi)); + auto bb = region_builder{reg}; f(bb, loop_var); - add(::tinytc::make_for(std::move(loop_var), from, to, step, std::move(bb).get_product(), - loc)); } /** - * @brief Build foreach-loop with functor f(region_builder&) -> void + * @brief Build foreach-loop with functor f(region_builder&, tinytc_value_t) -> void * * @tparam F Functor type - * @param loop_var_ty Type of loop variable * @param from Loop variable start * @param to Loop variable bound + * @param loop_var_ty Type of loop variable * @param f functor - * @param name Loop variable name + * @param loop_var_name Loop variable name * @param loc Source code location */ template - void foreach (scalar_type loop_var_ty, value const &from, value const &to, F && f, - std::string const &name = "", location const &loc = {}) { - auto loop_var = make_value(get_scalar(ctx_, loop_var_ty)); - if (name.size() > 0) { - loop_var.name(name); - } - auto bb = region_builder{ctx_}; - f(bb); - add(::tinytc::make_foreach(std::move(loop_var), from, to, std::move(bb).get_product(), - loc)); + void foreach (value const &from, value const &to, tinytc_data_type_t loop_var_ty, F && f, + std::string const &loop_var_name = "", location const &loc = {}) { + auto fi = ::tinytc::make_foreach(from, to, loop_var_ty, loc); + auto reg = fi.get_region(0); + auto loop_var = get_parameter(reg, 0); + set_name(loop_var, loop_var_name); + add_instruction(reg_, std::move(fi)); + auto bb = region_builder{reg}; + f(bb, loop_var); } /** @@ -1535,10 +1667,12 @@ class region_builder { auto if_condition(value const &condition, F &&then, std::vector const &return_type_list = {}, location const &loc = {}) -> std::vector { - auto bb = region_builder{ctx_}; + auto ii = ::tinytc::make_if(std::move(condition), return_type_list, loc); + auto r0 = ii.get_region(0); + auto results = add_multivalued(std::move(ii)); + auto bb = region_builder{r0}; then(bb); - return add_multivalued(::tinytc::make_if(std::move(condition), std::move(bb).get_product(), - region{}, return_type_list, loc)); + return results; } /** * @brief Build if/else with functors then(region_builder&) -> void and @@ -1558,151 +1692,19 @@ class region_builder { auto ifelse(value const &condition, F &&then, G &&otherwise, std::vector const &return_type_list = {}, location const &loc = {}) -> std::vector { - auto bb1 = region_builder{ctx_}; - then(bb1); - auto bb2 = region_builder{ctx_}; - otherwise(bb2); - return add_multivalued(::tinytc::make_if(std::move(condition), std::move(bb1).get_product(), - std::move(bb2).get_product(), return_type_list, - loc)); - } - - inline auto context() -> compiler_context const & { return ctx_; } - - private: - compiler_context ctx_; - region reg_; -}; - -//! Builder for functions -class function_builder { - public: - /** - * @brief creates function \@name - * - * @param ctx compiler context - * @param name Function name - * @param loc Source code location - * - */ - inline function_builder(compiler_context const &ctx, std::string name, location const &loc = {}) - : ctx_(ctx), name_(std::move(name)), body_{nullptr}, loc_(loc) {} - - /** - * @brief Returns built product - * - * @return Function - */ - inline func get_product() && { - auto fun = make_function(name_.c_str(), arguments_, std::move(body_), loc_); - if (x_ > 0 && y_ > 0) { - set_work_group_size(fun, x_, y_); - } - if (sgs_ > 0) { - set_subgroup_size(fun, sgs_); - } - return fun; - } - - /** - * @brief @code %name: %ty @endcode - * - * @param ty Argument type - * @param name Argument name - * @param loc Source code location - * - * @return Value - */ - inline value argument(tinytc_data_type_t ty, std::string const &name = "", - location const &loc = {}) { - auto v = make_value(ty, loc); - if (name.size() > 0) { - v.name(name); - } - arguments_.emplace_back(std::move(v)); - return arguments_.back(); - } - - /** - * @brief @code work_group_size(%x, %y) @endcode - * - * @param x x - * @param y y - */ - inline void work_group_size(std::int32_t x, std::int32_t y) { - x_ = x; - y_ = y; - } - /** - * @brief @code subgroup_size(%subgroup_size) @endcode - * - * @param subgroup_size Subgroup size - */ - inline void subgroup_size(std::int32_t subgroup_size) { sgs_ = subgroup_size; } - - /** - * @brief Build function body with functor f(region_builder&) -> void - * - * @tparam F Functor type - * @param f Functor - * @param loc Source code location - */ - template void body(F &&f, location const &loc = {}) { - auto bb = region_builder{ctx_, loc}; - f(bb); - body_ = std::move(bb).get_product(); - } - - private: - compiler_context ctx_; - std::string name_; - region body_; - location loc_; - std::vector arguments_; - std::int32_t x_ = 0, y_ = 0, sgs_ = 0; -}; - -//! Builder for programs -class program_builder { - public: - /** - * @brief ctor - * - * @param ctx Compiler context - * @param loc Source code location - * - */ - program_builder(compiler_context const &ctx, location const &loc = {}) - : prg_{make_program(ctx, loc)} {} - - /** - * @brief create function \@name with functor f(function_builder&) -> void - * - * @tparam F Functor type - * @param name Function name - * @param f Functor - * @param loc Source code location - */ - template void create(std::string name, F &&f, location const &loc = {}) { - auto fb = function_builder(prg_.get_compiler_context(), std::move(name), loc); - f(fb); - add(std::move(fb).get_product()); + auto ii = ::tinytc::make_if(std::move(condition), return_type_list, loc); + auto r0 = ii.get_region(0); + auto r1 = ii.get_region(1); + auto results = add_multivalued(std::move(ii)); + auto bb0 = region_builder{r0}; + then(bb0); + auto bb1 = region_builder{r1}; + otherwise(bb1); + return results; } - /** - * @brief Add function - * - * @param f function - */ - inline void add(func f) { prg_.add_function(std::move(f)); } - /** - * @brief Returns built product - * - * @return Program - */ - inline prog get_product() && { return std::move(prg_); } private: - prog prg_; + tinytc_region_t reg_; }; //////////////////////////// diff --git a/src/analysis/cfg.cpp b/src/analysis/cfg.cpp index 3a8aa98b..8f64aad7 100644 --- a/src/analysis/cfg.cpp +++ b/src/analysis/cfg.cpp @@ -38,23 +38,28 @@ auto get_control_flow_graph(region_node &topreg) -> control_flow_graph { auto pred_nodes = std::queue{}; const auto visit_inst = [&](inst_node *node) { + bool empty_child_regions = true; if (node->num_child_regions() > 0) { for (auto &subreg : node->child_regions()) { auto [substart, subexits] = - add_region_ref(*subreg, std::max(kind_max, subreg->kind()), add_region_ref); - cfg.add_edge(node, substart); - if (isa(*node)) { - for (; !subexits.empty(); subexits.pop()) { - cfg.add_edge(subexits.front(), node); - } - pred_nodes.push(node); - } else { - for (; !subexits.empty(); subexits.pop()) { - pred_nodes.push(subexits.front()); + add_region_ref(subreg, std::max(kind_max, subreg.kind()), add_region_ref); + if (substart != nullptr && !subexits.empty()) { + empty_child_regions = false; + cfg.add_edge(node, substart); + if (isa(*node)) { + for (; !subexits.empty(); subexits.pop()) { + cfg.add_edge(subexits.front(), node); + } + pred_nodes.push(node); + } else { + for (; !subexits.empty(); subexits.pop()) { + pred_nodes.push(subexits.front()); + } } } } - } else { + } + if (empty_child_regions) { pred_nodes.push(node); } }; diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index c1a1b201..2ee0c846 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -439,7 +439,7 @@ void tile_loop_by_sgs_new(region_builder &bb, value const &loop_trip_count, int void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_count, int sgs, int num_tiles, value const &sg_id, sgs_loop_body_builder_new const &body) { - auto index_ty = get_scalar(bb.context(), scalar_type::index); + auto index_ty = scalar_data_type::get(sg_id->context(), scalar_type::index); std::int64_t blocks = loop_trip_count / sgs; std::int64_t rem = loop_trip_count % sgs; @@ -453,22 +453,23 @@ void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_co if (blocks > 0) { auto block_start = bb.add(make_arith(arithmetic::mul, c_sgs, sg_id_index)); bb.for_loop( - scalar_type::index, std::move(block_start), c_sgs_blocks, c_sgs_tiles, - [&](region_builder &bb, value const &block) { body(bb, block, false, c_sgs); }, + std::move(block_start), c_sgs_blocks, c_sgs_tiles, index_ty, + [&](region_builder &bb, tinytc_value_t block) { body(bb, block, false, c_sgs.get()); }, "block"); } if (rem > 0) { auto condition = bb.add(make_cmp(cmp_condition::eq, sg_id_index, c_tiles_1)); - bb.if_condition(condition, - [&](region_builder &bb) { body(bb, c_sgs_blocks, true, c_rem); }); + bb.if_condition(condition, [&](region_builder &bb) { + body(bb, c_sgs_blocks.get(), true, c_rem.get()); + }); } } void tile_loop_by_sgs_new_dynamic(region_builder &bb, value const &loop_trip_count, int sgs, int num_tiles, value const &sg_id, sgs_loop_body_builder_new const &body) { - auto index_ty = get_scalar(bb.context(), scalar_type::index); + auto index_ty = scalar_data_type::get(sg_id->context(), scalar_type::index); auto c_sgs = bb.add(make_constant(sgs, index_ty)); auto c_sgs_tiles = bb.add(make_constant(sgs * num_tiles, index_ty)); auto c0 = bb.add(make_constant(0, index_ty)); @@ -481,15 +482,16 @@ void tile_loop_by_sgs_new_dynamic(region_builder &bb, value const &loop_trip_cou auto block_start = bb.add(make_arith(arithmetic::mul, c_sgs, sg_id_index)); auto block_end = bb.add(make_arith(arithmetic::mul, c_sgs, blocks)); bb.for_loop( - scalar_type::index, std::move(block_start), std::move(block_end), c_sgs_tiles, - [&](region_builder &bb, value const &block) { body(bb, block, false, c_sgs); }, "block"); + std::move(block_start), std::move(block_end), c_sgs_tiles, index_ty, + [&](region_builder &bb, tinytc_value_t block) { body(bb, block, false, c_sgs.get()); }, + "block"); auto condition0 = bb.add(make_cmp(cmp_condition::gt, rem, c0)); bb.if_condition(condition0, [&](region_builder &bb) { auto condition1 = bb.add(make_cmp(cmp_condition::eq, sg_id_index, c_tiles_1)); bb.if_condition(condition1, [&](region_builder &bb) { auto block = bb.add(make_arith(arithmetic::mul, blocks, c_sgs)); - body(bb, block, true, rem); + body(bb, block.get(), true, rem.get()); }); }); } @@ -504,7 +506,7 @@ void tile_loop_uniformly_new(region_builder &bb, value const &loop_trip_count, i void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip_count, int block_size, int num_tiles, value const &sg_id, uniform_loop_body_builder_new const &body) { - auto index_ty = get_scalar(bb.context(), scalar_type::index); + auto index_ty = scalar_data_type::get(sg_id->context(), scalar_type::index); // Find minimum number of blocks such that the block sizes are smaller or equal block_size std::int64_t blocks = 1 + (loop_trip_count - 1) / block_size; // Increase the number of blocks if such that the number of blocks is a multiple @@ -527,8 +529,9 @@ void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip if (rem > 0) { auto block_start = bb.add(make_arith(arithmetic::mul, c_bs_1, sg_id_index)); bb.for_loop( - scalar_type::index, std::move(block_start), c_bs_1_rem, c_bs_1_tiles, - [&](region_builder &bb, value const &block) { body(bb, block, c_bs_1); }, "block"); + std::move(block_start), c_bs_1_rem, c_bs_1_tiles, index_ty, + [&](region_builder &bb, tinytc_value_t block) { body(bb, block, c_bs_1.get()); }, + "block"); } auto tmp = bb.add(make_arith(arithmetic::add, sg_id_index, c_rem_mod_tiles)); @@ -536,14 +539,14 @@ void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip auto tmp2 = bb.add(make_arith(arithmetic::mul, c_bs, sg_id_1)); auto block_start = bb.add(make_arith(arithmetic::add, c_bs_1_rem, tmp2)); bb.for_loop( - scalar_type::index, std::move(block_start), c_loop_trip_count, c_bs_tiles, - [&](region_builder &bb, value const &block) { body(bb, block, c_bs); }, "block"); + std::move(block_start), c_loop_trip_count, c_bs_tiles, index_ty, + [&](region_builder &bb, tinytc_value_t block) { body(bb, block, c_bs.get()); }, "block"); } void tile_loop_uniformly_new_dynamic(region_builder &bb, value const &loop_trip_count, int block_size, int num_tiles, value const &sg_id, uniform_loop_body_builder_new const &body) { - auto index_ty = get_scalar(bb.context(), scalar_type::index); + auto index_ty = scalar_data_type::get(loop_trip_count->context(), scalar_type::index); auto c1 = bb.add(make_constant(1, index_ty)); auto c_block_size = bb.add(make_constant(block_size, index_ty)); auto c_tiles = bb.add(make_constant(num_tiles, index_ty)); @@ -568,8 +571,8 @@ void tile_loop_uniformly_new_dynamic(region_builder &bb, value const &loop_trip_ auto block_end_1 = bb.add(make_arith(arithmetic::mul, bs_1, rem)); auto step_1 = bb.add(make_arith(arithmetic::mul, bs_1, c_tiles)); bb.for_loop( - scalar_type::index, std::move(block_start_1), std::move(block_end_1), std::move(step_1), - [&](region_builder &bb, value const &block) { body(bb, block, bs_1); }, "block"); + std::move(block_start_1), std::move(block_end_1), std::move(step_1), index_ty, + [&](region_builder &bb, tinytc_value_t block) { body(bb, block, bs_1.get()); }, "block"); auto tmp0 = bb.add(make_arith(arithmetic::rem, rem, c_tiles)); auto tmp1 = bb.add(make_arith(arithmetic::add, sg_id_index, tmp0)); @@ -579,8 +582,8 @@ void tile_loop_uniformly_new_dynamic(region_builder &bb, value const &loop_trip_ auto block_start = bb.add(make_arith(arithmetic::add, tmp3, tmp2)); auto step = bb.add(make_arith(arithmetic::mul, bs, c_tiles)); bb.for_loop( - scalar_type::index, std::move(block_start), loop_trip_count, std::move(step), - [&](region_builder &bb, value const &block) { body(bb, block, bs); }, "block"); + std::move(block_start), loop_trip_count, std::move(step), index_ty, + [&](region_builder &bb, tinytc_value_t block) { body(bb, block, bs.get()); }, "block"); } } // namespace tinytc diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index 84725a5a..f4a63c4c 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -123,9 +123,9 @@ void write_matrix_block(clir::block_builder &bb, block_accessor const &block, clir::expr beta, core_config const &core_cfg); using sgs_loop_body_builder_new = - std::function; + std::function; using uniform_loop_body_builder_new = - std::function; + std::function; void tile_loop_by_sgs_new(region_builder &bb, value const &loop_trip_count, int sgs, int num_tiles, value const &sg_id, sgs_loop_body_builder_new const &body); diff --git a/src/func.cpp b/src/func.cpp index ee94efae..eacda1e7 100644 --- a/src/func.cpp +++ b/src/func.cpp @@ -19,31 +19,40 @@ using namespace tinytc; extern "C" { -tinytc_status_t tinytc_function_create(tinytc_func_t *fun, char const *name, uint32_t arg_list_size, - tinytc_value_t *arg_list, tinytc_region_t body, - const tinytc_location_t *loc) { - if (fun == nullptr || (arg_list_size > 0 && arg_list == nullptr) || body == nullptr) { +tinytc_status_t tinytc_func_create(tinytc_func_t *fun, uint32_t name_length, char const *name, + uint32_t num_params, tinytc_data_type_t *param_type_list, + const tinytc_location_t *loc) { + if (fun == nullptr || (num_params > 0 && param_type_list == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto arg_vec = std::vector(); - arg_vec.reserve(arg_list_size); - for (uint32_t i = 0; i < arg_list_size; ++i) { - arg_vec.emplace_back(value(arg_list[i], true)); - } - *fun = std::make_unique(std::string(name), std::move(arg_vec), body, + *fun = std::make_unique(std::string(name, name_length), + array_view(param_type_list, num_params), get_optional(loc)) .release(); }); } -tinytc_status_t tinytc_function_set_work_group_size(tinytc_func_t fun, int32_t x, int32_t y) { +tinytc_status_t tinytc_func_set_work_group_size(tinytc_func_t fun, int32_t x, int32_t y) { + if (fun == nullptr) { + return tinytc_status_invalid_arguments; + } return exception_to_status_code([&] { fun->work_group_size({x, y}); }); } -tinytc_status_t tinytc_function_set_subgroup_size(tinytc_func_t fun, int32_t sgs) { +tinytc_status_t tinytc_func_set_subgroup_size(tinytc_func_t fun, int32_t sgs) { + if (fun == nullptr) { + return tinytc_status_invalid_arguments; + } return exception_to_status_code([&] { fun->subgroup_size(sgs); }); } +tinytc_status_t tinytc_func_get_body(tinytc_func_t fun, tinytc_region_t *body) { + if (fun == nullptr || body == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *body = &fun->body(); }); +} + void tinytc_func_destroy(tinytc_func_t obj) { delete obj; } } diff --git a/src/inst.cpp b/src/inst.cpp index f1abf9ae..d6337a08 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -348,14 +348,12 @@ tinytc_status_t tinytc_num_subgroups_inst_create(tinytc_inst_t *instr, [&] { *instr = std::make_unique(ctx, get_optional(loc)).release(); }); } -tinytc_status_t tinytc_parallel_inst_create(tinytc_inst_t *instr, tinytc_region_t body, - const tinytc_location_t *loc) { - if (instr == nullptr || body == nullptr) { +tinytc_status_t tinytc_parallel_inst_create(tinytc_inst_t *instr, const tinytc_location_t *loc) { + if (instr == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { - *instr = std::make_unique(region{body}, get_optional(loc)).release(); - }); + return exception_to_status_code( + [&] { *instr = std::make_unique(get_optional(loc)).release(); }); } tinytc_status_t tinytc_size_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t mode, @@ -469,41 +467,37 @@ tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinytc_transpose_t }); } -tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_value_t loop_var, - tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, - tinytc_region_t body, const tinytc_location_t *loc) { - if (instr == nullptr || loop_var == nullptr || from == nullptr || to == nullptr || - body == nullptr) { +tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_value_t from, tinytc_value_t to, + tinytc_value_t step, tinytc_data_type_t loop_var_type, + const tinytc_location_t *loc) { + if (instr == nullptr || loop_var_type == nullptr || from == nullptr || to == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = - std::make_unique(value(loop_var, true), value(from, true), value(to, true), - value(step, true), region{body}, get_optional(loc)) - .release(); + *instr = std::make_unique(value(from, true), value(to, true), value(step, true), + loop_var_type, get_optional(loc)) + .release(); }); } -tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, tinytc_value_t loop_var, - tinytc_value_t from, tinytc_value_t to, - tinytc_region_t body, const tinytc_location_t *loc) { - if (instr == nullptr || loop_var == nullptr || from == nullptr || to == nullptr || - body == nullptr) { +tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, tinytc_value_t from, + tinytc_value_t to, tinytc_data_type_t loop_var_type, + const tinytc_location_t *loc) { + if (instr == nullptr || loop_var_type == nullptr || from == nullptr || to == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(value(loop_var, true), value(from, true), - value(to, true), region{body}, get_optional(loc)) + *instr = std::make_unique(value(from, true), value(to, true), loop_var_type, + get_optional(loc)) .release(); }); } tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condition, - tinytc_region_t then, tinytc_region_t otherwise, uint32_t return_type_list_size, tinytc_data_type_t *return_type_list, const tinytc_location_t *loc) { - if (instr == nullptr || condition == nullptr || then == nullptr || + if (instr == nullptr || condition == nullptr || (return_type_list_size > 0 && return_type_list == nullptr)) { return tinytc_status_invalid_arguments; } @@ -513,8 +507,7 @@ tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condi for (uint32_t i = 0; i < return_type_list_size; ++i) { rt.emplace_back(return_type_list[i]); } - *instr = std::make_unique(value(condition, true), region{then}, region{otherwise}, - std::move(rt), get_optional(loc)) + *instr = std::make_unique(value(condition, true), std::move(rt), get_optional(loc)) .release(); }); } @@ -565,4 +558,35 @@ tinytc_status_t tinytc_inst_get_values(const_tinytc_inst_t instr, uint32_t *resu *result_list_size = num; }); } + +tinytc_status_t tinytc_inst_get_region(tinytc_inst_t instr, uint32_t region_no, + tinytc_region_t *result) { + if (instr == nullptr || result == nullptr || region_no >= instr->num_child_regions()) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *result = &instr->child_region(region_no); }); +} + +tinytc_status_t tinytc_inst_get_regions(tinytc_inst_t instr, uint32_t *result_list_size, + tinytc_region_t *result_list) { + if (instr == nullptr || result_list_size == nullptr || + (*result_list_size > 0 && result_list == nullptr)) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + auto const num_results = instr->num_child_regions(); + if (num_results > std::numeric_limits::max()) { + throw std::out_of_range("too many results"); + } + auto const num = static_cast(num_results); + if (*result_list_size > 0) { + auto results = instr->child_regions_begin(); + auto const limit = std::min(num, *result_list_size); + for (uint32_t i = 0; i < limit; ++i) { + result_list[i] = &results[i]; + } + } + *result_list_size = num; + }); +} } diff --git a/src/node/function_node.hpp b/src/node/function_node.hpp index 6aa71007..bf507854 100644 --- a/src/node/function_node.hpp +++ b/src/node/function_node.hpp @@ -13,44 +13,26 @@ #include #include #include -#include - -namespace tinytc { -using value_range = iterator_range_wrapper; -using const_value_range = iterator_range_wrapper; -} // namespace tinytc struct tinytc_func final { public: - inline tinytc_func(std::string name, std::vector args, tinytc_region_t body, + inline tinytc_func(std::string name, tinytc::array_view params, tinytc_location const &lc = {}) - : name_(std::move(name)), args_(std::move(args)), body_(tinytc::region{body}), - work_group_size_{0, 0}, subgroup_size_{0}, loc_{lc} { - body_->kind(tinytc::region_kind::collective); + : name_(std::move(name)), body_(std::move(params)), work_group_size_{0, 0}, + subgroup_size_{0}, loc_{lc} { + body_.kind(tinytc::region_kind::collective); } inline auto loc() const noexcept -> tinytc_location const & { return loc_; } inline void loc(tinytc_location const &loc) noexcept { loc_ = loc; } - inline auto arg_begin() -> tinytc::value * { return args_.size() > 0 ? args_.data() : nullptr; } - inline auto arg_end() -> tinytc::value * { - return args_.size() > 0 ? args_.data() + args_.size() : nullptr; - } - inline auto args() -> tinytc::value_range { - return tinytc::value_range{arg_begin(), arg_end()}; - } - inline auto arg_begin() const -> tinytc::value const * { - return args_.size() > 0 ? args_.data() : nullptr; - } - inline auto arg_end() const -> tinytc::value const * { - return args_.size() > 0 ? args_.data() + args_.size() : nullptr; - } - inline auto args() const -> tinytc::const_value_range { - return tinytc::const_value_range{arg_begin(), arg_end()}; - } + inline auto params() { return body_.params(); } + inline auto params() const { return body_.params(); } + inline auto num_params() const noexcept { return body_.num_params(); } inline auto name() const -> std::string_view { return name_; } - inline auto body() const -> tinytc_region & { return *body_; } + inline auto body() -> tinytc_region & { return body_; } + inline auto body() const -> tinytc_region const & { return body_; } inline auto work_group_size() const -> std::array { return work_group_size_; } inline void work_group_size(std::array const &work_group_size) { @@ -61,8 +43,7 @@ struct tinytc_func final { private: std::string name_; - std::vector args_; - tinytc::region body_; + tinytc_region body_; std::array work_group_size_; std::int32_t subgroup_size_; tinytc_location loc_; diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 671208ea..d30c6dae 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -54,16 +54,15 @@ blas_a3_inst::blas_a3_inst(IK tid, value alpha, value A, value B, value beta, va op(op_C) = std::move(C); } -loop_inst::loop_inst(IK tid, value loop_var0, value from0, value to0, value step0, region body0, +loop_inst::loop_inst(IK tid, value from0, value to0, value step0, tinytc_data_type_t loop_var_type, location const &lc) - : standard_inst{tid, step0 ? 4 : 3} { - op(op_loop_var) = std::move(loop_var0); + : standard_inst{tid, step0 ? 3 : 2} { op(op_from) = std::move(from0); op(op_to) = std::move(to0); op(op_step) = std::move(step0); - child_region(0) = std::move(body0); - + body().add_param(loop_var_type); loc(lc); + auto lvt = get_scalar_type(loc(), loop_var()); auto fromt = get_scalar_type(loc(), from()); auto tot = get_scalar_type(loc(), to()); @@ -73,7 +72,8 @@ loop_inst::loop_inst(IK tid, value loop_var0, value from0, value to0, value step step_ok = lvt->ty() == stept->ty(); } - if (lvt->ty() != fromt->ty() || lvt->ty() != tot->ty() || !step_ok) { + if (!is_integer_type(lvt->ty()) || lvt->ty() != fromt->ty() || lvt->ty() != tot->ty() || + !step_ok) { throw compilation_error(loc(), status::ir_scalar_mismatch); } } @@ -202,22 +202,9 @@ constant_inst::constant_inst(value_type const &value, tinytc_data_type_t ty, loc if (auto st = dyn_cast(ty); st) { const auto type_ok = [](value_type const &val, scalar_type ty) { - switch (ty) { - case scalar_type::i1: - case scalar_type::i8: - case scalar_type::i16: - case scalar_type::i32: - case scalar_type::i64: - case scalar_type::index: - return std::holds_alternative(val); - case scalar_type::f32: - case scalar_type::f64: - return std::holds_alternative(val); - case scalar_type::c32: - case scalar_type::c64: - return std::holds_alternative>(val); - } - return false; + return (is_integer_type(ty) && std::holds_alternative(val)) || + (is_floating_type(ty) && std::holds_alternative(val)) || + (is_complex_type(ty) && std::holds_alternative>(val)); }; if (!type_ok(value_, st->ty())) { throw compilation_error(loc(), status::ir_scalar_mismatch); @@ -426,15 +413,10 @@ ger_inst::ger_inst(value alpha0, value A0, value B0, value beta0, value C0, bool } } -foreach_inst::foreach_inst(value loop_var, value from, value to, region body, location const &loc) - : loop_inst{IK::foreach_loop, - std::move(loop_var), - std::move(from), - std::move(to), - {}, - std::move(body), - loc} { - child_region(0)->kind(region_kind::spmd); +foreach_inst::foreach_inst(value from, value to, tinytc_data_type_t loop_var_type, + location const &loc) + : loop_inst{IK::foreach_loop, std::move(from), std::move(to), {}, loop_var_type, loc} { + child_region(0).kind(region_kind::spmd); } hadamard_inst::hadamard_inst(value alpha0, value A0, value B0, value beta0, value C0, bool atomic, @@ -462,23 +444,20 @@ hadamard_inst::hadamard_inst(value alpha0, value A0, value B0, value beta0, valu } } -if_inst::if_inst(value condition, region then0, region otherwise0, - std::vector const &return_types, location const &lc) - : standard_inst{IK::if_, 1, static_cast(return_types.size()), otherwise0 ? 2 : 1} { +if_inst::if_inst(value condition, std::vector const &return_types, + location const &lc) + : standard_inst{IK::if_, 1, static_cast(return_types.size())} { op(0) = std::move(condition); - child_region(child_region_then) = std::move(then0); - child_region(child_region_otherwise) = std::move(otherwise0); loc(lc); for (std::size_t i = 0; i < return_types.size(); ++i) { result(i) = make_value(return_types[i]); } } -parallel_inst::parallel_inst(region body, location const &lc) : standard_inst{IK::parallel} { - child_region(0) = std::move(body); +parallel_inst::parallel_inst(location const &lc) : standard_inst{IK::parallel} { loc(lc); - child_region(0)->kind(region_kind::spmd); + child_region(0).kind(region_kind::spmd); } size_inst::size_inst(value op0, std::int64_t mode, location const &lc) diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index e324b34b..e135721f 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -6,6 +6,7 @@ #include "error.hpp" #include "node/data_type_node.hpp" +#include "node/region_node.hpp" #include "support/ilist.hpp" #include "support/type_list.hpp" #include "support/util.hpp" @@ -85,8 +86,9 @@ using inst_nodes = using value_range = iterator_range_wrapper; using const_value_range = iterator_range_wrapper; -using region_range = iterator_range_wrapper; -using const_region_range = iterator_range_wrapper; + +using region_range = iterator_range_wrapper; +using const_region_range = iterator_range_wrapper; } // namespace tinytc @@ -138,25 +140,23 @@ struct tinytc_inst : tinytc::ilist_node_with_parent inline auto num_results() const -> std::int64_t { return result_end_ - result_begin_; } // Iterator over regions - inline auto child_regions_begin() -> tinytc::region * { return child_regions_begin_; } - inline auto child_regions_end() -> tinytc::region * { return child_regions_end_; } + inline auto child_regions_begin() -> tinytc_region_t { return child_regions_begin_; } + inline auto child_regions_end() -> tinytc_region_t { return child_regions_end_; } inline auto child_regions() -> tinytc::region_range { return tinytc::region_range{child_regions_begin(), child_regions_end()}; } - inline auto child_regions_begin() const -> tinytc::region const * { + inline auto child_regions_begin() const -> const_tinytc_region_t { return child_regions_begin_; } - inline auto child_regions_end() const -> tinytc::region const * { return child_regions_end_; } + inline auto child_regions_end() const -> const_tinytc_region_t { return child_regions_end_; } inline auto child_regions() const -> tinytc::const_region_range { return tinytc::const_region_range{child_regions_begin(), child_regions_end()}; } - inline auto child_region(std::size_t pos) -> tinytc::region & { - return child_regions_begin_[pos]; - } - inline auto child_region(std::size_t pos) const -> tinytc::region const & { + auto child_region(std::size_t pos) -> tinytc_region & { return child_regions_begin_[pos]; } + auto child_region(std::size_t pos) const -> tinytc_region const & { return child_regions_begin_[pos]; } - inline auto num_child_regions() const -> std::int64_t { + auto num_child_regions() const -> std::int64_t { return child_regions_end_ - child_regions_begin_; } @@ -215,7 +215,7 @@ struct tinytc_inst : tinytc::ilist_node_with_parent result_begin_ = begin; result_end_ = end; } - inline auto child_regions_range(tinytc::region *begin, tinytc::region *end) { + inline auto child_regions_range(tinytc_region_t begin, tinytc_region_t end) { child_regions_begin_ = begin; child_regions_end_ = end; } @@ -225,7 +225,7 @@ struct tinytc_inst : tinytc::ilist_node_with_parent tinytc::location loc_; tinytc::value *op_begin_ = nullptr, *op_end_ = nullptr, *result_begin_ = nullptr, *result_end_ = nullptr; - tinytc::region *child_regions_begin_ = nullptr, *child_regions_end_ = nullptr; + tinytc_region_t child_regions_begin_ = nullptr, child_regions_end_ = nullptr; }; namespace tinytc { @@ -284,7 +284,7 @@ class standard_inst : public inst_node { private: object_container ops_; object_container results_; - object_container child_regions_; + object_container child_regions_; }; class blas_a2_inst : public standard_inst<4, 0> { @@ -326,20 +326,20 @@ class blas_a3_inst : public standard_inst<5, 0> { bool atomic_; }; -class loop_inst : public standard_inst<4, 0, 1> { +class loop_inst : public standard_inst<3, 0, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() >= IK::loop && i.type_id() <= IK::last_loop; } - enum op_number { op_loop_var = 0, op_from = 1, op_to = 2, op_step = 3 }; - loop_inst(IK tid, value loop_var, value from, value to, value step, region body, + enum op_number { op_from = 0, op_to = 1, op_step = 2 }; + loop_inst(IK tid, value from, value to, value step, tinytc_data_type_t loop_var_type, location const &loc = {}); - inline auto loop_var() const -> value const & { return op(op_loop_var); } inline auto from() const -> value const & { return op(op_from); } inline auto to() const -> value const & { return op(op_to); } inline auto step() const -> value const & { return op(op_step); } - inline auto body() -> tinytc_region & { return *child_region(0); } - inline auto body() const -> tinytc_region const & { return *child_region(0); } + inline auto body() -> tinytc_region & { return child_region(0); } + inline auto body() const -> tinytc_region const & { return child_region(0); } + inline auto loop_var() const -> value const & { return body().param(0); } }; class alloca_inst : public standard_inst<0, 1> { @@ -553,23 +553,19 @@ class ger_inst : public blas_a3_inst { class for_inst : public loop_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::for_loop; } - inline for_inst(value loop_var, value from, value to, region body, location const &loc = {}) - : for_inst{std::move(loop_var), std::move(from), std::move(to), {}, std::move(body), loc} {} - inline for_inst(value loop_var, value from, value to, value step, region body, + inline for_inst(value from, value to, tinytc_data_type_t loop_var_type, + location const &loc = {}) + : for_inst{std::move(from), std::move(to), {}, loop_var_type, loc} {} + inline for_inst(value from, value to, value step, tinytc_data_type_t loop_var_type, location const &loc = {}) - : loop_inst{IK::for_loop, - std::move(loop_var), - std::move(from), - std::move(to), - std::move(step), - std::move(body), - loc} {} + : loop_inst{IK::for_loop, std::move(from), std::move(to), + std::move(step), loop_var_type, loc} {} }; class foreach_inst : public loop_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::foreach_loop; } - foreach_inst(value loop_var, value from, value to, region body, location const &loc = {}); + foreach_inst(value from, value to, tinytc_data_type_t loop_var_type, location const &loc = {}); }; class hadamard_inst : public blas_a3_inst { @@ -583,16 +579,16 @@ class if_inst : public standard_inst<1, dynamic, 2> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::if_; } enum child_region_number { child_region_then = 0, child_region_otherwise = 1 }; - if_inst(value condition, region then, region otherwise = {}, - std::vector const &return_types = {}, location const &lc = {}); + if_inst(value condition, std::vector const &return_types = {}, + location const &lc = {}); inline auto condition() const -> value const & { return op(0); } - inline auto then() -> tinytc_region & { return *child_region(child_region_then); } - inline auto then() const -> tinytc_region const & { return *child_region(child_region_then); } - inline auto has_otherwise() const -> bool { return bool(child_region(child_region_otherwise)); } - inline auto otherwise() -> tinytc_region & { return *child_region(child_region_otherwise); } + inline auto then() -> tinytc_region & { return child_region(child_region_then); } + inline auto then() const -> tinytc_region const & { return child_region(child_region_then); } + inline auto otherwise() -> tinytc_region & { return child_region(child_region_otherwise); } inline auto otherwise() const -> tinytc_region const & { - return *child_region(child_region_otherwise); + return child_region(child_region_otherwise); } + inline bool is_otherwise_empty() const { return otherwise().insts().empty(); } }; class num_subgroups_inst : public standard_inst<0, 1> { @@ -608,10 +604,10 @@ class num_subgroups_inst : public standard_inst<0, 1> { class parallel_inst : public standard_inst<0, 0, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::parallel; } - parallel_inst(region body, location const &lc = {}); + parallel_inst(location const &lc = {}); - inline auto body() -> tinytc_region & { return *child_region(0); } - inline auto body() const -> tinytc_region const & { return *child_region(0); } + inline auto body() -> tinytc_region & { return child_region(0); } + inline auto body() const -> tinytc_region const & { return child_region(0); } }; class size_inst : public standard_inst<1, 1> { diff --git a/src/node/program_node.cpp b/src/node/program_node.cpp index 66b56f61..aa4c94bf 100644 --- a/src/node/program_node.cpp +++ b/src/node/program_node.cpp @@ -15,11 +15,5 @@ tinytc_prog::tinytc_prog(tinytc::compiler_context ctx, tinytc_location const &lc : ctx_{std::move(ctx)} { loc(lc); } - -tinytc_prog::~tinytc_prog() { - for (auto &f : functions()) { - tinytc_func_destroy(f); - } -} } diff --git a/src/node/program_node.hpp b/src/node/program_node.hpp index 56944fff..d843c24b 100644 --- a/src/node/program_node.hpp +++ b/src/node/program_node.hpp @@ -11,43 +11,30 @@ #include "tinytc/tinytc.hpp" #include "tinytc/types.h" +#include #include -namespace tinytc { -using func_range = iterator_range_wrapper; -using const_func_range = iterator_range_wrapper; -} // namespace tinytc - struct tinytc_prog final : tinytc::reference_counted { public: + using iterator = tinytc::indirect_iterator::iterator>; + using const_iterator = tinytc::indirect_iterator::const_iterator>; + tinytc_prog(tinytc::compiler_context ctx, tinytc_location const &lc = {}); - ~tinytc_prog(); inline auto get_context() const -> tinytc_compiler_context_t { return ctx_.get(); } inline auto loc() const noexcept -> tinytc_location const & { return loc_; } inline void loc(tinytc_location const &loc) noexcept { loc_ = loc; } - inline auto begin() -> tinytc_func_t * { return funcs_.size() > 0 ? funcs_.data() : nullptr; } - inline auto end() -> tinytc_func_t * { - return funcs_.size() > 0 ? funcs_.data() + funcs_.size() : nullptr; - } - inline auto functions() -> tinytc::func_range { return tinytc::func_range{begin(), end()}; } - inline auto begin() const -> const_tinytc_func_t * { - return funcs_.size() > 0 ? const_cast(funcs_.data()) : nullptr; - } - inline auto end() const -> const_tinytc_func_t * { - return funcs_.size() > 0 ? const_cast(funcs_.data()) + funcs_.size() - : nullptr; - } - inline auto functions() const -> tinytc::const_func_range { - return tinytc::const_func_range{begin(), end()}; - } - inline void push_back(tinytc_func_t fun) { funcs_.push_back(fun); } + inline auto begin() -> iterator { return iterator{funcs_.begin()}; } + inline auto end() -> iterator { return iterator{funcs_.end()}; } + inline auto begin() const -> const_iterator { return const_iterator{funcs_.begin()}; } + inline auto end() const -> const_iterator { return const_iterator{funcs_.end()}; } + inline void push_back(tinytc::func fun) { funcs_.push_back(std::move(fun)); } private: tinytc::compiler_context ctx_; - std::vector funcs_; + std::vector funcs_; tinytc_location loc_; }; diff --git a/src/node/region_node.cpp b/src/node/region_node.cpp index 5fd92be8..0e844aec 100644 --- a/src/node/region_node.cpp +++ b/src/node/region_node.cpp @@ -2,13 +2,36 @@ // SPDX-License-Identifier: BSD-3-Clause #include "node/region_node.hpp" +#include "node/inst_node.hpp" namespace tinytc { -auto ilist_traits::get_parent_region() -> tinytc_region * { +auto ilist_callbacks::get_parent_region() -> tinytc_region * { return reinterpret_cast(reinterpret_cast(this) - tinytc_region::inst_list_offset()); } +void ilist_callbacks::node_added(tinytc_inst_t node) { + node->parent(get_parent_region()); +} +void ilist_callbacks::node_removed(tinytc_inst_t node) { tinytc_inst_destroy(node); } + } // namespace tinytc +using namespace tinytc; + +tinytc_region::tinytc_region(array_view param_types, location const &lc) + : kind_(region_kind::mixed) { + loc(lc); + + params_.reserve(param_types.size()); + for (auto ¶m_ty : param_types) { + params_.push_back(make_value(param_ty)); + } +} +tinytc_region::~tinytc_region() {} + +void tinytc_region::add_param(tinytc_data_type_t param_type) { + params_.push_back(make_value(param_type)); +} + diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp index 279e6ced..d3d895bf 100644 --- a/src/node/region_node.hpp +++ b/src/node/region_node.hpp @@ -4,24 +4,26 @@ #ifndef REGION_NODE_20230908_HPP #define REGION_NODE_20230908_HPP -#include "node/inst_node.hpp" #include "support/ilist.hpp" +#include "support/util.hpp" #include "tinytc/tinytc.h" +#include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" #include #include +#include namespace tinytc { //! Instruction classification enum class region_kind { mixed = 0x0, collective = 0x1, spmd = 0x2 }; -template <> struct ilist_traits { +template <> struct ilist_callbacks { auto get_parent_region() -> tinytc_region *; - void node_added(inst_node *node) { node->parent(get_parent_region()); } - void node_removed(inst_node *node) { tinytc_inst_destroy(node); } + void node_added(tinytc_inst_t node); + void node_removed(tinytc_inst_t node); }; } // namespace tinytc @@ -31,9 +33,9 @@ struct tinytc_region final { using iterator = tinytc::ilist::iterator; using const_iterator = tinytc::ilist::const_iterator; - inline tinytc_region(tinytc::location const &lc = {}) : kind_(tinytc::region_kind::mixed) { - loc(lc); - } + tinytc_region(tinytc::array_view param_types = {}, + tinytc::location const &lc = {}); + ~tinytc_region(); inline auto kind() const noexcept -> tinytc::region_kind { return kind_; } inline void kind(tinytc::region_kind kind) noexcept { kind_ = kind; } @@ -49,14 +51,27 @@ struct tinytc_region final { inline auto insts() const -> tinytc::ilist const & { return insts_; } inline auto empty() const -> bool { return insts_.empty(); } + inline auto param_begin() { return params_.begin(); } + inline auto param_end() { return params_.end(); } + inline auto params() { return tinytc::iterator_range_wrapper{param_begin(), param_end()}; } + inline auto param_begin() const { return params_.begin(); } + inline auto param_end() const { return params_.end(); } + inline auto param(std::int64_t pos) const -> tinytc::value const & { return params_[pos]; } + inline auto params() const { + return tinytc::iterator_range_wrapper{param_begin(), param_end()}; + } + inline auto num_params() const noexcept -> std::int64_t { return params_.size(); } + void add_param(tinytc_data_type_t param_type); + private: static auto inst_list_offset() -> std::size_t { static_assert(std::is_standard_layout_v, "offsetof not guaranteed to work"); return offsetof(tinytc_region, insts_); } - friend struct tinytc::ilist_traits; + friend struct tinytc::ilist_callbacks; tinytc::region_kind kind_; + std::vector params_; tinytc::ilist insts_; tinytc::location loc_; }; diff --git a/src/parser/parse_context.cpp b/src/parser/parse_context.cpp index a52fb739..1f3c6ff8 100644 --- a/src/parser/parse_context.cpp +++ b/src/parser/parse_context.cpp @@ -12,10 +12,20 @@ namespace tinytc { +parse_context::parse_context(compiler_context compiler_ctx) : compiler_ctx_(compiler_ctx) {} + void parse_context::push_scope() { id_map_.push_back({}); } void parse_context::pop_scope() { id_map_.pop_back(); } +void parse_context::push_region(tinytc_region_t r) { regions_.push(r); } +void parse_context::pop_region() { regions_.pop(); } +auto parse_context::top_region() -> tinytc_region_t { return regions_.top(); } +auto parse_context::has_regions() -> bool { return !regions_.empty(); } + void parse_context::val(std::string const &id, value val, location const &l) { + if (id_map_.empty()) { + throw parser::syntax_error(l, "No active variable scope"); + } for (auto it = id_map_.rbegin(); it != id_map_.rend(); ++it) { if (auto other = it->find(id); other != it->end()) { auto oss = std::ostringstream{}; diff --git a/src/parser/parse_context.hpp b/src/parser/parse_context.hpp index 1a930db3..ecad2fe3 100644 --- a/src/parser/parse_context.hpp +++ b/src/parser/parse_context.hpp @@ -7,6 +7,7 @@ #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" +#include #include #include #include @@ -16,14 +17,10 @@ namespace tinytc { class parse_context { public: - inline parse_context(compiler_context compiler_ctx) : compiler_ctx_(compiler_ctx) { - id_map_.push_back({}); - } + parse_context(compiler_context compiler_ctx); inline auto program() { return program_; } inline void program(prog p) { program_ = std::move(p); } - void push_scope(); - void pop_scope(); void val(std::string const &id, value val, location const &l); value val(std::string const &id, location const &l); @@ -31,9 +28,18 @@ class parse_context { auto cctx() -> compiler_context const & { return compiler_ctx_; } + void push_scope(); + void pop_scope(); + + void push_region(tinytc_region_t r); + void pop_region(); + auto top_region() -> tinytc_region_t; + auto has_regions() -> bool; + private: compiler_context compiler_ctx_; std::vector> id_map_; + std::stack regions_; prog program_; }; diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index ad758e1a..172a1bd2 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -6,26 +6,29 @@ %code requires { #include "node/function_node.hpp" + #include "node/inst_node.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" #include #include #include + #include #include #include + #include namespace tinytc { class parse_context; class lexer; using int_or_val = std::variant; + using unique_ptr_to_if_inst = std::unique_ptr; } } %code { #include "error.hpp" #include "node/data_type_node.hpp" - #include "node/inst_node.hpp" #include "node/program_node.hpp" #include "node/region_node.hpp" #include "node/value_node.hpp" @@ -37,12 +40,11 @@ #include #include #include - #include #include - #include #include namespace tinytc { + void check_scalar_type(compiler_context const &ctx, value &val, scalar_type const &sty, location &loc1, location &loc2) { if (val->ty() != get_scalar(ctx, sty)) { @@ -148,8 +150,8 @@ %nterm prog %nterm > func_list %nterm func -%nterm > arguments -%nterm <::tinytc::value> argument +%nterm ,std::vector>> parameters +%nterm > parameter %nterm >> attributes %nterm > attribute %nterm data_type @@ -163,9 +165,7 @@ %nterm group_type %nterm group_offset %nterm memref_or_group_type -%nterm region %nterm <::tinytc::value> var -%nterm > instructions %nterm instruction %nterm axpby_inst %nterm atomic @@ -186,10 +186,9 @@ %nterm > optional_returned_values %nterm > optional_scalar_type_list %nterm > scalar_type_list -%nterm else_region %nterm sum_inst %nterm yield_inst -%nterm for_loop_var_type +%nterm for_loop_var_type %nterm var_definition %nterm > identifier_list %nterm valued_inst @@ -236,34 +235,52 @@ func_list: | func_list func { $$ = std::move($1); $$.emplace_back(std::move($func)); } func: - FUNC { - ctx.push_scope(); - } GLOBAL_IDENTIFIER LPAREN arguments RPAREN attributes region { + FUNC GLOBAL_IDENTIFIER LPAREN parameters RPAREN attributes { auto loc = @FUNC; loc.end = @RPAREN.end; - auto func_node = std::make_unique($GLOBAL_IDENTIFIER, std::move($arguments), - $region.release(), loc) - .release(); - for (auto &attr : $attributes) { - attr(*func_node); + try { + auto func_node = + std::make_unique($GLOBAL_IDENTIFIER, $parameters.second, loc); + for (auto &attr : $attributes) { + attr(*func_node); + } + ctx.push_scope(); + auto name_it = $parameters.first.begin(); + for (auto &p : func_node->params()) { + p.name(*name_it); + ctx.val(*name_it, p, @parameters); + ++name_it; + } + ctx.push_region(&func_node->body()); + $$ = func{func_node.release()}; + } catch (compilation_error const &e) { + error(e.loc(), e.what()); + YYERROR; } - $func = func{func_node}; + }[prototype] region { + ctx.pop_region(); ctx.pop_scope(); + $$ = std::move($prototype); } ; -arguments: +parameters: %empty {} - | argument { $$.emplace_back(std::move($argument)); } - | arguments COMMA argument { $$ = std::move($1); $$.emplace_back(std::move($argument)); } + | parameter { + $$.first.emplace_back(std::move($parameter.first)); + $$.second.emplace_back(std::move($parameter.second)); + } + | parameters COMMA parameter { + $$.first = std::move($1.first); + $$.second = std::move($1.second); + $$.first.emplace_back(std::move($parameter.first)); + $$.second.emplace_back(std::move($parameter.second)); + } ; -argument: +parameter: LOCAL_IDENTIFIER COLON data_type { - auto v = make_value(std::move($data_type)); - v.name($LOCAL_IDENTIFIER); - ctx.val($LOCAL_IDENTIFIER, v, @LOCAL_IDENTIFIER); - $$ = std::move(v); + $$ = std::make_pair($LOCAL_IDENTIFIER, $data_type); } ; @@ -373,27 +390,23 @@ memref_or_group_type: | group_type ; -region: - LBRACE { - ctx.push_scope(); - } instructions RBRACE { - $$ = region{std::make_unique(@region).release()}; - for (auto& i : $instructions) { - $$.add_instruction(std::move(i)); - } - ctx.pop_scope(); - } -; - var: LOCAL_IDENTIFIER { $$ = ctx.val($LOCAL_IDENTIFIER, @LOCAL_IDENTIFIER); } ; +region: + LBRACE { ctx.push_scope(); } instructions { ctx.pop_scope(); } RBRACE {} +; + instructions: %empty {} | instructions instruction { - $$ = std::move($1); - $$.emplace_back(std::move($instruction)); + if (!ctx.has_regions()) { + error(@instruction, "Internal error: missing region"); + YYERROR; + } + tinytc_region_t reg = ctx.top_region(); + reg->insts().push_back($instruction.release()); } ; @@ -553,28 +566,30 @@ ger_inst: ; for_inst: - FOR LOCAL_IDENTIFIER[loop_var] - EQUALS var[from] COMMA var[to] optional_step - for_loop_var_type { - check_scalar_type(ctx.cctx(), $from, $for_loop_var_type, @from, @for_loop_var_type); - check_scalar_type(ctx.cctx(), $to, $for_loop_var_type, @to, @for_loop_var_type); + FOR LOCAL_IDENTIFIER[loop_var] EQUALS var[from] COMMA var[to] optional_step for_loop_var_type { + check_type($from, $for_loop_var_type, @from, @for_loop_var_type); + check_type($to, $for_loop_var_type, @to, @for_loop_var_type); if ($optional_step) { - check_scalar_type(ctx.cctx(), $optional_step, $for_loop_var_type, @optional_step, @for_loop_var_type); + check_type($optional_step, $for_loop_var_type, @optional_step, @for_loop_var_type); } - auto v = make_value(get_scalar(ctx.cctx(), $for_loop_var_type)); - v.name($loop_var); - ctx.val($loop_var, std::move(v), @loop_var); - } region { try { - $$ = inst { - std::make_unique(ctx.val($loop_var, @loop_var), $from, $to, - $optional_step, std::move($region), @for_inst) - .release() - }; + location loc = @FOR; + loc.end = @for_loop_var_type.end; + auto inode = std::make_unique($from, $to, $optional_step, $for_loop_var_type, loc); + ctx.push_scope(); + auto loop_var = inode->loop_var(); + loop_var->name($loop_var); + ctx.val($loop_var, std::move(loop_var), @loop_var); + ctx.push_region(&inode->body()); + $$ = inst{inode.release()}; } catch (compilation_error const &e) { error(e.loc(), e.what()); YYERROR; } + }[loop_header] region { + ctx.pop_region(); + ctx.pop_scope(); + $$ = std::move($loop_header); } ; @@ -583,30 +598,34 @@ optional_step: | COMMA var { $$ = $var; } foreach_inst: - FOREACH LOCAL_IDENTIFIER[loop_var] - EQUALS var[from] COMMA var[to] for_loop_var_type { - check_scalar_type(ctx.cctx(), $from, $for_loop_var_type, @from, @for_loop_var_type); - check_scalar_type(ctx.cctx(), $to, $for_loop_var_type, @to, @for_loop_var_type); - auto v = make_value(get_scalar(ctx.cctx(), $for_loop_var_type)); - v.name($loop_var); - ctx.val($loop_var, std::move(v), @loop_var); - } region { + FOREACH LOCAL_IDENTIFIER[loop_var] EQUALS var[from] COMMA var[to] for_loop_var_type { + check_type($from, $for_loop_var_type, @from, @for_loop_var_type); + check_type($to, $for_loop_var_type, @to, @for_loop_var_type); try { - $$ = inst { - std::make_unique(ctx.val($loop_var, @loop_var), $from, $to, - std::move($region), @foreach_inst) - .release() - }; + location loc = @FOREACH; + loc.end = @for_loop_var_type.end; + auto inode = + std::make_unique($from, $to, $for_loop_var_type, loc); + ctx.push_scope(); + auto loop_var = inode->loop_var(); + loop_var->name($loop_var); + ctx.val($loop_var, std::move(loop_var), @loop_var); + ctx.push_region(&inode->body()); + $$ = inst{inode.release()}; } catch (compilation_error const &e) { error(e.loc(), e.what()); YYERROR; } + }[loop_header] region { + ctx.pop_region(); + ctx.pop_scope(); + $$ = std::move($loop_header); } ; for_loop_var_type: - %empty { $$ = scalar_type::index; } - | COLON INTEGER_TYPE { $$ = $INTEGER_TYPE; } + %empty { $$ = get_scalar(ctx.cctx(), scalar_type::index); } + | COLON INTEGER_TYPE { $$ = get_scalar(ctx.cctx(), $INTEGER_TYPE); } ; var_definition: @@ -943,19 +962,31 @@ group_size_inst: ; if_inst: - IF var[condition] optional_returned_values region else_region { + IF var[condition] optional_returned_values { check_scalar_type(ctx.cctx(), $condition, scalar_type::i1, @condition, @condition); - $$ = inst{std::make_unique(std::move($condition), std::move($region), - std::move($else_region), - std::move($optional_returned_values)) - .release()}; - $$->loc(@if_inst); + try { + auto loc = @IF; + loc.end = @optional_returned_values.end; + auto inode = std::make_unique(std::move($condition), + std::move($optional_returned_values), loc); + ctx.push_region(&inode->then()); + $$ = std::move(inode); + } catch (compilation_error const &e) { + error(e.loc(), e.what()); + YYERROR; + } + }[header] region { + ctx.pop_region(); + ctx.push_region(&$header->otherwise()); + } else_region { + ctx.pop_region(); + $$ = inst{$header.release()}; } ; else_region: - %empty { $$ = {}; } - | ELSE region{ $$ = std::move($region); } + %empty {} + | ELSE region {} ; optional_returned_values: @@ -980,8 +1011,18 @@ num_subgroups_inst: ; parallel_inst: - PARALLEL region { - $$ = inst{std::make_unique(std::move($region), @parallel_inst) .release()}; + PARALLEL { + try { + auto inode = std::make_unique(@PARALLEL); + ctx.push_region(&inode->body()); + $$ = inst{inode.release()}; + } catch (compilation_error const &e) { + error(e.loc(), e.what()); + YYERROR; + } + }[header] region { + ctx.pop_region(); + $$ = std::move($header); } ; diff --git a/src/pass/check_ir.cpp b/src/pass/check_ir.cpp index c310ad07..873d805d 100644 --- a/src/pass/check_ir.cpp +++ b/src/pass/check_ir.cpp @@ -14,7 +14,7 @@ namespace tinytc { void check_ir_pass::run_on_function(function_node &fn) { walk(fn, [this](inst_node const &i, walk_stage const &stage) { const bool child_region_is_spmd_region = - i.num_child_regions() > 0 && i.child_region(0)->kind() == region_kind::spmd; + i.num_child_regions() > 0 && i.child_region(0).kind() == region_kind::spmd; if (stage.is_before_all_regions()) { if (i.kind() == inst_execution_kind::collective && inside_spmd_region_) { diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index b8c48f94..a674e8a7 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -842,7 +842,7 @@ std::vector convert_to_opencl_pass::operator()(if_inst const &in) { } auto ib = clir::if_selection_builder(val(*in.condition())); ib.set_then(run_on_region(in.then())); - if (in.has_otherwise()) { + if (!in.is_otherwise_empty()) { ib.set_otherwise(run_on_region(in.otherwise())); } yielded_vars_.pop_back(); @@ -1093,7 +1093,7 @@ auto convert_to_opencl_pass::run_on_function(function_node const &fn) -> clir::f // Create prototype auto fb = clir::kernel_builder(std::string(fn.name())); - for (auto const &v : fn.args()) { + for (auto const &v : fn.params()) { fb.argument(visit(*this, *v->ty()), declare(*v)); auto dv = visit( overloaded{[&fb, &v](memref_data_type const &) -> std::optional { @@ -1136,13 +1136,13 @@ auto convert_to_opencl_pass::run_on_function(function_node const &fn) -> clir::f /* Program nodes */ auto convert_to_opencl_pass::run_on_program(program_node const &p) -> clir::prog { reserved_names_.clear(); - for (auto const &fn : p.functions()) { - reserved_names_.insert(std::string(fn->name())); + for (auto const &fn : p) { + reserved_names_.insert(std::string(fn.name())); } prog_builder_ = clir::program_builder{}; - for (auto const &fn : p.functions()) { - prog_builder_.add(run_on_function(*fn)); + for (auto const &fn : p) { + prog_builder_.add(run_on_function(fn)); } return prog_builder_.get_product(); } diff --git a/src/pass/dump_cfg.cpp b/src/pass/dump_cfg.cpp index 59593876..ad858f74 100644 --- a/src/pass/dump_cfg.cpp +++ b/src/pass/dump_cfg.cpp @@ -16,7 +16,7 @@ namespace tinytc { dump_cfg_pass::dump_cfg_pass(std::ostream &os) : os_(&os) {} -void dump_cfg_pass::run_on_function(function_node const &fn) { +void dump_cfg_pass::run_on_function(function_node &fn) { auto dump_ir = dump_ir_pass(*os_, 0); *os_ << "digraph " << fn.name() << " {" << std::endl; diff --git a/src/pass/dump_cfg.hpp b/src/pass/dump_cfg.hpp index 08cfeea3..65d765d8 100644 --- a/src/pass/dump_cfg.hpp +++ b/src/pass/dump_cfg.hpp @@ -14,7 +14,7 @@ class dump_cfg_pass { public: dump_cfg_pass(std::ostream &os); - void run_on_function(function_node const &fn); + void run_on_function(function_node &fn); private: std::ostream *os_; diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 5e15c703..6b17c5a3 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -297,7 +297,7 @@ void dump_ir_pass::operator()(if_inst const &in) { dump_val(*in.condition()); *os_ << " "; dump_region(in.then()); - if (in.has_otherwise()) { + if (!in.is_otherwise_empty()) { *os_ << " else "; dump_region(in.otherwise()); } @@ -423,7 +423,7 @@ void dump_ir_pass::run_on_function(function_node const &fn) { std::string infix = ",\n "; infix += std::string(fn.name().size(), ' '); do_with_infix( - fn.args().begin(), fn.args().end(), + fn.params().begin(), fn.params().end(), [this](auto const &a) { dump_val(*a); *os_ << ": "; diff --git a/src/pass/insert_lifetime_stop.cpp b/src/pass/insert_lifetime_stop.cpp index 6953ca0c..596e225b 100644 --- a/src/pass/insert_lifetime_stop.cpp +++ b/src/pass/insert_lifetime_stop.cpp @@ -34,7 +34,7 @@ auto insert_lifetime_stop_pass::run_on_region(region_node ®, aa_results const while (prev_it != reg.begin()) { auto &i = *(--prev_it); for (auto &subreg : i.child_regions()) { - rgn_ops.merge(run_on_region(*subreg, aa)); + rgn_ops.merge(run_on_region(subreg, aa)); } for (auto &v : i.operands()) { if (isa(*v->ty())) { diff --git a/src/pass/slot_tracker.cpp b/src/pass/slot_tracker.cpp index 749ded13..5356799a 100644 --- a/src/pass/slot_tracker.cpp +++ b/src/pass/slot_tracker.cpp @@ -18,7 +18,7 @@ void slot_tracker::set_slot(value_node const &v) { void slot_tracker::run_on_function(function_node &fn) { slot_ = 0; - for (auto const &arg : fn.args()) { + for (auto const &arg : fn.params()) { set_slot(*arg); } walk(fn, [this](inst_node const &i) { diff --git a/src/passes.hpp b/src/passes.hpp index 3f3cfd7e..8a8223f1 100644 --- a/src/passes.hpp +++ b/src/passes.hpp @@ -10,14 +10,14 @@ namespace tinytc { template void run_function_pass(FunctionPass &&pass, tinytc_prog &p) { - for (auto &fun : p.functions()) { - pass.run_on_function(*fun); + for (auto &fun : p) { + pass.run_on_function(fun); } } template void run_function_pass(FunctionPass &&pass, tinytc_prog const &p) { - for (auto const &fun : p.functions()) { - pass.run_on_function(*fun); + for (auto const &fun : p) { + pass.run_on_function(fun); } } diff --git a/src/prog.cpp b/src/prog.cpp index d0ad3815..9b9d3e6c 100644 --- a/src/prog.cpp +++ b/src/prog.cpp @@ -25,8 +25,8 @@ using namespace tinytc; extern "C" { -tinytc_status_t tinytc_program_create(tinytc_prog_t *prg, tinytc_compiler_context_t ctx, - const tinytc_location_t *loc) { +tinytc_status_t tinytc_prog_create(tinytc_prog_t *prg, tinytc_compiler_context_t ctx, + const tinytc_location_t *loc) { if (prg == nullptr) { return tinytc_status_invalid_arguments; } @@ -40,7 +40,7 @@ tinytc_status_t tinytc_prog_add_function(tinytc_prog_t prg, tinytc_func_t fun) { if (prg == nullptr || fun == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { prg->push_back(fun); }); + return exception_to_status_code([&] { prg->push_back(tinytc::func{fun}); }); } tinytc_status_t tinytc_prog_release(tinytc_prog_t obj) { diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index 8052cf5c..4f96399b 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -85,52 +85,53 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( auto const tA_ = enum_cast(tA); auto const tB_ = enum_cast(tB); - auto const kernel = [&](function_builder &fb, bool is_beta_nonzero) { - auto alpha = fb.argument(ty_, "alpha"); - auto A = fb.argument(get_memref(ctx_, enum_cast(ty), - {selA(M, K), selA(K, M), dynamic}, - {1, ldA, strideA}, address_space::global, my_loc()), - "A", my_loc()); - auto B = fb.argument(get_memref(ctx_, enum_cast(ty), - {selB(K, N), selB(N, K), dynamic}, - {1, ldB, strideB}, address_space::global, my_loc()), - "B", my_loc()); - auto beta_arg = fb.argument(ty_, "beta"); - auto C = fb.argument(get_memref(ctx_, enum_cast(ty), {M, N, dynamic}, - {1, ldC, strideC}, address_space::global, my_loc()), - "C", my_loc()); + auto const kernel = [&](char const *name, bool is_beta_nonzero) { + auto A_ty = + get_memref(ctx_, enum_cast(ty), {selA(M, K), selA(K, M), dynamic}, + {1, ldA, strideA}, address_space::global, my_loc()); + auto B_ty = + get_memref(ctx_, enum_cast(ty), {selB(K, N), selB(N, K), dynamic}, + {1, ldB, strideB}, address_space::global, my_loc()); + auto C_ty = get_memref(ctx_, enum_cast(ty), {M, N, dynamic}, + {1, ldC, strideC}, address_space::global, my_loc()); + auto f = make_func(name, {ty_, A_ty, B_ty, ty_, C_ty}, my_loc()); + auto fn_body = f.get_body(); + auto alpha = get_parameter(fn_body, 0); + set_name(alpha, "alpha"); + auto A = get_parameter(fn_body, 1); + set_name(A, "A"); + auto B = get_parameter(fn_body, 2); + set_name(B, "B"); + auto beta = get_parameter(fn_body, 3); + set_name(beta, "beta"); + auto C = get_parameter(fn_body, 4); + set_name(C, "C"); - fb.body( - [&](region_builder &bb) { - auto gid = bb.add(make_group_id(ctx_, my_loc())); - auto const static_offsets = std::vector{0, 0}; - auto const A_static_sizes = std::vector{M, K}; - auto const B_static_sizes = std::vector{K, N}; - auto const C_static_sizes = std::vector{M, N}; - auto a = bb.add( - make_subview(A, static_offsets, A_static_sizes, {}, {}, my_loc())); - auto b = bb.add( - make_subview(B, static_offsets, B_static_sizes, {}, {}, my_loc())); - auto c = bb.add( - make_subview(C, static_offsets, C_static_sizes, {}, {}, my_loc())); - auto beta = is_beta_nonzero ? std::move(beta_arg) - : bb.add(make_constant(0.0, ty_, my_loc())); - bb.add(make_gemm(tA_, tB_, false, alpha, std::move(a), std::move(b), beta, - std::move(c), my_loc())); - }, - my_loc()); - }; - auto p = [&] { - auto pb = program_builder{ctx_, my_loc()}; - pb.create( - small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm), - [&](function_builder &fb) { kernel(fb, true); }, my_loc()); - pb.create( - small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm_beta0), - [&](function_builder &fb) { kernel(fb, false); }, my_loc()); + auto bb = region_builder{fn_body}; + + auto gid = bb.add(make_group_id(ctx_, my_loc())); + auto const static_offsets = std::vector{0, 0}; + auto const A_static_sizes = std::vector{M, K}; + auto const B_static_sizes = std::vector{K, N}; + auto const C_static_sizes = std::vector{M, N}; + auto a = bb.add( + make_subview(value{A, true}, static_offsets, A_static_sizes, {}, {}, my_loc())); + auto b = bb.add( + make_subview(value{B, true}, static_offsets, B_static_sizes, {}, {}, my_loc())); + auto c = bb.add( + make_subview(value{C, true}, static_offsets, C_static_sizes, {}, {}, my_loc())); + auto beta_val = + is_beta_nonzero ? value{beta, true} : bb.add(make_constant(0.0, ty_, my_loc())); + bb.add(make_gemm(tA_, tB_, false, value{alpha, true}, std::move(a), std::move(b), + beta_val, std::move(c), my_loc())); - return std::move(pb).get_product(); - }(); + return f; + }; + auto p = make_prog(ctx_, my_loc()); + p.add_function( + kernel(small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm), true)); + p.add_function(kernel( + small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm_beta0), false)); tinytc_source_t src; CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, p.get(), info)); *recipe = std::make_unique(std::move(p), source(src), diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index a78a5fd5..1b6dc732 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -104,8 +104,9 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( tiling[1] /= 2; } - auto const body = [&](region_builder &bb, value &alpha, value &A, value &B, - bool is_beta_nonzero, value &beta_arg, value &C) { + auto const body = [&](region_builder &bb, value const &alpha, value const &A, + value const &B, bool is_beta_nonzero, value const &beta_arg, + value const &C) { auto c_M_block_size = bb.add(make_constant(M_block_size, index_ty, my_loc())); auto gid = bb.add(make_group_id(ctx_, my_loc())); auto m = bb.add(make_arith(arithmetic::mul, gid, c_M_block_size, my_loc())); @@ -150,39 +151,39 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( } }; - auto const kernel = [&](function_builder &fb, bool is_beta_nonzero) { - auto alpha = fb.argument(ty_, "alpha", my_loc()); - auto A = fb.argument(get_memref(ctx_, enum_cast(ty), {M, K}, {1, ldA}, - address_space::global, my_loc()), - "A", my_loc()); - auto B = fb.argument(get_memref(ctx_, enum_cast(ty), {K, N}, {1, ldB}, - address_space::global, my_loc()), - "B", my_loc()); - auto beta_arg = fb.argument(ty_, "beta", my_loc()); - auto C = fb.argument(get_memref(ctx_, enum_cast(ty), {M, N}, {1, ldC}, - address_space::global, my_loc()), - "C", my_loc()); - fb.subgroup_size(sgs); + auto const kernel = [&](char const *name, bool is_beta_nonzero) { + auto A_ty = get_memref(ctx_, enum_cast(ty), {M, K}, {1, ldA}, + address_space::global, my_loc()); + auto B_ty = get_memref(ctx_, enum_cast(ty), {K, N}, {1, ldB}, + address_space::global, my_loc()); + auto C_ty = get_memref(ctx_, enum_cast(ty), {M, N}, {1, ldC}, + address_space::global, my_loc()); + auto f = make_func(name, {ty_, A_ty, B_ty, ty_, C_ty}, my_loc()); + auto fn_body = f.get_body(); + auto alpha = get_parameter(fn_body, 0); + set_name(alpha, "alpha"); + auto A = get_parameter(fn_body, 1); + set_name(A, "A"); + auto B = get_parameter(fn_body, 2); + set_name(B, "B"); + auto beta = get_parameter(fn_body, 3); + set_name(beta, "beta"); + auto C = get_parameter(fn_body, 4); + set_name(C, "C"); + f.set_subgroup_size(sgs); auto const wgs = tiling.work_group_size(sgs); - fb.work_group_size(wgs[0], wgs[1]); + f.set_work_group_size(wgs[0], wgs[1]); - fb.body( - [&](region_builder &bb) { - body(bb, alpha, A, B, is_beta_nonzero, beta_arg, C); - }, - my_loc()); + auto bb = region_builder{fn_body}; + body(bb, value{alpha, true}, value{A, true}, value{B, true}, is_beta_nonzero, + value{beta, true}, value{C, true}); + return f; }; - auto p = [&] { - auto pb = program_builder{ctx_, my_loc()}; - pb.create( - tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm), - [&](function_builder &fb) { kernel(fb, true); }, my_loc()); - pb.create( - tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm_beta0), - [&](function_builder &fb) { kernel(fb, false); }, my_loc()); - return std::move(pb).get_product(); - }(); + auto p = make_prog(ctx_, my_loc()); + p.add_function(kernel(tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm), true)); + p.add_function( + kernel(tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm_beta0), false)); tinytc_source_t src; CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, p.get(), info)); *recipe = std::make_unique(std::move(p), source(src), diff --git a/src/region.cpp b/src/region.cpp index b8e22f16..69d5a6ce 100644 --- a/src/region.cpp +++ b/src/region.cpp @@ -15,20 +15,42 @@ using namespace tinytc; extern "C" { -tinytc_status_t tinytc_region_create(tinytc_region_t *reg, const tinytc_location_t *loc) { - if (reg == nullptr) { +tinytc_status_t tinytc_region_add_instruction(tinytc_region_t reg, tinytc_inst_t instruction) { + if (reg == nullptr || instruction == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code( - [&] { *reg = std::make_unique(get_optional(loc)).release(); }); + return exception_to_status_code([&] { reg->insts().push_back(instruction); }); } -tinytc_status_t tinytc_region_add_instruction(tinytc_region_t reg, tinytc_inst_t instruction) { - if (reg == nullptr || instruction == nullptr) { +tinytc_status_t tinytc_region_get_parameter(tinytc_region_t reg, uint32_t param_no, + tinytc_value_t *result) { + if (reg == nullptr || result == nullptr || param_no >= reg->num_params()) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { reg->insts().push_back(instruction); }); + return exception_to_status_code([&] { *result = reg->param(param_no).get(); }); } -void tinytc_region_destroy(tinytc_region_t obj) { delete obj; } +tinytc_status_t tinytc_region_get_parameters(tinytc_region_t reg, uint32_t *result_list_size, + tinytc_value_t *result_list) { + + if (reg == nullptr || result_list_size == nullptr || + (*result_list_size > 0 && result_list == nullptr)) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + auto const num_results = reg->num_params(); + if (num_results > std::numeric_limits::max()) { + throw std::out_of_range("too many results"); + } + auto const num = static_cast(num_results); + if (*result_list_size > 0) { + auto results = reg->param_begin(); + auto const limit = std::min(num, *result_list_size); + for (uint32_t i = 0; i < limit; ++i) { + result_list[i] = results[i].get(); + } + } + *result_list_size = num; + }); +} } diff --git a/src/scalar_type.cpp b/src/scalar_type.cpp index c5407de5..ac879872 100644 --- a/src/scalar_type.cpp +++ b/src/scalar_type.cpp @@ -32,6 +32,21 @@ bool is_complex_type(scalar_type ty) { return false; } +bool is_integer_type(scalar_type ty) { + switch (ty) { + case scalar_type::i1: + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return true; + default: + break; + } + return false; +} + scalar_type element_type(scalar_type ty) { switch (ty) { case scalar_type::c32: diff --git a/src/scalar_type.hpp b/src/scalar_type.hpp index 2b1c6d5a..7380a957 100644 --- a/src/scalar_type.hpp +++ b/src/scalar_type.hpp @@ -13,6 +13,7 @@ namespace tinytc { bool is_floating_type(scalar_type ty); bool is_complex_type(scalar_type ty); +bool is_integer_type(scalar_type ty); scalar_type element_type(scalar_type ty); clir::data_type to_clir_ty(scalar_type ty, clir::address_space as = clir::address_space::generic_t, clir::type_qualifier q = clir::type_qualifier::none); diff --git a/src/support/ilist.hpp b/src/support/ilist.hpp index 3406904f..a953c92a 100644 --- a/src/support/ilist.hpp +++ b/src/support/ilist.hpp @@ -8,9 +8,12 @@ namespace tinytc { -template struct ilist_traits; +template struct ilist_callbacks { + void node_added(NodeT *) {} + void node_removed(NodeT *) {} +}; -template > +template > class ilist : public ilist_base { public: ilist() = default; diff --git a/src/support/ilist_base.hpp b/src/support/ilist_base.hpp index b2d9cc4b..675c958b 100644 --- a/src/support/ilist_base.hpp +++ b/src/support/ilist_base.hpp @@ -15,18 +15,20 @@ template class ilist_iterator; template class ilist_node { public: - auto prev() const -> NodeT * { return prev_; } - void prev(NodeT *prev) { prev_ = prev; } - auto next() const -> NodeT * { return next_; } - void next(NodeT *next) { next_ = next; } + using node_type = NodeT; + + auto prev() const -> ilist_node * { return prev_; } + void prev(ilist_node *prev) { prev_ = prev; } + auto next() const -> ilist_node * { return next_; } + void next(ilist_node *next) { next_ = next; } auto sentinel() const -> bool { return sentinel_; } void set_sentinel() { sentinel_ = true; } - auto iterator() -> ilist_iterator { return {this}; } + auto iterator() -> ilist_iterator, false> { return {this}; } private: - NodeT *prev_ = nullptr, *next_ = nullptr; + ilist_node *prev_ = nullptr, *next_ = nullptr; bool sentinel_ = false; }; @@ -40,11 +42,12 @@ class ilist_node_with_parent : public ilist_node { ParentT *parent_ = nullptr; }; -template class ilist_iterator { +template class ilist_iterator { public: - using base_type = std::conditional_t, ilist_node>; + using base_type = std::conditional_t; using base_pointer = base_type *; - using value_type = std::conditional_t; + using node_type = typename IListNodeT::node_type; + using value_type = std::conditional_t; using pointer = value_type *; using reference = value_type &; using difference_type = std::ptrdiff_t; @@ -55,22 +58,23 @@ template class ilist_iterator { auto operator*() const -> reference { return *static_cast(pos_); } auto operator->() const -> pointer { return get(); } auto get() const -> pointer { return static_cast(pos_); } + auto get_base() const -> base_pointer { return pos_; } auto &operator++() { - pos_ = static_cast(pos_->next()); + pos_ = pos_->next(); return *this; } auto operator++(int) { auto old_pos = pos_; - pos_ = static_cast(pos_->next()); + pos_ = pos_->next(); return ilist_iterator{old_pos}; } auto &operator--() { - pos_ = static_cast(pos_->prev()); + pos_ = pos_->prev(); return *this; } auto operator--(int) { auto old_pos = pos_; - pos_ = static_cast(pos_->prev()); + pos_ = pos_->prev(); return ilist_iterator{old_pos}; } auto operator==(ilist_iterator const &other) const -> bool { return pos_ == other.pos_; } @@ -80,12 +84,7 @@ template class ilist_iterator { base_pointer pos_; }; -template struct ilist_dummy_callback { - void node_added(NodeT *) {} - void node_removed(NodeT *) {} -}; - -template > +template requires requires(IListCallback &cb, NodeT *node) { std::is_base_of_v, NodeT>; cb.node_added(node); @@ -93,21 +92,23 @@ requires requires(IListCallback &cb, NodeT *node) { } class ilist_base : protected IListCallback { public: + using base_type = ilist_node; + using base_pointer = base_type *; using value_type = NodeT; using size_type = std::size_t; using difference_type = std::ptrdiff_t; using pointer = value_type *; using reference = value_type &; using const_reference = const value_type &; - using iterator = ilist_iterator; - using const_iterator = ilist_iterator; + using iterator = ilist_iterator; + using const_iterator = ilist_iterator; static_assert(std::bidirectional_iterator); ilist_base() { sentinel_.set_sentinel(); // let's go in a circle - yay! - sentinel_.prev(static_cast(&sentinel_)); - sentinel_.next(static_cast(&sentinel_)); + sentinel_.prev(&sentinel_); + sentinel_.next(&sentinel_); } ~ilist_base() { clear(); } @@ -139,7 +140,7 @@ class ilist_base : protected IListCallback { // let s = sentinel // |0|: s{prev->s,next->s} // |1|: n0{prev->s,next->s}, s{prev->n0,next->n0} - pointer prev = it->prev(); + base_pointer prev = it.get_base()->prev(); prev->next(node); node->prev(prev); node->next(it.get()); @@ -169,17 +170,17 @@ class ilist_base : protected IListCallback { // |0|: s{prev->s,next->s} // |1|: n0{prev->s,next->s}, s{prev->n0,next->n0} // |2|: n0{prev->s,next->n1}, n1{prev->n0,next->s}, s{prev->n1,next->n0} - pointer prev = it->prev(); - pointer next = it->prev(); + base_pointer prev = it.get_base()->prev(); + base_pointer next = it.get_base()->prev(); prev->prev(next); next->prev(prev); - it->prev(nullptr); - it->next(nullptr); + it.get_base()->prev(nullptr); + it.get_base()->next(nullptr); // |0| (it -> s) : s{prev->s,next->s} // |1| (it -> n0): s{prev->s,next->s} // |2| (it -> n0): n1{prev->s,next->s}, s{prev->n1,next->n1} // |2| (it -> n1): n0{prev->s,next->s}, s{prev->n0,next->n0} - this->node_removed(it.get()); + // this->node_removed(&*it); return iterator{next}; } auto erase(iterator begin, iterator end) -> iterator { @@ -190,7 +191,7 @@ class ilist_base : protected IListCallback { } private: - ilist_node sentinel_; + base_type sentinel_; }; } // namespace tinytc diff --git a/src/support/util.hpp b/src/support/util.hpp index f6c91bba..e4414d18 100644 --- a/src/support/util.hpp +++ b/src/support/util.hpp @@ -15,16 +15,6 @@ template auto enum_cast(V val) { return T{std::underlying_type_t(val)}; } -template class iterator_range_wrapper { - public: - iterator_range_wrapper(ItT begin, ItT end) : begin_(std::move(begin)), end_(std::move(end)) {} - ItT begin() const { return begin_; } - ItT end() const { return end_; } - - private: - ItT begin_, end_; -}; - constexpr auto fnv1a0() -> std::uint64_t { return 0xcbf29ce484222325; } constexpr auto fnv1a_step(std::uint64_t hash, char ch) -> std::uint64_t { return (hash ^ ch) * 0x00000100000001b3; @@ -48,6 +38,91 @@ constexpr auto fnv1a(Head &&head, Tail &&...tail) -> std::uint64_t { return fnv1a_step(fnv1a_step(fnv1a0(), std::forward(tail)...), std::forward(head)); } +template class iterator_range_wrapper { + public: + iterator_range_wrapper(ItT begin, ItT end) : begin_(std::move(begin)), end_(std::move(end)) {} + ItT begin() const { return begin_; } + ItT end() const { return end_; } + + private: + ItT begin_, end_; +}; + +template class indirect_iterator : public IteratorT { + public: + using value_type = std::decay_t()))>; + using pointer = value_type *; + using reference = value_type &; + + auto operator*() const -> reference { return *(this->IteratorT::operator*()); } + auto operator->() const -> pointer { return &*(this->IteratorT::operator*()); } + auto operator[](std::size_t n) const -> reference { return *(this->IteratorT::operator[](n)); } +}; + +template class pointer_iterator { + public: + using value_type = T; + using pointer = value_type *; + using reference = value_type &; + using difference_type = std::ptrdiff_t; + + pointer_iterator() : ptr_{nullptr} {} + pointer_iterator(pointer ptr) : ptr_{std::move(ptr)} {} + + auto operator*() const -> reference { return *ptr_; } + auto operator->() const -> pointer { return ptr_; } + auto operator[](std::size_t n) const -> reference { return ptr_[n]; } + auto operator++() -> pointer_iterator & { + ++ptr_; + return *this; + } + auto operator++(int) -> pointer_iterator { + auto tmp = ptr_++; + return pointer_iterator{tmp}; + } + auto operator--() -> pointer_iterator & { + --ptr_; + return *this; + } + auto operator--(int) -> pointer_iterator { + auto tmp = ptr_--; + return pointer_iterator{tmp}; + } + auto operator-(pointer_iterator const &other) const -> difference_type { + return other.ptr_ - ptr_; + } + auto operator+=(std::ptrdiff_t n) -> pointer_iterator & { + ptr_ += n; + return *this; + } + auto operator-=(std::ptrdiff_t n) -> pointer_iterator & { + ptr_ -= n; + return *this; + } + auto operator==(pointer_iterator const &other) const -> bool { return ptr_ == other.ptr_; } + auto operator<=>(pointer_iterator const &other) const -> bool { return ptr_ <=> other.ptr_; } + + private: + pointer ptr_; +}; + +template +auto operator+(pointer_iterator const &p, std::ptrdiff_t n) -> pointer_iterator { + auto q = pointer_iterator{p}; + return q += n; +} + +template +auto operator+(std::ptrdiff_t n, pointer_iterator const &p) -> pointer_iterator { + return p + n; +} + +template +auto operator-(pointer_iterator const &p, std::ptrdiff_t n) -> pointer_iterator { + auto q = pointer_iterator{p}; + return q -= n; +} + } // namespace tinytc #endif // UTIL_20240201_HPP diff --git a/src/support/walk.cpp b/src/support/walk.cpp index a1a1762e..a1709627 100644 --- a/src/support/walk.cpp +++ b/src/support/walk.cpp @@ -14,7 +14,7 @@ void walk(inst_node &i, std::function void walk(inst_node &i, std::function(j, callback); } } @@ -48,13 +48,13 @@ template void walk(inst_node &i, std::function callback) { for (auto ® : i.child_regions()) { if constexpr (Order == walk_order::pre_order) { - callback(*reg); + callback(reg); } - for (auto &j : *reg) { + for (auto &j : reg) { walk(j, callback); } if constexpr (Order == walk_order::post_order) { - callback(*reg); + callback(reg); } } } diff --git a/src/value.cpp b/src/value.cpp index 251bfca9..a6c964e2 100644 --- a/src/value.cpp +++ b/src/value.cpp @@ -52,6 +52,13 @@ tinytc_status_t tinytc_value_set_name(tinytc_value_t vl, char const *name) { return exception_to_status_code([&] { vl->name(std::string(name)); }); } +tinytc_status_t tinytc_value_set_name_n(tinytc_value_t vl, uint32_t name_length, char const *name) { + if (vl == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { vl->name(std::string(name, name_length)); }); +} + tinytc_status_t tinytc_value_get_name(const_tinytc_value_t vl, char const **name) { if (vl == nullptr || name == nullptr) { return tinytc_status_invalid_arguments; diff --git a/test/codegen/if.ir b/test/codegen/if.ir index 1ba73c9d..4accd465 100644 --- a/test/codegen/if.ir +++ b/test/codegen/if.ir @@ -15,7 +15,6 @@ func @if0(%0: i32) { ; CHECK: bool x2 = x0 >= c0; ; CHECK: bool x3 = x1 && x2; ; CHECK: if (x3) { -; CHECK-NEXT: } else { ; CHECK-NEXT: } } @@ -26,7 +25,6 @@ func @if1(%0: i32) { } else { } ; CHECK: if (x1) { -; CHECK-NEXT: } else { ; CHECK-NEXT: } } diff --git a/test/opt/check-ir/nesting1.ir b/test/opt/check-ir/nesting1.ir index ec252c5d..8247c30b 100644 --- a/test/opt/check-ir/nesting1.ir +++ b/test/opt/check-ir/nesting1.ir @@ -8,6 +8,6 @@ func @illegal_nesting() { foreach %i=%lb,%ub { foreach %j=%lb,%ub { } -; CHECK: 9.9-10.9: Collective instruction must not be called from SPMD region +; CHECK: 9.9-26: Collective instruction must not be called from SPMD region } } diff --git a/test/opt/check-ir/nesting3.ir b/test/opt/check-ir/nesting3.ir index 4c896771..ab1558c4 100644 --- a/test/opt/check-ir/nesting3.ir +++ b/test/opt/check-ir/nesting3.ir @@ -8,6 +8,6 @@ func @illegal_nesting() { parallel { foreach %j=%lb,%ub { } -; CHECK: 9.9-10.9: Collective instruction must not be called from SPMD region +; CHECK: 9.9-26: Collective instruction must not be called from SPMD region } } From 6e3cd7b6446c07b468f174803e032a9c6442e1a3 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 30 Sep 2024 17:37:08 +0200 Subject: [PATCH 032/297] Change of ownership: value shared -> value owned by instruction that returns it Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.rst | 21 --- docs/api/builder_capi.yaml | 3 - docs/api/builder_cxxapi.rst | 35 +--- docs/api/builder_cxxapi.yaml | 7 +- include/tinytc/tinytc.h | 58 ++---- include/tinytc/tinytc.hpp | 287 +++++++++++------------------- src/CMakeLists.txt | 1 + src/analysis/alias.cpp | 14 +- src/codegen_tools.cpp | 42 +++-- src/codegen_tools.hpp | 21 +-- src/func.cpp | 2 +- src/inst.cpp | 167 ++++++----------- src/node/inst_node.cpp | 128 +++++++------ src/node/inst_node.hpp | 234 ++++++++++++------------ src/node/region_node.cpp | 15 +- src/node/region_node.hpp | 8 +- src/node/value_node.cpp | 7 + src/node/value_node.hpp | 8 +- src/parser/parse_context.cpp | 8 +- src/parser/parse_context.hpp | 6 +- src/parser/parser_impl.yy | 42 ++--- src/pass/convert_to_opencl.cpp | 259 ++++++++++++++------------- src/pass/convert_to_opencl.hpp | 4 +- src/pass/dump_ir.cpp | 144 +++++++-------- src/pass/insert_barrier.cpp | 12 +- src/pass/insert_lifetime_stop.cpp | 10 +- src/pass/slot_tracker.cpp | 4 +- src/pass/stack.cpp | 6 +- src/pass/work_group_size.cpp | 6 +- src/recipe/small_gemm_batched.cpp | 27 ++- src/recipe/tall_and_skinny.cpp | 15 +- src/region.cpp | 4 +- src/value.cpp | 28 --- 33 files changed, 713 insertions(+), 920 deletions(-) create mode 100644 src/node/value_node.cpp diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index c2375961..b30c770b 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -666,26 +666,15 @@ Value * Functions - * :ref:`tinytc_value_create` - * :ref:`tinytc_value_get_name` * :ref:`tinytc_value_set_name` * :ref:`tinytc_value_set_name_n` - * :ref:`tinytc_value_release` - - * :ref:`tinytc_value_retain` - Value Functions --------------- -tinytc_value_create -................... - -.. doxygenfunction:: tinytc_value_create - tinytc_value_get_name ..................... @@ -701,13 +690,3 @@ tinytc_value_set_name_n .. doxygenfunction:: tinytc_value_set_name_n -tinytc_value_release -.................... - -.. doxygenfunction:: tinytc_value_release - -tinytc_value_retain -................... - -.. doxygenfunction:: tinytc_value_retain - diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index 9ca7ec40..3b7218ee 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -103,9 +103,6 @@ Builder C-API: - tinytc_region_get_parameters Value: function: - - tinytc_value_create - tinytc_value_get_name - tinytc_value_set_name - tinytc_value_set_name_n - - tinytc_value_release - - tinytc_value_retain diff --git a/docs/api/builder_cxxapi.rst b/docs/api/builder_cxxapi.rst index e7e72d3e..d494f17a 100644 --- a/docs/api/builder_cxxapi.rst +++ b/docs/api/builder_cxxapi.rst @@ -248,9 +248,9 @@ Instruction * :ref:`make_axpby` - * :ref:`make_arith(arithmetic,value const&,value const&,location const&)` + * :ref:`make_arith(arithmetic,tinytc_value_t,tinytc_value_t,location const&)` - * :ref:`make_arith(arithmetic_unary,value const&,location const&)` + * :ref:`make_arith(arithmetic_unary,tinytc_value_t,location const&)` * :ref:`make_cast` @@ -325,15 +325,15 @@ make_axpby .. doxygenfunction:: tinytc::make_axpby -make_arith(arithmetic,value const&,value const&,location const&) -................................................................ +make_arith(arithmetic,tinytc_value_t,tinytc_value_t,location const&) +.................................................................... -.. doxygenfunction:: tinytc::make_arith(arithmetic,value const&,value const&,location const&) +.. doxygenfunction:: tinytc::make_arith(arithmetic,tinytc_value_t,tinytc_value_t,location const&) -make_arith(arithmetic_unary,value const&,location const&) -......................................................... +make_arith(arithmetic_unary,tinytc_value_t,location const&) +........................................................... -.. doxygenfunction:: tinytc::make_arith(arithmetic_unary,value const&,location const&) +.. doxygenfunction:: tinytc::make_arith(arithmetic_unary,tinytc_value_t,location const&) make_cast ......... @@ -565,14 +565,8 @@ Value * :ref:`get_name` - * :ref:`make_value` - * :ref:`set_name` -* Classes - - * :ref:`value` - Value Functions --------------- @@ -581,21 +575,8 @@ get_name .. doxygenfunction:: tinytc::get_name -make_value -.......... - -.. doxygenfunction:: tinytc::make_value - set_name ........ .. doxygenfunction:: tinytc::set_name -Value Classes -------------- - -value -..... - -.. doxygenclass:: tinytc::value - diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index 3f5bd53e..27123482 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -43,8 +43,8 @@ Builder C++-API: function: - tinytc::make_alloca - tinytc::make_axpby - - tinytc::make_arith(arithmetic,value const&,value const&,location const&) - - tinytc::make_arith(arithmetic_unary,value const&,location const&) + - tinytc::make_arith(arithmetic,tinytc_value_t,tinytc_value_t,location const&) + - tinytc::make_arith(arithmetic_unary,tinytc_value_t,location const&) - tinytc::make_cast - tinytc::make_cmp - tinytc::make_constant(std::complex,tinytc_data_type_t,location const&) @@ -91,7 +91,4 @@ Builder C++-API: Value: function: - tinytc::get_name - - tinytc::make_value - tinytc::set_name - class: - - tinytc::value diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 016ddd5e..027c9648 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -104,38 +104,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_group_type_get(tinytc_data_type_t *dt, /////////// Value ////////// //////////////////////////// -/** - * @brief Create value - * - * @param vl [out] pointer to the value object created - * @param type [in] data type object - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_value_create(tinytc_value_t *vl, tinytc_data_type_t type, - const tinytc_location_t *loc); - -/** - * @brief Release value object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param vl [inout] value object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_value_release(tinytc_value_t vl); - -/** - * @brief Increase reference count of value object by 1 - * - * @param vl [inout] value object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_value_retain(tinytc_value_t vl); - /** * @brief Set name of value * @@ -355,8 +323,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_axpby_inst_create(tinytc_inst_t *instr, tin */ TINYTC_EXPORT tinytc_status_t tinytc_expand_inst_create( tinytc_inst_t *instr, tinytc_value_t a, int64_t expanded_mode, - uint32_t static_expand_shape_size, int64_t *static_expand_shape, uint32_t expand_shape_size, - tinytc_value_t *expand_shape, const tinytc_location_t *loc); + uint32_t static_expand_shape_size, const int64_t *static_expand_shape, + uint32_t expand_shape_size, const tinytc_value_t *expand_shape, const tinytc_location_t *loc); /** * @brief Create fuse instruction @@ -391,7 +359,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_fuse_inst_create(tinytc_inst_t *instr, tiny */ TINYTC_EXPORT tinytc_status_t tinytc_load_inst_create(tinytc_inst_t *instr, tinytc_value_t a, uint32_t index_list_size, - tinytc_value_t *index_list, + const tinytc_value_t *index_list, const tinytc_location_t *loc); /** * @brief Create group_id instruction @@ -639,9 +607,10 @@ TINYTC_EXPORT tinytc_status_t tinytc_subgroup_size_inst_create(tinytc_inst_t *in * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_subview_inst_create( - tinytc_inst_t *instr, tinytc_value_t a, uint32_t static_list_size, int64_t *static_offset_list, - int64_t *static_size_list, uint32_t offset_list_size, tinytc_value_t *offset_list, - uint32_t size_list_size, tinytc_value_t *size_list, const tinytc_location_t *loc); + tinytc_inst_t *instr, tinytc_value_t a, uint32_t static_list_size, + const int64_t *static_offset_list, const int64_t *static_size_list, uint32_t offset_list_size, + const tinytc_value_t *offset_list, uint32_t size_list_size, const tinytc_value_t *size_list, + const tinytc_location_t *loc); /** * @brief Create store instruction @@ -660,7 +629,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_subview_inst_create( */ TINYTC_EXPORT tinytc_status_t tinytc_store_inst_create(tinytc_inst_t *instr, tinytc_value_t val, tinytc_value_t a, uint32_t index_list_size, - tinytc_value_t *index_list, + const tinytc_value_t *index_list, const tinytc_location_t *loc); /** @@ -757,7 +726,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, t */ TINYTC_EXPORT tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condition, uint32_t return_type_list_size, - tinytc_data_type_t *return_type_list, + const tinytc_data_type_t *return_type_list, const tinytc_location_t *loc); /** @@ -777,7 +746,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc */ TINYTC_EXPORT tinytc_status_t tinytc_yield_inst_create(tinytc_inst_t *instr, uint32_t yield_list_size, - tinytc_value_t *yield_list, + const tinytc_value_t *yield_list, const tinytc_location_t *loc); /** @@ -795,8 +764,7 @@ TINYTC_EXPORT void tinytc_inst_destroy(tinytc_inst_t instr); * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_inst_get_value(const_tinytc_inst_t instr, - tinytc_value_t *result); +TINYTC_EXPORT tinytc_status_t tinytc_inst_get_value(tinytc_inst_t instr, tinytc_value_t *result); /** * @brief Get values produced by instruction @@ -811,7 +779,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_inst_get_value(const_tinytc_inst_t instr, * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_inst_get_values(const_tinytc_inst_t instr, +TINYTC_EXPORT tinytc_status_t tinytc_inst_get_values(tinytc_inst_t instr, uint32_t *result_list_size, tinytc_value_t *result_list); @@ -912,7 +880,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_region_get_parameters(tinytc_region_t reg, */ TINYTC_EXPORT tinytc_status_t tinytc_func_create(tinytc_func_t *fun, uint32_t name_length, char const *name, uint32_t num_params, - tinytc_data_type_t *param_type_list, + const tinytc_data_type_t *param_type_list, const tinytc_location_t *loc); /** diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index c311c7c4..ca4454d3 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -339,6 +339,14 @@ template class array_view { template array_view(std::array const &arr) : data_{arr.data()}, size_{arr.size()} {} + /** + * @brief Convert initializer list to array view (array_view must be rvalue) + * + * @param arr initializer list + */ + array_view(std::initializer_list const &arr) + : data_{arr.begin() != arr.end() ? arr.begin() : nullptr}, size_{arr.size()} {} + //! Begin iterator auto begin() const -> const_iterator { return data_; } //! End iterator @@ -351,6 +359,12 @@ template class array_view { auto front() const -> T const & { return data_[0]; } //! Access last element; must not call when array size is 0 auto back() const -> T const & { return data_[size_ - 1]; } + //! Get data pointer + auto data() const -> T const * { return data_; } + //! Access operator + auto operator[](std::size_t n) const -> T const & { return data_[n]; } + //! Convert to vector + operator std::vector() const { return std::vector(data_, data_ + size_); } private: T const *data_ = nullptr; @@ -465,8 +479,8 @@ inline tinytc_data_type_t get_scalar(compiler_context const &ctx, scalar_type sc * @return Data type */ inline tinytc_data_type_t get_memref(compiler_context const &ctx, scalar_type scalar_ty, - std::vector const &shape, - std::vector const &stride = {}, + array_view shape, + array_view stride = {}, const address_space addrspace = address_space::global, location const &loc = {}) { tinytc_data_type_t mt; @@ -499,62 +513,6 @@ inline tinytc_data_type_t get_group(compiler_context const &ctx, tinytc_data_typ /////////// Value ////////// //////////////////////////// -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_value_t handle) -> tinytc_status_t { - return tinytc_value_retain(handle); - } - static auto release(tinytc_value_t handle) -> tinytc_status_t { - return tinytc_value_release(handle); - } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_value_t -class value : public shared_handle { - public: - using shared_handle::shared_handle; - - /** - * @brief Get name - * - * @return Name - */ - inline auto get_name() const -> char const * { - char const *name; - CHECK_STATUS(tinytc_value_get_name(obj_, &name)); - return name; - } - /** - * @brief Set name - * - * @param name Name - */ - inline void name(std::string_view name) { - CHECK_STATUS(tinytc_value_set_name_n(obj_, name.size(), name.data())); - } -}; - -namespace internal { -//! Is reinterpret_cast(&v) allowed, where v has type value -constexpr bool value_reinterpret_allowed = - std::is_standard_layout_v && sizeof(value) == sizeof(tinytc_value_t); -} // namespace internal - -/** - * @brief Make value - * - * @param ty Data type - * @param loc Source code location - * - * @return Value - */ -inline auto make_value(tinytc_data_type_t ty, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_value_create(&val, ty, &loc), loc); - return value{val}; -} - /** * @brief Get name * @@ -653,10 +611,10 @@ class inst : public unique_handle { * * @return Value; may be empty */ - inline auto get_value() const -> value { + inline auto get_value() const -> tinytc_value_t { tinytc_value_t result; CHECK_STATUS(tinytc_inst_get_value(obj_, &result)); - return value{result}; + return result; } /** @@ -675,12 +633,10 @@ class inst : public unique_handle { * * @return Vector of values */ - inline auto get_values() const -> std::vector { - static_assert(internal::value_reinterpret_allowed); + inline auto get_values() const -> std::vector { std::uint32_t result_list_size = get_num_values(); - auto values = std::vector(result_list_size); - tinytc_value_t *result_list = reinterpret_cast(values.data()); - CHECK_STATUS(tinytc_inst_get_values(obj_, &result_list_size, result_list)); + auto values = std::vector(result_list_size); + CHECK_STATUS(tinytc_inst_get_values(obj_, &result_list_size, values.data())); return values; } @@ -786,11 +742,11 @@ inline auto get_parameters(tinytc_region_t reg) -> std::vector { * * @return Instruction */ -inline inst make_arith(arithmetic op, value const &a, value const &b, location const &loc = {}) { +inline inst make_arith(arithmetic op, tinytc_value_t a, tinytc_value_t b, + location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_arith_inst_create(&instr, static_cast(op), a.get(), - b.get(), &loc), - loc); + CHECK_STATUS_LOC( + tinytc_arith_inst_create(&instr, static_cast(op), a, b, &loc), loc); return inst(instr); } @@ -803,11 +759,11 @@ inline inst make_arith(arithmetic op, value const &a, value const &b, location c * * @return Instruction */ -inline inst make_arith(arithmetic_unary op, value const &a, location const &loc = {}) { +inline inst make_arith(arithmetic_unary op, tinytc_value_t a, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_arith_unary_inst_create( - &instr, static_cast(op), a.get(), &loc), - loc); + CHECK_STATUS_LOC( + tinytc_arith_unary_inst_create(&instr, static_cast(op), a, &loc), + loc); return inst(instr); } @@ -820,11 +776,10 @@ inline inst make_arith(arithmetic_unary op, value const &a, location const &loc * * @return Instruction */ -inline inst make_cast(value const &a, scalar_type to_ty, location const &loc = {}) { +inline inst make_cast(tinytc_value_t a, scalar_type to_ty, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC( - tinytc_cast_inst_create(&instr, a.get(), static_cast(to_ty), &loc), - loc); + tinytc_cast_inst_create(&instr, a, static_cast(to_ty), &loc), loc); return inst(instr); } @@ -838,11 +793,11 @@ inline inst make_cast(value const &a, scalar_type to_ty, location const &loc = { * * @return Instruction */ -inline inst make_cmp(cmp_condition cond, value const &a, value const &b, location const &loc = {}) { +inline inst make_cmp(cmp_condition cond, tinytc_value_t a, tinytc_value_t b, + location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_cmp_inst_create(&instr, static_cast(cond), - a.get(), b.get(), &loc), - loc); + CHECK_STATUS_LOC( + tinytc_cmp_inst_create(&instr, static_cast(cond), a, b, &loc), loc); return inst(instr); } @@ -935,11 +890,11 @@ inline inst make_alloca(tinytc_data_type_t ty, location const &loc = {}) { * * @return Instruction */ -inline inst make_axpby(transpose tA, bool atomic, value const &alpha, value const &A, - value const &beta, value const &B, location const &loc = {}) { +inline inst make_axpby(transpose tA, bool atomic, tinytc_value_t alpha, tinytc_value_t A, + tinytc_value_t beta, tinytc_value_t B, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_axpby_inst_create(&instr, static_cast(tA), atomic, - alpha.get(), A.get(), beta.get(), B.get(), &loc), + alpha, A, beta, B, &loc), loc); return inst(instr); } @@ -955,10 +910,9 @@ inline inst make_axpby(transpose tA, bool atomic, value const &alpha, value cons * * @return Instruction */ -inline inst make_expand(value const &a, std::int64_t expanded_mode, - std::vector const &static_expand_shape, - std::vector const &expand_shape, location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); +inline inst make_expand(tinytc_value_t a, std::int64_t expanded_mode, + array_view static_expand_shape, + array_view expand_shape, location const &loc = {}) { tinytc_inst_t instr; auto static_len = static_expand_shape.size(); if (static_len > std::numeric_limits::max()) { @@ -968,11 +922,9 @@ inline inst make_expand(value const &a, std::int64_t expanded_mode, if (len > std::numeric_limits::max()) { throw std::out_of_range("expand shape too large"); } - tinytc_value_t *eshape = - const_cast(reinterpret_cast(expand_shape.data())); - CHECK_STATUS_LOC(tinytc_expand_inst_create( - &instr, a.get(), expanded_mode, static_len, - const_cast(static_expand_shape.data()), len, eshape, &loc), + CHECK_STATUS_LOC(tinytc_expand_inst_create(&instr, a, expanded_mode, static_len, + static_expand_shape.data(), len, expand_shape.data(), + &loc), loc); return inst(instr); } @@ -987,10 +939,10 @@ inline inst make_expand(value const &a, std::int64_t expanded_mode, * * @return Instruction */ -inline inst make_fuse(value const &a, std::int64_t from, std::int64_t to, +inline inst make_fuse(tinytc_value_t a, std::int64_t from, std::int64_t to, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_fuse_inst_create(&instr, a.get(), from, to, &loc), loc); + CHECK_STATUS_LOC(tinytc_fuse_inst_create(&instr, a, from, to, &loc), loc); return inst(instr); } @@ -1003,17 +955,14 @@ inline inst make_fuse(value const &a, std::int64_t from, std::int64_t to, * * @return Instruction */ -inline inst make_load(value const &a, std::vector const &index_list, +inline inst make_load(tinytc_value_t a, array_view index_list, location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); tinytc_inst_t instr; auto len = index_list.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("index list too long"); } - tinytc_value_t *il = - const_cast(reinterpret_cast(index_list.data())); - CHECK_STATUS_LOC(tinytc_load_inst_create(&instr, a.get(), len, il, &loc), loc); + CHECK_STATUS_LOC(tinytc_load_inst_create(&instr, a, len, index_list.data(), &loc), loc); return inst(instr); } @@ -1060,13 +1009,13 @@ inline inst make_group_size(compiler_context const &ctx, location const &loc = { * * @return Instruction */ -inline inst make_gemm(transpose tA, transpose tB, bool atomic, value const &alpha, value const &A, - value const &B, value const &beta, value const &C, location const &loc = {}) { +inline inst make_gemm(transpose tA, transpose tB, bool atomic, tinytc_value_t alpha, + tinytc_value_t A, tinytc_value_t B, tinytc_value_t beta, tinytc_value_t C, + location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_gemm_inst_create(&instr, static_cast(tA), - static_cast(tB), atomic, - alpha.get(), A.get(), B.get(), beta.get(), C.get(), - &loc), + static_cast(tB), atomic, alpha, A, + B, beta, C, &loc), loc); return inst(instr); } @@ -1085,12 +1034,12 @@ inline inst make_gemm(transpose tA, transpose tB, bool atomic, value const &alph * * @return Instruction */ -inline inst make_gemv(transpose tA, bool atomic, value const &alpha, value const &A, value const &B, - value const &beta, value const &C, location const &loc = {}) { +inline inst make_gemv(transpose tA, bool atomic, tinytc_value_t alpha, tinytc_value_t A, + tinytc_value_t B, tinytc_value_t beta, tinytc_value_t C, + location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_gemv_inst_create(&instr, static_cast(tA), atomic, - alpha.get(), A.get(), B.get(), beta.get(), C.get(), - &loc), + alpha, A, B, beta, C, &loc), loc); return inst(instr); } @@ -1108,12 +1057,10 @@ inline inst make_gemv(transpose tA, bool atomic, value const &alpha, value const * * @return Instruction */ -inline inst make_ger(bool atomic, value const &alpha, value const &A, value const &B, - value const &beta, value const &C, location const &loc = {}) { +inline inst make_ger(bool atomic, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, + tinytc_value_t beta, tinytc_value_t C, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_ger_inst_create(&instr, atomic, alpha.get(), A.get(), B.get(), - beta.get(), C.get(), &loc), - loc); + CHECK_STATUS_LOC(tinytc_ger_inst_create(&instr, atomic, alpha, A, B, beta, C, &loc), loc); return inst(instr); } @@ -1130,12 +1077,10 @@ inline inst make_ger(bool atomic, value const &alpha, value const &A, value cons * * @return Instruction */ -inline inst make_hadamard(bool atomic, value const &alpha, value const &A, value const &B, - value const &beta, value const &C, location const &loc = {}) { +inline inst make_hadamard(bool atomic, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, + tinytc_value_t beta, tinytc_value_t C, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_hadamard_inst_create(&instr, atomic, alpha.get(), A.get(), B.get(), - beta.get(), C.get(), &loc), - loc); + CHECK_STATUS_LOC(tinytc_hadamard_inst_create(&instr, atomic, alpha, A, B, beta, C, &loc), loc); return inst(instr); } @@ -1175,9 +1120,9 @@ inline inst make_parallel(location const &loc = {}) { * * @return Instruction */ -inline inst make_size(value const &a, std::int64_t mode, location const &loc = {}) { +inline inst make_size(tinytc_value_t a, std::int64_t mode, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_size_inst_create(&instr, a.get(), mode, &loc), loc); + CHECK_STATUS_LOC(tinytc_size_inst_create(&instr, a, mode, &loc), loc); return inst(instr); } @@ -1237,11 +1182,10 @@ inline inst make_subgroup_size(compiler_context const &ctx, location const &loc * * @return Instruction */ -inline inst make_subview(value const &a, std::vector const &static_offset_list, - std::vector const &static_size_list, - std::vector const &offset_list, std::vector const &size_list, - location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); +inline inst make_subview(tinytc_value_t a, array_view static_offset_list, + array_view static_size_list, + array_view offset_list, + array_view size_list, location const &loc = {}) { tinytc_inst_t instr; if (static_offset_list.size() != static_size_list.size()) { throw std::invalid_argument( @@ -1259,16 +1203,10 @@ inline inst make_subview(value const &a, std::vector const &static if (size_len > std::numeric_limits::max()) { throw std::out_of_range("dynamic size list too long"); } - tinytc_value_t *ol = - const_cast(reinterpret_cast(offset_list.data())); - tinytc_value_t *sl = - const_cast(reinterpret_cast(size_list.data())); - CHECK_STATUS_LOC( - tinytc_subview_inst_create(&instr, a.get(), static_len, - const_cast(static_offset_list.data()), - const_cast(static_size_list.data()), offset_len, - ol, size_len, sl, &loc), - loc); + CHECK_STATUS_LOC(tinytc_subview_inst_create( + &instr, a, static_len, static_offset_list.data(), static_size_list.data(), + offset_len, offset_list.data(), size_len, size_list.data(), &loc), + loc); return inst(instr); } @@ -1282,17 +1220,14 @@ inline inst make_subview(value const &a, std::vector const &static * * @return Instruction */ -inline inst make_store(value const &val, value const &a, std::vector const &index_list, +inline inst make_store(tinytc_value_t val, tinytc_value_t a, array_view index_list, location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); tinytc_inst_t instr; auto len = index_list.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("index list too long"); } - tinytc_value_t *il = - const_cast(reinterpret_cast(index_list.data())); - CHECK_STATUS_LOC(tinytc_store_inst_create(&instr, val.get(), a.get(), len, il, &loc), loc); + CHECK_STATUS_LOC(tinytc_store_inst_create(&instr, val, a, len, index_list.data(), &loc), loc); return inst(instr); } @@ -1309,11 +1244,11 @@ inline inst make_store(value const &val, value const &a, std::vector cons * * @return Instruction */ -inline inst make_sum(transpose tA, bool atomic, value const &alpha, value const &A, - value const &beta, value const &B, location const &loc = {}) { +inline inst make_sum(transpose tA, bool atomic, tinytc_value_t alpha, tinytc_value_t A, + tinytc_value_t beta, tinytc_value_t B, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_sum_inst_create(&instr, static_cast(tA), atomic, - alpha.get(), A.get(), beta.get(), B.get(), &loc), + alpha, A, beta, B, &loc), loc); return inst(instr); } @@ -1329,11 +1264,10 @@ inline inst make_sum(transpose tA, bool atomic, value const &alpha, value const * * @return Instruction */ -inline inst make_for(value const &from, value const &to, value const &step, +inline inst make_for(tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, tinytc_data_type_t loop_var_type, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC( - tinytc_for_inst_create(&instr, from.get(), to.get(), step.get(), loop_var_type, &loc), loc); + CHECK_STATUS_LOC(tinytc_for_inst_create(&instr, from, to, step, loop_var_type, &loc), loc); return inst(instr); } @@ -1347,11 +1281,10 @@ inline inst make_for(value const &from, value const &to, value const &step, * * @return Instruction */ -inline inst make_foreach(value const &from, value const &to, tinytc_data_type_t loop_var_type, +inline inst make_foreach(tinytc_value_t from, tinytc_value_t to, tinytc_data_type_t loop_var_type, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_foreach_inst_create(&instr, from.get(), to.get(), loop_var_type, &loc), - loc); + CHECK_STATUS_LOC(tinytc_foreach_inst_create(&instr, from, to, loop_var_type, &loc), loc); return inst(instr); } @@ -1364,18 +1297,15 @@ inline inst make_foreach(value const &from, value const &to, tinytc_data_type_t * * @return Instruction */ -inline inst make_if(value const &condition, - std::vector const &return_type_list = {}, +inline inst make_if(tinytc_value_t condition, array_view return_type_list = {}, location const &loc = {}) { tinytc_inst_t instr; auto len = return_type_list.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("return type list too long"); } - CHECK_STATUS_LOC( - tinytc_if_inst_create(&instr, condition.get(), len, - const_cast(return_type_list.data()), &loc), - loc); + CHECK_STATUS_LOC(tinytc_if_inst_create(&instr, condition, len, return_type_list.data(), &loc), + loc); return inst(instr); } @@ -1387,16 +1317,13 @@ inline inst make_if(value const &condition, * * @return Instruction */ -inline inst make_yield(std::vector const &yield_list, location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); +inline inst make_yield(array_view yield_list, location const &loc = {}) { tinytc_inst_t instr; auto len = yield_list.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("yield list too long"); } - tinytc_value_t *yl = - const_cast(reinterpret_cast(yield_list.data())); - CHECK_STATUS_LOC(tinytc_yield_inst_create(&instr, len, yl, &loc), loc); + CHECK_STATUS_LOC(tinytc_yield_inst_create(&instr, len, yield_list.data(), &loc), loc); return inst(instr); } @@ -1439,15 +1366,15 @@ class func : public unique_handle { * * @return Function */ -inline func make_func(std::string_view name, std::vector const ¶m_type_list, +inline func make_func(std::string_view name, array_view param_type_list, location const &loc = {}) { tinytc_func_t fun; auto len = param_type_list.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("param list too long"); } - tinytc_data_type_t *pl = const_cast(param_type_list.data()); - CHECK_STATUS_LOC(tinytc_func_create(&fun, name.size(), name.data(), len, pl, &loc), loc); + CHECK_STATUS_LOC( + tinytc_func_create(&fun, name.size(), name.data(), len, param_type_list.data(), &loc), loc); return func(fun); } @@ -1553,10 +1480,10 @@ class region_builder { * * @return Value returned by instruction; may be empty */ - [[maybe_unused]] inline auto add(inst i, std::string_view name = "") -> value { + [[maybe_unused]] inline auto add(inst i, std::string_view name = "") -> tinytc_value_t { auto result = i.get_value(); if (result && name.size() > 0) { - result.name(name); + set_name(result, name); } add_instruction(reg_, std::move(i)); return result; @@ -1570,14 +1497,14 @@ class region_builder { * * @return Values returned by instruction */ - [[maybe_unused]] inline auto add_multivalued(inst i, - std::string_view name = "") -> std::vector { + [[maybe_unused]] inline auto + add_multivalued(inst i, std::string_view name = "") -> std::vector { auto results = i.get_values(); if (name.size() > 0) { int counter = 0; auto name_str = std::string{name}; for (auto &result : results) { - result.name(name_str + std::to_string(counter++)); + set_name(result, name_str + std::to_string(counter++)); } } add_instruction(reg_, std::move(i)); @@ -1598,9 +1525,9 @@ class region_builder { * @param loc Source code location */ template - void for_loop(value const &from, value const &to, tinytc_data_type_t loop_var_ty, F &&f, + void for_loop(tinytc_value_t from, tinytc_value_t to, tinytc_data_type_t loop_var_ty, F &&f, std::string_view loop_var_name = "", location const &loc = {}) { - for_loop(std::move(from), std::move(to), value{nullptr}, std::move(loop_var_ty), + for_loop(std::move(from), std::move(to), nullptr, std::move(loop_var_ty), std::forward(f), std::move(loop_var_name), loc); } /** @@ -1618,7 +1545,7 @@ class region_builder { * @param loc Source code location */ template - void for_loop(value const &from, value const &to, value const &step, + void for_loop(tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, tinytc_data_type_t loop_var_ty, F &&f, std::string_view loop_var_name = "", location const &loc = {}) { auto fi = ::tinytc::make_for(from, to, step, loop_var_ty, loc); @@ -1641,7 +1568,7 @@ class region_builder { * @param loc Source code location */ template - void foreach (value const &from, value const &to, tinytc_data_type_t loop_var_ty, F && f, + void foreach (tinytc_value_t from, tinytc_value_t to, tinytc_data_type_t loop_var_ty, F && f, std::string const &loop_var_name = "", location const &loc = {}) { auto fi = ::tinytc::make_foreach(from, to, loop_var_ty, loc); auto reg = fi.get_region(0); @@ -1664,9 +1591,9 @@ class region_builder { * @return Returned values */ template - auto if_condition(value const &condition, F &&then, - std::vector const &return_type_list = {}, - location const &loc = {}) -> std::vector { + auto if_condition(tinytc_value_t condition, F &&then, + array_view return_type_list = {}, + location const &loc = {}) -> std::vector { auto ii = ::tinytc::make_if(std::move(condition), return_type_list, loc); auto r0 = ii.get_region(0); auto results = add_multivalued(std::move(ii)); @@ -1689,9 +1616,9 @@ class region_builder { * @return Returned values */ template - auto ifelse(value const &condition, F &&then, G &&otherwise, - std::vector const &return_type_list = {}, - location const &loc = {}) -> std::vector { + auto ifelse(tinytc_value_t condition, F &&then, G &&otherwise, + array_view return_type_list = {}, + location const &loc = {}) -> std::vector { auto ii = ::tinytc::make_if(std::move(condition), return_type_list, loc); auto r0 = ii.get_region(0); auto r1 = ii.get_region(1); @@ -1781,7 +1708,7 @@ class core_info : public shared_handle { * @return Core info */ inline auto make_core_info_generic(std::int32_t register_space, std::int32_t max_work_group_size, - std::vector sgs) -> core_info { + array_view sgs) -> core_info { tinytc_core_info_t info; CHECK_STATUS(tinytc_core_info_generic_create(&info, register_space, max_work_group_size, sgs.size(), sgs.data())); @@ -1814,7 +1741,7 @@ inline auto make_core_info_intel_from_arch(intel_gpu_architecture arch) -> core_ */ inline auto make_core_info_intel(std::uint32_t ip_version, std::int32_t num_eus_per_subslice, std::int32_t num_threads_per_eu, - std::vector sgs) -> core_info { + array_view sgs) -> core_info { tinytc_core_info_t info; CHECK_STATUS(tinytc_core_info_intel_create(&info, ip_version, num_eus_per_subslice, num_threads_per_eu, sgs.size(), sgs.data())); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ec1f1dc6..583eec8a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -37,6 +37,7 @@ set(SOURCES node/inst_node.cpp node/region_node.cpp node/program_node.cpp + node/value_node.cpp parser/parse_context.cpp parser.cpp pass/check_ir.cpp diff --git a/src/analysis/alias.cpp b/src/analysis/alias.cpp index 39097e63..a492257d 100644 --- a/src/analysis/alias.cpp +++ b/src/analysis/alias.cpp @@ -39,31 +39,31 @@ void alias_analysis_visitor::operator()(alloca_inst const &a) { if (t == nullptr) { throw compilation_error(a.loc(), status::ir_expected_memref); } - allocs_[a.result().get()] = + allocs_[a.result()] = aa_results::allocation{a.stack_ptr(), a.stack_ptr() + t->size_in_bytes()}; } } void alias_analysis_visitor::operator()(expand_inst const &e) { - value_node const *source = e.operand().get(); + value_node const *source = &e.operand(); while (alias_.find(source) != alias_.end()) { source = alias_[source]; } - alias_[e.result().get()] = source; + alias_[e.result()] = source; } void alias_analysis_visitor::operator()(fuse_inst const &f) { - value_node const *source = f.operand().get(); + value_node const *source = &f.operand(); while (alias_.find(source) != alias_.end()) { source = alias_[source]; } - alias_[f.result().get()] = source; + alias_[f.result()] = source; } void alias_analysis_visitor::operator()(subview_inst const &s) { - value_node const *source = s.operand().get(); + value_node const *source = &s.operand(); while (alias_.find(source) != alias_.end()) { source = alias_[source]; } - alias_[s.result().get()] = source; + alias_[s.result()] = source; } auto alias_analysis::run_on_function(function_node &fn) -> aa_results { diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 2ee0c846..05c08e03 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -430,14 +430,15 @@ void write_matrix_block(block_builder &bb, block_accessor const &block, } } -void tile_loop_by_sgs_new(region_builder &bb, value const &loop_trip_count, int sgs, int num_tiles, - value const &sg_id, sgs_loop_body_builder_new const &body) { +void tile_loop_by_sgs_new(region_builder &bb, tinytc_value_t loop_trip_count, int sgs, + int num_tiles, tinytc_value_t sg_id, + sgs_loop_body_builder_new const &body) { tile_loop_by_sgs_new_dynamic(bb, std::move(loop_trip_count), sgs, num_tiles, std::move(sg_id), body); } void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_count, int sgs, - int num_tiles, value const &sg_id, + int num_tiles, tinytc_value_t sg_id, sgs_loop_body_builder_new const &body) { auto index_ty = scalar_data_type::get(sg_id->context(), scalar_type::index); std::int64_t blocks = loop_trip_count / sgs; @@ -454,20 +455,19 @@ void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_co auto block_start = bb.add(make_arith(arithmetic::mul, c_sgs, sg_id_index)); bb.for_loop( std::move(block_start), c_sgs_blocks, c_sgs_tiles, index_ty, - [&](region_builder &bb, tinytc_value_t block) { body(bb, block, false, c_sgs.get()); }, + [&](region_builder &bb, tinytc_value_t block) { body(bb, block, false, c_sgs); }, "block"); } if (rem > 0) { auto condition = bb.add(make_cmp(cmp_condition::eq, sg_id_index, c_tiles_1)); - bb.if_condition(condition, [&](region_builder &bb) { - body(bb, c_sgs_blocks.get(), true, c_rem.get()); - }); + bb.if_condition(condition, + [&](region_builder &bb) { body(bb, c_sgs_blocks, true, c_rem); }); } } -void tile_loop_by_sgs_new_dynamic(region_builder &bb, value const &loop_trip_count, int sgs, - int num_tiles, value const &sg_id, +void tile_loop_by_sgs_new_dynamic(region_builder &bb, tinytc_value_t loop_trip_count, int sgs, + int num_tiles, tinytc_value_t sg_id, sgs_loop_body_builder_new const &body) { auto index_ty = scalar_data_type::get(sg_id->context(), scalar_type::index); auto c_sgs = bb.add(make_constant(sgs, index_ty)); @@ -483,28 +483,27 @@ void tile_loop_by_sgs_new_dynamic(region_builder &bb, value const &loop_trip_cou auto block_end = bb.add(make_arith(arithmetic::mul, c_sgs, blocks)); bb.for_loop( std::move(block_start), std::move(block_end), c_sgs_tiles, index_ty, - [&](region_builder &bb, tinytc_value_t block) { body(bb, block, false, c_sgs.get()); }, - "block"); + [&](region_builder &bb, tinytc_value_t block) { body(bb, block, false, c_sgs); }, "block"); auto condition0 = bb.add(make_cmp(cmp_condition::gt, rem, c0)); bb.if_condition(condition0, [&](region_builder &bb) { auto condition1 = bb.add(make_cmp(cmp_condition::eq, sg_id_index, c_tiles_1)); bb.if_condition(condition1, [&](region_builder &bb) { auto block = bb.add(make_arith(arithmetic::mul, blocks, c_sgs)); - body(bb, block.get(), true, rem.get()); + body(bb, block, true, rem); }); }); } -void tile_loop_uniformly_new(region_builder &bb, value const &loop_trip_count, int block_size, - int num_tiles, value const &sg_id, +void tile_loop_uniformly_new(region_builder &bb, tinytc_value_t loop_trip_count, int block_size, + int num_tiles, tinytc_value_t sg_id, uniform_loop_body_builder_new const &body) { tile_loop_uniformly_new_dynamic(bb, std::move(loop_trip_count), block_size, num_tiles, std::move(sg_id), body); } void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip_count, - int block_size, int num_tiles, value const &sg_id, + int block_size, int num_tiles, tinytc_value_t sg_id, uniform_loop_body_builder_new const &body) { auto index_ty = scalar_data_type::get(sg_id->context(), scalar_type::index); // Find minimum number of blocks such that the block sizes are smaller or equal block_size @@ -530,8 +529,7 @@ void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip auto block_start = bb.add(make_arith(arithmetic::mul, c_bs_1, sg_id_index)); bb.for_loop( std::move(block_start), c_bs_1_rem, c_bs_1_tiles, index_ty, - [&](region_builder &bb, tinytc_value_t block) { body(bb, block, c_bs_1.get()); }, - "block"); + [&](region_builder &bb, tinytc_value_t block) { body(bb, block, c_bs_1); }, "block"); } auto tmp = bb.add(make_arith(arithmetic::add, sg_id_index, c_rem_mod_tiles)); @@ -540,11 +538,11 @@ void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip auto block_start = bb.add(make_arith(arithmetic::add, c_bs_1_rem, tmp2)); bb.for_loop( std::move(block_start), c_loop_trip_count, c_bs_tiles, index_ty, - [&](region_builder &bb, tinytc_value_t block) { body(bb, block, c_bs.get()); }, "block"); + [&](region_builder &bb, tinytc_value_t block) { body(bb, block, c_bs); }, "block"); } -void tile_loop_uniformly_new_dynamic(region_builder &bb, value const &loop_trip_count, - int block_size, int num_tiles, value const &sg_id, +void tile_loop_uniformly_new_dynamic(region_builder &bb, tinytc_value_t loop_trip_count, + int block_size, int num_tiles, tinytc_value_t sg_id, uniform_loop_body_builder_new const &body) { auto index_ty = scalar_data_type::get(loop_trip_count->context(), scalar_type::index); auto c1 = bb.add(make_constant(1, index_ty)); @@ -572,7 +570,7 @@ void tile_loop_uniformly_new_dynamic(region_builder &bb, value const &loop_trip_ auto step_1 = bb.add(make_arith(arithmetic::mul, bs_1, c_tiles)); bb.for_loop( std::move(block_start_1), std::move(block_end_1), std::move(step_1), index_ty, - [&](region_builder &bb, tinytc_value_t block) { body(bb, block, bs_1.get()); }, "block"); + [&](region_builder &bb, tinytc_value_t block) { body(bb, block, bs_1); }, "block"); auto tmp0 = bb.add(make_arith(arithmetic::rem, rem, c_tiles)); auto tmp1 = bb.add(make_arith(arithmetic::add, sg_id_index, tmp0)); @@ -583,7 +581,7 @@ void tile_loop_uniformly_new_dynamic(region_builder &bb, value const &loop_trip_ auto step = bb.add(make_arith(arithmetic::mul, bs, c_tiles)); bb.for_loop( std::move(block_start), loop_trip_count, std::move(step), index_ty, - [&](region_builder &bb, tinytc_value_t block) { body(bb, block, bs.get()); }, "block"); + [&](region_builder &bb, tinytc_value_t block) { body(bb, block, bs); }, "block"); } } // namespace tinytc diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index f4a63c4c..26e29aaf 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -127,23 +127,24 @@ using sgs_loop_body_builder_new = using uniform_loop_body_builder_new = std::function; -void tile_loop_by_sgs_new(region_builder &bb, value const &loop_trip_count, int sgs, int num_tiles, - value const &sg_id, sgs_loop_body_builder_new const &body); +void tile_loop_by_sgs_new(region_builder &bb, tinytc_value_t loop_trip_count, int sgs, + int num_tiles, tinytc_value_t sg_id, + sgs_loop_body_builder_new const &body); void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_count, int sgs, - int num_tiles, value const &sg_id, + int num_tiles, tinytc_value_t sg_id, sgs_loop_body_builder_new const &body); -void tile_loop_by_sgs_new_dynamic(region_builder &bb, value const &loop_trip_count, int sgs, - int num_tiles, value const &sg_id, +void tile_loop_by_sgs_new_dynamic(region_builder &bb, tinytc_value_t loop_trip_count, int sgs, + int num_tiles, tinytc_value_t sg_id, sgs_loop_body_builder_new const &body); -void tile_loop_uniformly_new(region_builder &bb, value const &loop_trip_count, int block_size, - int num_tiles, value const &sg_id, +void tile_loop_uniformly_new(region_builder &bb, tinytc_value_t loop_trip_count, int block_size, + int num_tiles, tinytc_value_t sg_id, uniform_loop_body_builder_new const &body); void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip_count, - int block_size, int num_tiles, value const &sg_id, + int block_size, int num_tiles, tinytc_value_t sg_id, uniform_loop_body_builder_new const &body); -void tile_loop_uniformly_new_dynamic(region_builder &bb, value const &loop_trip_count, - int block_size, int num_tiles, value const &sg_id, +void tile_loop_uniformly_new_dynamic(region_builder &bb, tinytc_value_t loop_trip_count, + int block_size, int num_tiles, tinytc_value_t sg_id, uniform_loop_body_builder_new const &body); } // namespace tinytc diff --git a/src/func.cpp b/src/func.cpp index eacda1e7..886d9987 100644 --- a/src/func.cpp +++ b/src/func.cpp @@ -20,7 +20,7 @@ using namespace tinytc; extern "C" { tinytc_status_t tinytc_func_create(tinytc_func_t *fun, uint32_t name_length, char const *name, - uint32_t num_params, tinytc_data_type_t *param_type_list, + uint32_t num_params, const tinytc_data_type_t *param_type_list, const tinytc_location_t *loc) { if (fun == nullptr || (num_params > 0 && param_type_list == nullptr)) { return tinytc_status_invalid_arguments; diff --git a/src/inst.cpp b/src/inst.cpp index d6337a08..be989482 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -104,8 +104,7 @@ tinytc_status_t tinytc_arith_inst_create(tinytc_inst_t *instr, tinytc_arithmetic return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(op), value(a, true), - value(b, true), get_optional(loc)) + *instr = std::make_unique(enum_cast(op), a, b, get_optional(loc)) .release(); }); } @@ -116,7 +115,7 @@ tinytc_status_t tinytc_arith_unary_inst_create(tinytc_inst_t *instr, tinytc_arit return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(op), value(a, true), + *instr = std::make_unique(enum_cast(op), a, get_optional(loc)) .release(); }); @@ -128,8 +127,7 @@ tinytc_status_t tinytc_cast_inst_create(tinytc_inst_t *instr, tinytc_value_t a, return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(value(a, true), enum_cast(to_ty), - get_optional(loc)) + *instr = std::make_unique(a, enum_cast(to_ty), get_optional(loc)) .release(); }); } @@ -141,9 +139,9 @@ tinytc_status_t tinytc_cmp_inst_create(tinytc_inst_t *instr, tinytc_cmp_conditio return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(cond), value(a, true), - value(b, true), get_optional(loc)) - .release(); + *instr = + std::make_unique(enum_cast(cond), a, b, get_optional(loc)) + .release(); }); } @@ -197,8 +195,7 @@ tinytc_status_t tinytc_axpby_inst_create(tinytc_inst_t *instr, tinytc_transpose_ return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(tA), value(alpha, true), - value(A, true), value(beta, true), value(B, true), + *instr = std::make_unique(enum_cast(tA), alpha, A, beta, B, bool(atomic), get_optional(loc)) .release(); }); @@ -206,28 +203,19 @@ tinytc_status_t tinytc_axpby_inst_create(tinytc_inst_t *instr, tinytc_transpose_ tinytc_status_t tinytc_expand_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t expanded_mode, uint32_t static_expand_shape_size, - int64_t *static_expand_shape, uint32_t expand_shape_size, - tinytc_value_t *expand_shape, + const int64_t *static_expand_shape, + uint32_t expand_shape_size, + const tinytc_value_t *expand_shape, const tinytc_location_t *loc) { if (instr == nullptr || static_expand_shape == nullptr || (expand_shape_size > 0 && expand_shape == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto static_shape = std::vector{}; - static_shape.reserve(static_expand_shape_size); - for (uint32_t i = 0; i < static_expand_shape_size; ++i) { - static_shape.emplace_back(static_expand_shape[i]); - } - auto dynamic_shape = std::vector{}; - dynamic_shape.reserve(expand_shape_size); - for (uint32_t i = 0; i < expand_shape_size; ++i) { - dynamic_shape.emplace_back(value(expand_shape[i], true)); - } - *instr = - std::make_unique(value(a, true), expanded_mode, std::move(static_shape), - std::move(dynamic_shape), get_optional(loc)) - .release(); + *instr = std::make_unique( + a, expanded_mode, array_view{static_expand_shape, static_expand_shape_size}, + array_view{expand_shape, expand_shape_size}, get_optional(loc)) + .release(); }); } @@ -236,24 +224,19 @@ tinytc_status_t tinytc_fuse_inst_create(tinytc_inst_t *instr, tinytc_value_t a, if (instr == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { - *instr = std::make_unique(value(a, true), from, to, get_optional(loc)).release(); - }); + return exception_to_status_code( + [&] { *instr = std::make_unique(a, from, to, get_optional(loc)).release(); }); } tinytc_status_t tinytc_load_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - uint32_t index_list_size, tinytc_value_t *index_list, + uint32_t index_list_size, const tinytc_value_t *index_list, const tinytc_location_t *loc) { if (instr == nullptr || (index_list_size > 0 && index_list == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto il_vec = std::vector(); - il_vec.reserve(index_list_size); - for (uint32_t i = 0; i < index_list_size; ++i) { - il_vec.emplace_back(value(index_list[i], true)); - } - *instr = std::make_unique(value(a, true), std::move(il_vec), get_optional(loc)) + *instr = std::make_unique(a, array_view{index_list, index_list_size}, + get_optional(loc)) .release(); }); } @@ -286,9 +269,7 @@ tinytc_status_t tinytc_gemm_inst_create(tinytc_inst_t *instr, tinytc_transpose_t } return exception_to_status_code([&] { *instr = std::make_unique(enum_cast(tA), enum_cast(tB), - value(alpha, true), value(A, true), value(B, true), - value(beta, true), value(C, true), bool(atomic), - get_optional(loc)) + alpha, A, B, beta, C, bool(atomic), get_optional(loc)) .release(); }); } @@ -301,9 +282,8 @@ tinytc_status_t tinytc_gemv_inst_create(tinytc_inst_t *instr, tinytc_transpose_t return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(tA), value(alpha, true), - value(A, true), value(B, true), value(beta, true), - value(C, true), bool(atomic), get_optional(loc)) + *instr = std::make_unique(enum_cast(tA), alpha, A, B, beta, C, + bool(atomic), get_optional(loc)) .release(); }); } @@ -316,9 +296,7 @@ tinytc_status_t tinytc_ger_inst_create(tinytc_inst_t *instr, tinytc_bool_t atomi return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(value(alpha, true), value(A, true), value(B, true), - value(beta, true), value(C, true), bool(atomic), - get_optional(loc)) + *instr = std::make_unique(alpha, A, B, beta, C, bool(atomic), get_optional(loc)) .release(); }); } @@ -331,10 +309,9 @@ tinytc_status_t tinytc_hadamard_inst_create(tinytc_inst_t *instr, tinytc_bool_t return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(value(alpha, true), value(A, true), value(B, true), - value(beta, true), value(C, true), bool(atomic), - get_optional(loc)) - .release(); + *instr = + std::make_unique(alpha, A, B, beta, C, bool(atomic), get_optional(loc)) + .release(); }); } @@ -361,9 +338,8 @@ tinytc_status_t tinytc_size_inst_create(tinytc_inst_t *instr, tinytc_value_t a, if (instr == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { - *instr = std::make_unique(value(a, true), mode, get_optional(loc)).release(); - }); + return exception_to_status_code( + [&] { *instr = std::make_unique(a, mode, get_optional(loc)).release(); }); } tinytc_status_t tinytc_subgroup_id_inst_create(tinytc_inst_t *instr, tinytc_compiler_context_t ctx, @@ -396,12 +372,12 @@ tinytc_status_t tinytc_subgroup_size_inst_create(tinytc_inst_t *instr, [&] { *instr = std::make_unique(ctx, get_optional(loc)).release(); }); } -tinytc_status_t tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - uint32_t static_list_size, int64_t *static_offset_list, - int64_t *static_size_list, uint32_t offset_list_size, - tinytc_value_t *offset_list, uint32_t size_list_size, - tinytc_value_t *size_list, - const tinytc_location_t *loc) { +tinytc_status_t +tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t a, uint32_t static_list_size, + const int64_t *static_offset_list, const int64_t *static_size_list, + uint32_t offset_list_size, const tinytc_value_t *offset_list, + uint32_t size_list_size, const tinytc_value_t *size_list, + const tinytc_location_t *loc) { if (instr == nullptr || (static_list_size > 0 && (static_offset_list == nullptr || static_size_list == nullptr)) || (offset_list_size > 0 && offset_list == nullptr) || @@ -409,44 +385,23 @@ tinytc_status_t tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto static_offset_vec = - static_list_size > 0 - ? std::vector(static_offset_list, static_offset_list + static_list_size) - : std::vector{}; - auto static_size_vec = - static_list_size > 0 - ? std::vector(static_size_list, static_size_list + static_list_size) - : std::vector{}; - auto offset_vec = std::vector(); - auto size_vec = std::vector(); - offset_vec.reserve(offset_list_size); - size_vec.reserve(size_list_size); - for (uint32_t i = 0; i < offset_list_size; ++i) { - offset_vec.emplace_back(value(offset_list[i], true)); - } - for (uint32_t i = 0; i < size_list_size; ++i) { - size_vec.emplace_back(value(size_list[i], true)); - } - *instr = std::make_unique(value(a, true), std::move(static_offset_vec), - std::move(static_size_vec), std::move(offset_vec), - std::move(size_vec), get_optional(loc)) - .release(); + *instr = + std::make_unique(a, array_view{static_offset_list, static_list_size}, + array_view{static_size_list, static_list_size}, + array_view{offset_list, offset_list_size}, + array_view{size_list, size_list_size}, get_optional(loc)) + .release(); }); } tinytc_status_t tinytc_store_inst_create(tinytc_inst_t *instr, tinytc_value_t val, tinytc_value_t a, - uint32_t index_list_size, tinytc_value_t *index_list, + uint32_t index_list_size, const tinytc_value_t *index_list, const tinytc_location_t *loc) { if (instr == nullptr || (index_list_size > 0 && index_list == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto il_vec = std::vector(); - il_vec.reserve(index_list_size); - for (uint32_t i = 0; i < index_list_size; ++i) { - il_vec.emplace_back(value(index_list[i], true)); - } - *instr = std::make_unique(value(val, true), value(a, true), std::move(il_vec), + *instr = std::make_unique(val, a, array_view{index_list, index_list_size}, get_optional(loc)) .release(); }); @@ -460,8 +415,7 @@ tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinytc_transpose_t return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(tA), value(alpha, true), - value(A, true), value(beta, true), value(B, true), + *instr = std::make_unique(enum_cast(tA), alpha, A, beta, B, bool(atomic), get_optional(loc)) .release(); }); @@ -474,9 +428,8 @@ tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_value_t from return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(value(from, true), value(to, true), value(step, true), - loop_var_type, get_optional(loc)) - .release(); + *instr = + std::make_unique(from, to, step, loop_var_type, get_optional(loc)).release(); }); } @@ -487,15 +440,14 @@ tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, tinytc_value_t return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(value(from, true), value(to, true), loop_var_type, - get_optional(loc)) - .release(); + *instr = + std::make_unique(from, to, loop_var_type, get_optional(loc)).release(); }); } tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condition, uint32_t return_type_list_size, - tinytc_data_type_t *return_type_list, + const tinytc_data_type_t *return_type_list, const tinytc_location_t *loc) { if (instr == nullptr || condition == nullptr || (return_type_list_size > 0 && return_type_list == nullptr)) { @@ -507,36 +459,33 @@ tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condi for (uint32_t i = 0; i < return_type_list_size; ++i) { rt.emplace_back(return_type_list[i]); } - *instr = std::make_unique(value(condition, true), std::move(rt), get_optional(loc)) - .release(); + *instr = std::make_unique(condition, std::move(rt), get_optional(loc)).release(); }); } tinytc_status_t tinytc_yield_inst_create(tinytc_inst_t *instr, uint32_t yield_list_size, - tinytc_value_t *yield_list, const tinytc_location_t *loc) { + const tinytc_value_t *yield_list, + const tinytc_location_t *loc) { if (instr == nullptr || (yield_list_size != 0 && yield_list == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto yl = std::vector(); - yl.reserve(yield_list_size); - for (uint32_t i = 0; i < yield_list_size; ++i) { - yl.emplace_back(value(yield_list[i], true)); - } - *instr = std::make_unique(std::move(yl), get_optional(loc)).release(); + *instr = + std::make_unique(array_view{yield_list, yield_list_size}, get_optional(loc)) + .release(); }); } void tinytc_inst_destroy(tinytc_inst_t obj) { delete obj; } -tinytc_status_t tinytc_inst_get_value(const_tinytc_inst_t instr, tinytc_value_t *result) { +tinytc_status_t tinytc_inst_get_value(tinytc_inst_t instr, tinytc_value_t *result) { if (instr == nullptr || result == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { *result = instr->result().release(); }); + return exception_to_status_code([&] { *result = instr->result(); }); } -tinytc_status_t tinytc_inst_get_values(const_tinytc_inst_t instr, uint32_t *result_list_size, +tinytc_status_t tinytc_inst_get_values(tinytc_inst_t instr, uint32_t *result_list_size, tinytc_value_t *result_list) { if (instr == nullptr || result_list_size == nullptr || (*result_list_size > 0 && result_list == nullptr)) { @@ -552,7 +501,7 @@ tinytc_status_t tinytc_inst_get_values(const_tinytc_inst_t instr, uint32_t *resu auto results = instr->result_begin(); auto const limit = std::min(num, *result_list_size); for (uint32_t i = 0; i < limit; ++i) { - result_list[i] = value(results[i]).release(); + result_list[i] = &results[i]; } } *result_list_size = num; diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index d30c6dae..79334191 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -21,23 +21,24 @@ namespace tinytc { -scalar_data_type *get_scalar_type(location const &loc, value const &v) { - auto m = dyn_cast(v->ty()); +scalar_data_type *get_scalar_type(location const &loc, tinytc_value const &v) { + auto m = dyn_cast(v.ty()); if (m == nullptr) { throw compilation_error(loc, status::ir_expected_scalar); } return m; } -memref_data_type *get_memref_type(location const &loc, value const &v) { - auto m = dyn_cast(v->ty()); +memref_data_type *get_memref_type(location const &loc, tinytc_value const &v) { + auto m = dyn_cast(v.ty()); if (m == nullptr) { throw compilation_error(loc, status::ir_expected_memref); } return m; } -blas_a2_inst::blas_a2_inst(IK tid, value alpha, value A, value beta, value B, bool atomic) +blas_a2_inst::blas_a2_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t beta, + tinytc_value_t B, bool atomic) : standard_inst{tid}, atomic_(atomic) { op(op_alpha) = std::move(alpha); op(op_A) = std::move(A); @@ -45,7 +46,8 @@ blas_a2_inst::blas_a2_inst(IK tid, value alpha, value A, value beta, value B, bo op(op_B) = std::move(B); } -blas_a3_inst::blas_a3_inst(IK tid, value alpha, value A, value B, value beta, value C, bool atomic) +blas_a3_inst::blas_a3_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, + tinytc_value_t beta, tinytc_value_t C, bool atomic) : standard_inst{tid}, atomic_(atomic) { op(op_alpha) = std::move(alpha); op(op_A) = std::move(A); @@ -54,20 +56,20 @@ blas_a3_inst::blas_a3_inst(IK tid, value alpha, value A, value B, value beta, va op(op_C) = std::move(C); } -loop_inst::loop_inst(IK tid, value from0, value to0, value step0, tinytc_data_type_t loop_var_type, - location const &lc) +loop_inst::loop_inst(IK tid, tinytc_value_t from0, tinytc_value_t to0, tinytc_value_t step0, + tinytc_data_type_t loop_var_type, location const &lc) : standard_inst{tid, step0 ? 3 : 2} { op(op_from) = std::move(from0); op(op_to) = std::move(to0); op(op_step) = std::move(step0); - body().add_param(loop_var_type); + body().set_params(array_view{loop_var_type}, lc); loc(lc); auto lvt = get_scalar_type(loc(), loop_var()); auto fromt = get_scalar_type(loc(), from()); auto tot = get_scalar_type(loc(), to()); bool step_ok = true; - if (step()) { + if (has_step()) { auto stept = get_scalar_type(loc(), step()); step_ok = lvt->ty() == stept->ty(); } @@ -82,8 +84,8 @@ alloca_inst::alloca_inst(tinytc_data_type_t ty, location const &lc) : standard_inst{IK::alloca}, stack_ptr_{-1} { loc(lc); - result(0) = make_value(std::move(ty)); - auto memref = dyn_cast(result(0)->ty()); + result(0) = value_node{ty, lc}; + auto memref = dyn_cast(result(0).ty()); if (memref == nullptr) { throw compilation_error(loc(), status::ir_expected_memref); } @@ -92,8 +94,8 @@ alloca_inst::alloca_inst(tinytc_data_type_t ty, location const &lc) } } -axpby_inst::axpby_inst(transpose tA, value alpha0, value A0, value beta0, value B0, bool atomic, - location const &lc) +axpby_inst::axpby_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t beta0, + tinytc_value_t B0, bool atomic, location const &lc) : blas_a2_inst(IK::axpby_blas_a2, std::move(alpha0), std::move(A0), std::move(beta0), std::move(B0), atomic), tA_(tA) { @@ -117,7 +119,8 @@ axpby_inst::axpby_inst(transpose tA, value alpha0, value A0, value beta0, value } } -arith_inst::arith_inst(arithmetic operation, value a0, value b0, location const &lc) +arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b0, + location const &lc) : standard_inst{IK::arith}, operation_(operation) { op(op_a) = std::move(a0); op(op_b) = std::move(b0); @@ -149,10 +152,11 @@ arith_inst::arith_inst(arithmetic operation, value a0, value b0, location const if (!inst_supports_fp && is_floating_type(at->ty())) { throw compilation_error(loc(), status::ir_fp_unsupported); } - result(0) = make_value(at); + result(0) = value_node{at}; } -arith_unary_inst::arith_unary_inst(arithmetic_unary operation, value a0, location const &lc) +arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0, + location const &lc) : standard_inst{IK::arith_unary}, operation_(operation) { op(op_a) = std::move(a0); loc(lc); @@ -170,17 +174,20 @@ arith_unary_inst::arith_unary_inst(arithmetic_unary operation, value a0, locatio if (!inst_supports_fp && is_floating_type(at->ty())) { throw compilation_error(loc(), status::ir_fp_unsupported); } - result(0) = make_value(at); + result(0) = value_node{at, lc}; } -cast_inst::cast_inst(value a, scalar_type to_ty, location const &lc) : standard_inst{IK::cast} { +cast_inst::cast_inst(tinytc_value_t a, scalar_type to_ty, location const &lc) + : standard_inst{IK::cast} { op(op_a) = std::move(a); loc(lc); - result(0) = make_value(scalar_data_type::get(op(op_a)->context(), std::move(to_ty))); + auto result_ty = scalar_data_type::get(op(op_a)->context(), to_ty); + result(0) = value_node{result_ty, lc}; } -compare_inst::compare_inst(cmp_condition cond, value a0, value b0, location const &lc) +compare_inst::compare_inst(cmp_condition cond, tinytc_value_t a0, tinytc_value_t b0, + location const &lc) : standard_inst{IK::compare}, cond_(cond) { op(op_a) = std::move(a0); op(op_b) = std::move(b0); @@ -193,7 +200,8 @@ compare_inst::compare_inst(cmp_condition cond, value a0, value b0, location cons throw compilation_error(loc(), status::ir_scalar_mismatch); } - result(0) = make_value(scalar_data_type::get(at->context(), scalar_type::i1)); + auto result_ty = scalar_data_type::get(at->context(), scalar_type::i1); + result(0) = value_node{result_ty, lc}; } constant_inst::constant_inst(value_type const &value, tinytc_data_type_t ty, location const &lc) @@ -213,12 +221,12 @@ constant_inst::constant_inst(value_type const &value, tinytc_data_type_t ty, loc throw compilation_error(loc(), status::ir_expected_scalar); } - result(0) = make_value(ty); + result(0) = value_node{ty, lc}; } -expand_inst::expand_inst(value op0, std::int64_t expanded_mode, - std::vector static_expand_shape0, - std::vector const &expand_shape0, location const &lc) +expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, + array_view static_expand_shape0, + array_view expand_shape0, location const &lc) : standard_inst{IK::expand, static_cast(1 + expand_shape0.size())}, expanded_mode_(expanded_mode), static_expand_shape_(std::move(static_expand_shape0)) { op(0) = std::move(op0); @@ -263,11 +271,12 @@ expand_inst::expand_inst(value op0, std::int64_t expanded_mode, stride.push_back(m->stride(i)); } - result(0) = make_value( - memref_data_type::get(m->context(), m->element_ty(), shape, stride, m->addrspace())); + auto result_ty = + memref_data_type::get(m->context(), m->element_ty(), shape, stride, m->addrspace()); + result(0) = value_node{result_ty, lc}; } -fuse_inst::fuse_inst(value op0, std::int64_t from, std::int64_t to, location const &lc) +fuse_inst::fuse_inst(tinytc_value_t op0, std::int64_t from, std::int64_t to, location const &lc) : standard_inst{IK::fuse}, from_(from), to_(to) { op(0) = std::move(op0); loc(lc); @@ -299,12 +308,12 @@ fuse_inst::fuse_inst(value op0, std::int64_t from, std::int64_t to, location con shape.push_back(m->shape(i)); stride.push_back(m->stride(i)); } - - result(0) = make_value( - memref_data_type::get(m->context(), m->element_ty(), shape, stride, m->addrspace())); + auto result_ty = + memref_data_type::get(m->context(), m->element_ty(), shape, stride, m->addrspace()); + result(0) = value_node{result_ty, lc}; } -load_inst::load_inst(value op0, std::vector const &index_list0, location const &lc) +load_inst::load_inst(tinytc_value_t op0, array_view index_list0, location const &lc) : standard_inst{IK::load, static_cast(1 + index_list0.size())} { op(0) = std::move(op0); for (std::size_t i = 0; i < index_list0.size(); ++i) { @@ -317,20 +326,22 @@ load_inst::load_inst(value op0, std::vector const &index_list0, location if (static_cast(index_list().size()) != 1) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } - result(0) = make_value(g.ty()); + result(0) = value_node{g.ty(), lc}; }, [&](memref_data_type &m) { if (m.dim() != static_cast(index_list().size())) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } - result(0) = make_value(scalar_data_type::get(m.context(), m.element_ty())); + auto result_ty = scalar_data_type::get(m.context(), m.element_ty()); + result(0) = value_node{result_ty, lc}; }, [&](auto &) { throw compilation_error(loc(), status::ir_expected_memref_or_group); }}, - *operand()->ty()); + *operand().ty()); } -gemm_inst::gemm_inst(transpose tA, transpose tB, value alpha0, value A0, value B0, value beta0, - value C0, bool atomic, location const &lc) +gemm_inst::gemm_inst(transpose tA, transpose tB, tinytc_value_t alpha0, tinytc_value_t A0, + tinytc_value_t B0, tinytc_value_t beta0, tinytc_value_t C0, bool atomic, + location const &lc) : blas_a3_inst(IK::gemm_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), std::move(beta0), std::move(C0), atomic), tA_(tA), tB_(tB) { @@ -359,8 +370,8 @@ gemm_inst::gemm_inst(transpose tA, transpose tB, value alpha0, value A0, value B } } -gemv_inst::gemv_inst(transpose tA, value alpha0, value A0, value B0, value beta0, value C0, - bool atomic, location const &lc) +gemv_inst::gemv_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t B0, + tinytc_value_t beta0, tinytc_value_t C0, bool atomic, location const &lc) : blas_a3_inst(IK::gemv_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), std::move(beta0), std::move(C0), atomic), tA_(tA) { @@ -387,8 +398,8 @@ gemv_inst::gemv_inst(transpose tA, value alpha0, value A0, value B0, value beta0 } } -ger_inst::ger_inst(value alpha0, value A0, value B0, value beta0, value C0, bool atomic, - location const &lc) +ger_inst::ger_inst(tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t B0, + tinytc_value_t beta0, tinytc_value_t C0, bool atomic, location const &lc) : blas_a3_inst(IK::ger_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), std::move(beta0), std::move(C0), atomic) { loc(lc); @@ -413,13 +424,14 @@ ger_inst::ger_inst(value alpha0, value A0, value B0, value beta0, value C0, bool } } -foreach_inst::foreach_inst(value from, value to, tinytc_data_type_t loop_var_type, +foreach_inst::foreach_inst(tinytc_value_t from, tinytc_value_t to, tinytc_data_type_t loop_var_type, location const &loc) : loop_inst{IK::foreach_loop, std::move(from), std::move(to), {}, loop_var_type, loc} { child_region(0).kind(region_kind::spmd); } -hadamard_inst::hadamard_inst(value alpha0, value A0, value B0, value beta0, value C0, bool atomic, +hadamard_inst::hadamard_inst(tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t B0, + tinytc_value_t beta0, tinytc_value_t C0, bool atomic, location const &lc) : blas_a3_inst(IK::hadamard_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), std::move(beta0), std::move(C0), atomic) { @@ -444,13 +456,13 @@ hadamard_inst::hadamard_inst(value alpha0, value A0, value B0, value beta0, valu } } -if_inst::if_inst(value condition, std::vector const &return_types, +if_inst::if_inst(tinytc_value_t condition, array_view return_types, location const &lc) : standard_inst{IK::if_, 1, static_cast(return_types.size())} { op(0) = std::move(condition); loc(lc); for (std::size_t i = 0; i < return_types.size(); ++i) { - result(i) = make_value(return_types[i]); + result(i) = value_node{return_types[i], lc}; } } @@ -460,7 +472,7 @@ parallel_inst::parallel_inst(location const &lc) : standard_inst{IK::parallel} { child_region(0).kind(region_kind::spmd); } -size_inst::size_inst(value op0, std::int64_t mode, location const &lc) +size_inst::size_inst(tinytc_value_t op0, std::int64_t mode, location const &lc) : standard_inst{IK::size}, mode_(mode) { op(0) = std::move(op0); loc(lc); @@ -470,12 +482,13 @@ size_inst::size_inst(value op0, std::int64_t mode, location const &lc) throw compilation_error(loc(), status::ir_out_of_bounds); } - result(0) = make_value(scalar_data_type::get(op(0)->context(), scalar_type::index)); + auto result_ty = scalar_data_type::get(op(0)->context(), scalar_type::index); + result(0) = value_node{result_ty, lc}; } -subview_inst::subview_inst(value op0, std::vector static_offsets0, - std::vector static_sizes0, - std::vector const &offsets0, std::vector const &sizes0, +subview_inst::subview_inst(tinytc_value_t op0, array_view static_offsets0, + array_view static_sizes0, + array_view offsets0, array_view sizes0, location const &lc) : standard_inst{IK::subview, static_cast(1 + offsets0.size() + sizes0.size())}, static_offsets_(std::move(static_offsets0)), static_sizes_(std::move(static_sizes0)) { @@ -519,12 +532,13 @@ subview_inst::subview_inst(value op0, std::vector static_offsets0, } } - result(0) = make_value( - memref_data_type::get(m->context(), m->element_ty(), shape, stride, m->addrspace())); + auto result_ty = + memref_data_type::get(m->context(), m->element_ty(), shape, stride, m->addrspace()); + result(0) = value_node{result_ty, lc}; } -store_inst::store_inst(value val0, value op0, std::vector const &index_list0, - location const &lc) +store_inst::store_inst(tinytc_value_t val0, tinytc_value_t op0, + array_view index_list0, location const &lc) : standard_inst{IK::store, static_cast(2 + index_list0.size())} { op(op_val) = std::move(val0); op(op_operand) = std::move(op0); @@ -548,8 +562,8 @@ store_inst::store_inst(value val0, value op0, std::vector const &index_li } } -sum_inst::sum_inst(transpose tA, value alpha0, value A0, value beta0, value B0, bool atomic, - location const &lc) +sum_inst::sum_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t beta0, + tinytc_value_t B0, bool atomic, location const &lc) : blas_a2_inst(IK::sum_blas_a2, std::move(alpha0), std::move(A0), std::move(beta0), std::move(B0), atomic), tA_(tA) { diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index e135721f..1fca1235 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -7,6 +7,7 @@ #include "error.hpp" #include "node/data_type_node.hpp" #include "node/region_node.hpp" +#include "node/value_node.hpp" #include "support/ilist.hpp" #include "support/type_list.hpp" #include "support/util.hpp" @@ -17,10 +18,10 @@ #include #include #include +#include #include #include #include -#include namespace tinytc { @@ -84,8 +85,11 @@ using inst_nodes = class subgroup_local_id_inst, class subgroup_size_inst, class sum_inst, class yield_inst>; -using value_range = iterator_range_wrapper; -using const_value_range = iterator_range_wrapper; +using op_range = iterator_range_wrapper; +using const_op_range = iterator_range_wrapper; + +using result_range = iterator_range_wrapper; +using const_result_range = iterator_range_wrapper; using region_range = iterator_range_wrapper; using const_region_range = iterator_range_wrapper; @@ -99,44 +103,41 @@ struct tinytc_inst : tinytc::ilist_node_with_parent inline tinytc_inst(tinytc::IK tid) : tid_(tid) {} virtual ~tinytc_inst() = default; + tinytc_inst(tinytc_inst const &other) = delete; + tinytc_inst(tinytc_inst &&other) = delete; + tinytc_inst &operator=(tinytc_inst const &other) = delete; + tinytc_inst &operator=(tinytc_inst &&other) = delete; + inline auto type_id() const -> tinytc::IK { return tid_; } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } // Iterator over operands - inline auto op_begin() -> tinytc::value * { return op_begin_; } - inline auto op_end() -> tinytc::value * { return op_end_; } - inline auto operands() -> tinytc::value_range { - return tinytc::value_range{op_begin_, op_end_}; - } - inline auto op_begin() const -> tinytc::value const * { return op_begin_; } - inline auto op_end() const -> tinytc::value const * { return op_end_; } - inline auto operands() const -> tinytc::const_value_range { - return tinytc::const_value_range{op_begin_, op_end_}; - } - inline auto op(std::size_t pos) -> tinytc::value & { return op_begin_[pos]; } - inline auto op(std::size_t pos) const -> tinytc::value const & { return op_begin_[pos]; } + inline auto op_begin() -> tinytc_value_t * { return op_begin_; } + inline auto op_end() -> tinytc_value_t * { return op_end_; } + inline auto operands() -> tinytc::op_range { return {op_begin_, op_end_}; } + inline auto op_begin() const -> const tinytc_value_t * { return op_begin_; } + inline auto op_end() const -> const tinytc_value_t * { return op_end_; } + inline auto operands() const -> tinytc::const_op_range { return {op_begin_, op_end_}; } + inline auto op(std::size_t pos) -> tinytc_value_t & { return op_begin_[pos]; } + inline auto op(std::size_t pos) const -> tinytc_value_t const & { return op_begin_[pos]; } inline auto num_operands() const -> std::int64_t { return op_end_ - op_begin_; } // Iterator over results - inline auto result_begin() -> tinytc::value * { return result_begin_; } - inline auto result_end() -> tinytc::value * { return result_end_; } - inline auto results() -> tinytc::value_range { - return tinytc::value_range{result_begin_, result_end_}; - } - inline auto result_begin() const -> tinytc::value const * { return result_begin_; } - inline auto result_end() const -> tinytc::value const * { return result_end_; } - inline auto results() const -> tinytc::const_value_range { - return tinytc::const_value_range{result_begin_, result_end_}; - } - inline auto result() const -> tinytc::value { - return num_results() > 0 ? result_begin_[0] : tinytc::value{}; - } - inline auto result(std::size_t pos) -> tinytc::value & { return result_begin_[pos]; } - inline auto result(std::size_t pos) const -> tinytc::value const & { - return result_begin_[pos]; - } + inline auto result_begin() -> tinytc_value_t { return result_begin_; } + inline auto result_end() -> tinytc_value_t { return result_end_; } + inline auto results() -> tinytc::result_range { return {result_begin_, result_end_}; } + inline auto result_begin() const -> const_tinytc_value_t { return result_begin_; } + inline auto result_end() const -> const_tinytc_value_t { return result_end_; } + inline auto results() const -> tinytc::const_result_range { + return {result_begin_, result_end_}; + } + inline auto result() const -> tinytc_value_t { + return num_results() > 0 ? result_begin_ : nullptr; + } + inline auto result(std::size_t pos) -> tinytc_value & { return result_begin_[pos]; } + inline auto result(std::size_t pos) const -> tinytc_value const & { return result_begin_[pos]; } inline auto num_results() const -> std::int64_t { return result_end_ - result_begin_; } // Iterator over regions @@ -207,11 +208,11 @@ struct tinytc_inst : tinytc::ilist_node_with_parent } protected: - inline auto op_range(tinytc::value *begin, tinytc::value *end) { + inline auto op_range(tinytc_value_t *begin, tinytc_value_t *end) { op_begin_ = begin; op_end_ = end; } - inline auto result_range(tinytc::value *begin, tinytc::value *end) { + inline auto result_range(tinytc_value_t begin, tinytc_value_t end) { result_begin_ = begin; result_end_ = end; } @@ -223,8 +224,8 @@ struct tinytc_inst : tinytc::ilist_node_with_parent private: tinytc::IK tid_; tinytc::location loc_; - tinytc::value *op_begin_ = nullptr, *op_end_ = nullptr, *result_begin_ = nullptr, - *result_end_ = nullptr; + tinytc_value_t *op_begin_ = nullptr, *op_end_ = nullptr; + tinytc_value_t result_begin_ = nullptr, result_end_ = nullptr; tinytc_region_t child_regions_begin_ = nullptr, child_regions_end_ = nullptr; }; @@ -282,8 +283,8 @@ class standard_inst : public inst_node { } private: - object_container ops_; - object_container results_; + object_container ops_; + object_container results_; object_container child_regions_; }; @@ -293,14 +294,15 @@ class blas_a2_inst : public standard_inst<4, 0> { return i.type_id() >= IK::blas_a2 && i.type_id() <= IK::last_blas_a2; } enum op_number { op_alpha = 0, op_A = 1, op_beta = 2, op_B = 3 }; - blas_a2_inst(IK tid, value alpha, value A, value beta, value B, bool atomic); + blas_a2_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t beta, + tinytc_value_t B, bool atomic); inline bool atomic() const { return atomic_; } inline void atomic(bool a) { atomic_ = a; } - inline auto alpha() const -> value const & { return op(op_alpha); } - inline auto A() const -> value const & { return op(op_A); } - inline auto beta() const -> value const & { return op(op_beta); } - inline auto B() const -> value const & { return op(op_B); } + inline auto alpha() const -> tinytc_value const & { return *op(op_alpha); } + inline auto A() const -> tinytc_value const & { return *op(op_A); } + inline auto beta() const -> tinytc_value const & { return *op(op_beta); } + inline auto B() const -> tinytc_value const & { return *op(op_B); } protected: bool atomic_; @@ -312,15 +314,16 @@ class blas_a3_inst : public standard_inst<5, 0> { return i.type_id() >= IK::blas_a3 && i.type_id() <= IK::last_blas_a3; } enum op_number { op_alpha = 0, op_A = 1, op_B = 2, op_beta = 3, op_C = 4 }; - blas_a3_inst(IK tid, value alpha, value A, value B, value beta, value C, bool atomic); + blas_a3_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, + tinytc_value_t beta, tinytc_value_t C, bool atomic); inline bool atomic() const { return atomic_; } inline void atomic(bool a) { atomic_ = a; } - inline auto alpha() const -> value const & { return op(op_alpha); } - inline auto A() const -> value const & { return op(op_A); } - inline auto B() const -> value const & { return op(op_B); } - inline auto beta() const -> value const & { return op(op_beta); } - inline auto C() const -> value const & { return op(op_C); } + inline auto alpha() const -> tinytc_value const & { return *op(op_alpha); } + inline auto A() const -> tinytc_value const & { return *op(op_A); } + inline auto B() const -> tinytc_value const & { return *op(op_B); } + inline auto beta() const -> tinytc_value const & { return *op(op_beta); } + inline auto C() const -> tinytc_value const & { return *op(op_C); } protected: bool atomic_; @@ -332,14 +335,16 @@ class loop_inst : public standard_inst<3, 0, 1> { return i.type_id() >= IK::loop && i.type_id() <= IK::last_loop; } enum op_number { op_from = 0, op_to = 1, op_step = 2 }; - loop_inst(IK tid, value from, value to, value step, tinytc_data_type_t loop_var_type, - location const &loc = {}); - inline auto from() const -> value const & { return op(op_from); } - inline auto to() const -> value const & { return op(op_to); } - inline auto step() const -> value const & { return op(op_step); } + loop_inst(IK tid, tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, + tinytc_data_type_t loop_var_type, location const &loc = {}); + inline auto from() const -> tinytc_value const & { return *op(op_from); } + inline auto to() const -> tinytc_value const & { return *op(op_to); } + inline auto has_step() const -> bool { return op(op_step) != nullptr; } + inline auto step() const -> tinytc_value const & { return *op(op_step); } inline auto body() -> tinytc_region & { return child_region(0); } inline auto body() const -> tinytc_region const & { return child_region(0); } - inline auto loop_var() const -> value const & { return body().param(0); } + inline auto loop_var() -> tinytc_value & { return body().param(0); } + inline auto loop_var() const -> tinytc_value const & { return body().param(0); } }; class alloca_inst : public standard_inst<0, 1> { @@ -357,8 +362,8 @@ class alloca_inst : public standard_inst<0, 1> { class axpby_inst : public blas_a2_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::axpby_blas_a2; } - axpby_inst(transpose tA, value alpha, value A, value beta, value B, bool atomic = false, - location const &lc = {}); + axpby_inst(transpose tA, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t beta, + tinytc_value_t B, bool atomic = false, location const &lc = {}); inline transpose tA() const { return tA_; } @@ -370,11 +375,11 @@ class arith_inst : public standard_inst<2, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::arith; } enum op_number { op_a = 0, op_b = 1 }; - arith_inst(arithmetic op, value a, value b, location const &lc = {}); + arith_inst(arithmetic op, tinytc_value_t a, tinytc_value_t b, location const &lc = {}); inline arithmetic operation() const { return operation_; } - inline auto a() const -> value const & { return op(op_a); } - inline auto b() const -> value const & { return op(op_b); } + inline auto a() const -> tinytc_value const & { return *op(op_a); } + inline auto b() const -> tinytc_value const & { return *op(op_b); } private: arithmetic operation_; @@ -384,10 +389,10 @@ class arith_unary_inst : public standard_inst<1, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::arith_unary; } enum op_number { op_a = 0 }; - arith_unary_inst(arithmetic_unary op, value a, location const &lc = {}); + arith_unary_inst(arithmetic_unary op, tinytc_value_t a, location const &lc = {}); inline arithmetic_unary operation() const { return operation_; } - inline auto a() const -> value const & { return op(op_a); } + inline auto a() const -> tinytc_value const & { return *op(op_a); } private: arithmetic_unary operation_; @@ -415,19 +420,19 @@ class cast_inst : public standard_inst<1, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::cast; } enum op_number { op_a = 0 }; - cast_inst(value a, scalar_type to_ty, location const &lc = {}); - inline auto a() const -> value const & { return op(op_a); } + cast_inst(tinytc_value_t a, scalar_type to_ty, location const &lc = {}); + inline auto a() const -> tinytc_value const & { return *op(op_a); } }; class compare_inst : public standard_inst<2, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::compare; } enum op_number { op_a = 0, op_b = 1 }; - compare_inst(cmp_condition cond, value a, value b, location const &lc = {}); + compare_inst(cmp_condition cond, tinytc_value_t a, tinytc_value_t b, location const &lc = {}); inline cmp_condition cond() const { return cond_; } - inline auto a() const -> value const & { return op(op_a); } - inline auto b() const -> value const & { return op(op_b); } + inline auto a() const -> tinytc_value const & { return *op(op_a); } + inline auto b() const -> tinytc_value const & { return *op(op_b); } private: cmp_condition cond_; @@ -449,18 +454,19 @@ class constant_inst : public standard_inst<0, 1> { class expand_inst : public standard_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::expand; } - expand_inst(value op, std::int64_t expanded_mode, std::vector static_expand_shape, - std::vector const &expand_shape, location const &lc = {}); + expand_inst(tinytc_value_t op, std::int64_t expanded_mode, + array_view static_expand_shape, + array_view expand_shape, location const &lc = {}); inline std::int64_t expanded_mode() const { return expanded_mode_; } - inline auto static_expand_shape() const -> std::vector const & { + inline auto static_expand_shape() const -> array_view { return static_expand_shape_; } - inline auto operand() const -> value const & { return op(0); } + inline auto operand() const -> tinytc_value const & { return *op(0); } inline auto expand_shape() { return operands() | std::views::drop(1); } inline auto expand_shape() const { return operands() | std::views::drop(1); } - inline auto expand_shape(std::int64_t i) const -> value const & { return op(i + 1); } + inline auto expand_shape(std::int64_t i) const -> tinytc_value const & { return *op(i + 1); } private: std::int64_t expanded_mode_; @@ -470,9 +476,9 @@ class expand_inst : public standard_inst { class fuse_inst : public standard_inst<1, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::fuse; } - fuse_inst(value op, std::int64_t from, std::int64_t to, location const &lc = {}); + fuse_inst(tinytc_value_t op, std::int64_t from, std::int64_t to, location const &lc = {}); - inline auto operand() const -> value const & { return op(0); } + inline auto operand() const -> tinytc_value const & { return *op(0); } inline std::int64_t from() const { return from_; } inline std::int64_t to() const { return to_; } @@ -483,9 +489,9 @@ class fuse_inst : public standard_inst<1, 1> { class load_inst : public standard_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::load; } - load_inst(value op, std::vector const &index_list, location const &lc = {}); + load_inst(tinytc_value_t op, array_view index_list, location const &lc = {}); - inline auto operand() const -> value const & { return op(0); } + inline auto operand() const -> tinytc_value const & { return *op(0); } inline auto index_list() const { return operands() | std::views::drop(1); } }; @@ -495,7 +501,7 @@ class group_id_inst : public standard_inst<0, 1> { inline group_id_inst(tinytc_compiler_context_t ctx, location const &lc = {}) : standard_inst{IK::group_id} { loc(lc); - result(0) = make_value(scalar_data_type::get(ctx, scalar_type::index)); + result(0) = value_node{scalar_data_type::get(ctx, scalar_type::index), lc}; } }; @@ -505,24 +511,24 @@ class group_size_inst : public standard_inst<0, 1> { inline group_size_inst(tinytc_compiler_context_t ctx, location const &lc = {}) : standard_inst{IK::group_size} { loc(lc); - result(0) = make_value(scalar_data_type::get(ctx, scalar_type::index)); + result(0) = value_node{scalar_data_type::get(ctx, scalar_type::index), lc}; } }; class lifetime_stop_inst : public standard_inst<1, 0> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::lifetime_stop; } - inline lifetime_stop_inst(value obj) : standard_inst{IK::lifetime_stop} { + inline lifetime_stop_inst(tinytc_value_t obj) : standard_inst{IK::lifetime_stop} { op(0) = std::move(obj); } - inline auto object() const -> value const & { return op(0); } + inline auto object() const -> tinytc_value const & { return *op(0); } }; class gemm_inst : public blas_a3_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::gemm_blas_a3; } - gemm_inst(transpose tA, transpose tB, value alpha, value A, value B, value beta, value C, - bool atomic = false, location const &lc = {}); + gemm_inst(transpose tA, transpose tB, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, + tinytc_value_t beta, tinytc_value_t C, bool atomic = false, location const &lc = {}); inline transpose tA() const { return tA_; } inline transpose tB() const { return tB_; } @@ -534,8 +540,8 @@ class gemm_inst : public blas_a3_inst { class gemv_inst : public blas_a3_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::gemv_blas_a3; } - gemv_inst(transpose tA, value alpha, value A, value B, value beta, value C, bool atomic = false, - location const &lc = {}); + gemv_inst(transpose tA, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, + tinytc_value_t beta, tinytc_value_t C, bool atomic = false, location const &lc = {}); inline transpose tA() const { return tA_; } @@ -546,18 +552,18 @@ class gemv_inst : public blas_a3_inst { class ger_inst : public blas_a3_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::ger_blas_a3; } - ger_inst(value alpha, value A, value B, value beta, value C, bool atomic = false, - location const &lc = {}); + ger_inst(tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, tinytc_value_t beta, + tinytc_value_t C, bool atomic = false, location const &lc = {}); }; class for_inst : public loop_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::for_loop; } - inline for_inst(value from, value to, tinytc_data_type_t loop_var_type, + inline for_inst(tinytc_value_t from, tinytc_value_t to, tinytc_data_type_t loop_var_type, location const &loc = {}) : for_inst{std::move(from), std::move(to), {}, loop_var_type, loc} {} - inline for_inst(value from, value to, value step, tinytc_data_type_t loop_var_type, - location const &loc = {}) + inline for_inst(tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, + tinytc_data_type_t loop_var_type, location const &loc = {}) : loop_inst{IK::for_loop, std::move(from), std::move(to), std::move(step), loop_var_type, loc} {} }; @@ -565,23 +571,24 @@ class for_inst : public loop_inst { class foreach_inst : public loop_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::foreach_loop; } - foreach_inst(value from, value to, tinytc_data_type_t loop_var_type, location const &loc = {}); + foreach_inst(tinytc_value_t from, tinytc_value_t to, tinytc_data_type_t loop_var_type, + location const &loc = {}); }; class hadamard_inst : public blas_a3_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::hadamard_blas_a3; } - hadamard_inst(value alpha, value A, value B, value beta, value C, bool atomic = false, - location const &lc = {}); + hadamard_inst(tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, tinytc_value_t beta, + tinytc_value_t C, bool atomic = false, location const &lc = {}); }; class if_inst : public standard_inst<1, dynamic, 2> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::if_; } enum child_region_number { child_region_then = 0, child_region_otherwise = 1 }; - if_inst(value condition, std::vector const &return_types = {}, + if_inst(tinytc_value_t condition, array_view return_types = {}, location const &lc = {}); - inline auto condition() const -> value const & { return op(0); } + inline auto condition() const -> tinytc_value const & { return *op(0); } inline auto then() -> tinytc_region & { return child_region(child_region_then); } inline auto then() const -> tinytc_region const & { return child_region(child_region_then); } inline auto otherwise() -> tinytc_region & { return child_region(child_region_otherwise); } @@ -597,7 +604,7 @@ class num_subgroups_inst : public standard_inst<0, 1> { inline num_subgroups_inst(tinytc_compiler_context_t ctx, location const &lc = {}) : standard_inst{IK::num_subgroups} { loc(lc); - result(0) = make_value(scalar_data_type::get(ctx, scalar_type::i32)); + result(0) = value_node{scalar_data_type::get(ctx, scalar_type::i32), lc}; } }; @@ -613,9 +620,9 @@ class parallel_inst : public standard_inst<0, 0, 1> { class size_inst : public standard_inst<1, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::size; } - size_inst(value op, std::int64_t mode, location const &lc = {}); + size_inst(tinytc_value_t op, std::int64_t mode, location const &lc = {}); - inline auto operand() const -> value const & { return op(0); } + inline auto operand() const -> tinytc_value const & { return *op(0); } inline std::int64_t mode() const { return mode_; } private: @@ -628,7 +635,7 @@ class subgroup_id_inst : public standard_inst<0, 1> { inline subgroup_id_inst(tinytc_compiler_context_t ctx, location const &lc = {}) : standard_inst{IK::subgroup_id} { loc(lc); - result(0) = make_value(scalar_data_type::get(ctx, scalar_type::i32)); + result(0) = value_node{scalar_data_type::get(ctx, scalar_type::i32), lc}; } }; @@ -638,7 +645,7 @@ class subgroup_local_id_inst : public standard_inst<0, 1> { inline subgroup_local_id_inst(tinytc_compiler_context_t ctx, location const &lc = {}) : standard_inst{IK::subgroup_local_id} { loc(lc); - result(0) = make_value(scalar_data_type::get(ctx, scalar_type::i32)); + result(0) = value_node{scalar_data_type::get(ctx, scalar_type::i32), lc}; } }; @@ -648,23 +655,21 @@ class subgroup_size_inst : public standard_inst<0, 1> { inline subgroup_size_inst(tinytc_compiler_context_t ctx, location const &lc = {}) : standard_inst{IK::subgroup_size} { loc(lc); - result(0) = make_value(scalar_data_type::get(ctx, scalar_type::i32)); + result(0) = value_node{scalar_data_type::get(ctx, scalar_type::i32), lc}; } }; class subview_inst : public standard_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::subview; } - subview_inst(value op, std::vector static_offsets, - std::vector static_sizes, std::vector const &offsets, - std::vector const &sizes, location const &lc = {}); + subview_inst(tinytc_value_t op, array_view static_offsets, + array_view static_sizes, array_view offsets, + array_view sizes, location const &lc = {}); - inline auto static_offsets() const -> std::vector const & { - return static_offsets_; - } - inline auto static_sizes() const -> std::vector const & { return static_sizes_; } + inline auto static_offsets() const -> array_view { return static_offsets_; } + inline auto static_sizes() const -> array_view { return static_sizes_; } - inline auto operand() const -> value const & { return op(0); } + inline auto operand() const -> tinytc_value const & { return *op(0); } inline auto offsets() const { return operands() | std::views::drop(1) | std::views::take(num_dyn_offsets_); } @@ -679,18 +684,19 @@ class store_inst : public standard_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::store; } enum op_number { op_val = 0, op_operand = 1 }; - store_inst(value val, value op, std::vector const &index_list, location const &lc = {}); + store_inst(tinytc_value_t val, tinytc_value_t op, array_view index_list, + location const &lc = {}); - inline auto val() const -> value const & { return op(op_val); } - inline auto operand() const -> value const & { return op(op_operand); } + inline auto val() const -> tinytc_value const & { return *op(op_val); } + inline auto operand() const -> tinytc_value const & { return *op(op_operand); } inline auto index_list() const { return operands() | std::views::drop(2); } }; class sum_inst : public blas_a2_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::sum_blas_a2; } - sum_inst(transpose tA, value alpha, value A, value beta, value B, bool atomic = false, - location const &lc = {}); + sum_inst(transpose tA, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t beta, + tinytc_value_t B, bool atomic = false, location const &lc = {}); inline transpose tA() const { return tA_; } @@ -701,7 +707,7 @@ class sum_inst : public blas_a2_inst { class yield_inst : public standard_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::yield; } - inline yield_inst(std::vector const &vals, location const &lc = {}) + inline yield_inst(array_view vals, location const &lc = {}) : standard_inst{IK::yield, static_cast(vals.size())} { loc(lc); for (std::size_t i = 0; i < vals.size(); ++i) { diff --git a/src/node/region_node.cpp b/src/node/region_node.cpp index 0e844aec..82cb0e55 100644 --- a/src/node/region_node.cpp +++ b/src/node/region_node.cpp @@ -21,17 +21,16 @@ void ilist_callbacks::node_removed(tinytc_inst_t node) { tinytc_ins using namespace tinytc; tinytc_region::tinytc_region(array_view param_types, location const &lc) - : kind_(region_kind::mixed) { + : kind_(region_kind::mixed), params_{param_types.size()} { loc(lc); - params_.reserve(param_types.size()); - for (auto ¶m_ty : param_types) { - params_.push_back(make_value(param_ty)); - } + set_params(std::move(param_types), lc); } tinytc_region::~tinytc_region() {} -void tinytc_region::add_param(tinytc_data_type_t param_type) { - params_.push_back(make_value(param_type)); +void tinytc_region::set_params(array_view param_types, location const &lc) { + params_.resize(param_types.size()); + for (std::size_t i = 0; i < param_types.size(); ++i) { + params_[i] = tinytc_value{param_types[i], lc}; + } } - diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp index d3d895bf..b6f23dfb 100644 --- a/src/node/region_node.hpp +++ b/src/node/region_node.hpp @@ -4,6 +4,7 @@ #ifndef REGION_NODE_20230908_HPP #define REGION_NODE_20230908_HPP +#include "node/value_node.hpp" #include "support/ilist.hpp" #include "support/util.hpp" #include "tinytc/tinytc.h" @@ -56,12 +57,13 @@ struct tinytc_region final { inline auto params() { return tinytc::iterator_range_wrapper{param_begin(), param_end()}; } inline auto param_begin() const { return params_.begin(); } inline auto param_end() const { return params_.end(); } - inline auto param(std::int64_t pos) const -> tinytc::value const & { return params_[pos]; } + inline auto param(std::int64_t pos) -> tinytc_value & { return params_[pos]; } + inline auto param(std::int64_t pos) const -> tinytc_value const & { return params_[pos]; } inline auto params() const { return tinytc::iterator_range_wrapper{param_begin(), param_end()}; } inline auto num_params() const noexcept -> std::int64_t { return params_.size(); } - void add_param(tinytc_data_type_t param_type); + void set_params(tinytc::array_view param_types, tinytc::location const &lc); private: static auto inst_list_offset() -> std::size_t { @@ -71,8 +73,8 @@ struct tinytc_region final { friend struct tinytc::ilist_callbacks; tinytc::region_kind kind_; - std::vector params_; tinytc::ilist insts_; + std::vector params_; tinytc::location loc_; }; diff --git a/src/node/value_node.cpp b/src/node/value_node.cpp new file mode 100644 index 00000000..972ac408 --- /dev/null +++ b/src/node/value_node.cpp @@ -0,0 +1,7 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "node/value_node.hpp" + +tinytc_value::tinytc_value(tinytc_data_type_t ty, tinytc::location const &lc) + : ty_{std::move(ty)}, loc_{lc} {} diff --git a/src/node/value_node.hpp b/src/node/value_node.hpp index 0f54bdf4..6a50844f 100644 --- a/src/node/value_node.hpp +++ b/src/node/value_node.hpp @@ -6,22 +6,20 @@ #include "location.hpp" #include "node/data_type_node.hpp" -#include "reference_counted.hpp" #include "tinytc/types.h" #include #include #include -struct tinytc_value final : tinytc::reference_counted { +struct tinytc_value final { public: - inline tinytc_value(tinytc_data_type_t ty, tinytc::location const &lc = {}) - : ty_{std::move(ty)}, loc_{lc} {} + tinytc_value(tinytc_data_type_t ty = nullptr, tinytc::location const &lc = {}); inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } - inline tinytc_data_type_t ty() const { return ty_; } + inline auto ty() const -> tinytc_data_type_t { return ty_; } inline auto context() const -> tinytc_compiler_context_t { return ty_->context(); } diff --git a/src/parser/parse_context.cpp b/src/parser/parse_context.cpp index 1f3c6ff8..6bc4890b 100644 --- a/src/parser/parse_context.cpp +++ b/src/parser/parse_context.cpp @@ -22,7 +22,7 @@ void parse_context::pop_region() { regions_.pop(); } auto parse_context::top_region() -> tinytc_region_t { return regions_.top(); } auto parse_context::has_regions() -> bool { return !regions_.empty(); } -void parse_context::val(std::string const &id, value val, location const &l) { +void parse_context::val(std::string const &id, tinytc_value &val, location const &l) { if (id_map_.empty()) { throw parser::syntax_error(l, "No active variable scope"); } @@ -33,11 +33,11 @@ void parse_context::val(std::string const &id, value val, location const &l) { throw parser::syntax_error(l, oss.str()); } } - val->loc(l); - id_map_.back()[id] = std::move(val); + val.loc(l); + id_map_.back()[id] = &val; } -value parse_context::val(std::string const &id, location const &l) { +auto parse_context::val(std::string const &id, location const &l) -> tinytc_value_t { for (auto it = id_map_.rbegin(); it != id_map_.rend(); ++it) { if (auto j = it->find(id); j != it->end()) { return j->second; diff --git a/src/parser/parse_context.hpp b/src/parser/parse_context.hpp index ecad2fe3..3fa5defd 100644 --- a/src/parser/parse_context.hpp +++ b/src/parser/parse_context.hpp @@ -21,8 +21,8 @@ class parse_context { inline auto program() { return program_; } inline void program(prog p) { program_ = std::move(p); } - void val(std::string const &id, value val, location const &l); - value val(std::string const &id, location const &l); + void val(std::string const &id, tinytc_value &val, location const &l); + auto val(std::string const &id, location const &l) -> tinytc_value_t; void report_error(location const &loc, std::string const &what); @@ -38,7 +38,7 @@ class parse_context { private: compiler_context compiler_ctx_; - std::vector> id_map_; + std::vector> id_map_; std::stack regions_; prog program_; }; diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 172a1bd2..244b6811 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -21,7 +21,7 @@ class parse_context; class lexer; - using int_or_val = std::variant; + using int_or_val = std::variant; using unique_ptr_to_if_inst = std::unique_ptr; } } @@ -45,7 +45,7 @@ namespace tinytc { - void check_scalar_type(compiler_context const &ctx, value &val, scalar_type const &sty, + void check_scalar_type(compiler_context const &ctx, tinytc_value_t val, scalar_type const &sty, location &loc1, location &loc2) { if (val->ty() != get_scalar(ctx, sty)) { auto loc = loc1; @@ -53,7 +53,7 @@ throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); } } - void check_type(value &val, tinytc_data_type_t ty, location &loc1, location &loc2) { + void check_type(tinytc_value_t val, tinytc_data_type_t ty, location &loc1, location &loc2) { if (val->ty() != ty) { auto loc = loc1; loc.end = loc2.end; @@ -165,12 +165,12 @@ %nterm group_type %nterm group_offset %nterm memref_or_group_type -%nterm <::tinytc::value> var +%nterm var %nterm instruction %nterm axpby_inst %nterm atomic -%nterm > optional_value_list -%nterm > value_list +%nterm > optional_value_list +%nterm > value_list %nterm barrier_inst %nterm optional_global_attr %nterm optional_local_attr @@ -179,7 +179,7 @@ %nterm ger_inst %nterm transpose %nterm for_inst -%nterm <::tinytc::value> optional_step +%nterm optional_step %nterm foreach_inst %nterm hadamard_inst %nterm if_inst @@ -577,9 +577,9 @@ for_inst: loc.end = @for_loop_var_type.end; auto inode = std::make_unique($from, $to, $optional_step, $for_loop_var_type, loc); ctx.push_scope(); - auto loop_var = inode->loop_var(); - loop_var->name($loop_var); - ctx.val($loop_var, std::move(loop_var), @loop_var); + auto &loop_var = inode->loop_var(); + loop_var.name($loop_var); + ctx.val($loop_var, loop_var, @loop_var); ctx.push_region(&inode->body()); $$ = inst{inode.release()}; } catch (compilation_error const &e) { @@ -607,9 +607,9 @@ foreach_inst: auto inode = std::make_unique($from, $to, $for_loop_var_type, loc); ctx.push_scope(); - auto loop_var = inode->loop_var(); - loop_var->name($loop_var); - ctx.val($loop_var, std::move(loop_var), @loop_var); + auto &loop_var = inode->loop_var(); + loop_var.name($loop_var); + ctx.val($loop_var, loop_var, @loop_var); ctx.push_region(&inode->body()); $$ = inst{inode.release()}; } catch (compilation_error const &e) { @@ -638,7 +638,7 @@ var_definition: } auto results = $$->result_begin(); for (std::int64_t i = 0; i < $$->num_results(); ++i) { - results[i]->name($identifier_list[i]); + results[i].name($identifier_list[i]); ctx.val($identifier_list[i], results[i], @identifier_list); } } @@ -853,12 +853,12 @@ expand_inst: try { auto static_shape = std::vector{}; static_shape.reserve($expand_shape.size()); - auto dynamic_shape = std::vector{}; + auto dynamic_shape = std::vector{}; dynamic_shape.reserve($expand_shape.size()); for (auto &s : $expand_shape) { std::visit(overloaded{ [&](std::int64_t i) { static_shape.push_back(i); }, - [&](value const &v) { + [&](tinytc_value_t v) { static_shape.push_back(dynamic); dynamic_shape.push_back(v); }, @@ -880,7 +880,7 @@ expand_inst: expand_shape: integer_constant_or_identifier[a] TIMES integer_constant_or_identifier[b] { - $$ = std::vector>{$a, $b}; + $$ = std::vector{$a, $b}; } | expand_shape TIMES integer_constant_or_identifier[a] { $$ = std::move($1); $$.push_back($a); } ; @@ -1066,8 +1066,8 @@ subview_inst: try { auto static_offsets = std::vector{}; auto static_sizes = std::vector{}; - auto offsets = std::vector{}; - auto sizes = std::vector{}; + auto offsets = std::vector{}; + auto sizes = std::vector{}; static_offsets.reserve($optional_slice_list.size()); static_sizes.reserve($optional_slice_list.size()); offsets.reserve($optional_slice_list.size()); @@ -1075,14 +1075,14 @@ subview_inst: for (auto &s : $optional_slice_list) { std::visit(overloaded{ [&](std::int64_t i) { static_offsets.push_back(i); }, - [&](value const &v) { + [&](tinytc_value_t v) { static_offsets.push_back(dynamic); offsets.push_back(v); }, }, s.first); std::visit(overloaded{ [&](std::int64_t i) { static_sizes.push_back(i); }, - [&](value const &v) { + [&](tinytc_value_t v) { static_sizes.push_back(dynamic); sizes.push_back(v); }, diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index a674e8a7..bb056b9f 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -115,17 +115,17 @@ convert_to_opencl_pass::convert_to_opencl_pass(::tinytc_core_info const *info) declared_vars_.push_back({}); } -auto convert_to_opencl_pass::get_dope_vector(value_node *v) -> dope_vector & { - auto dv = dope_vector_.find(std::bit_cast(v)); +auto convert_to_opencl_pass::get_dope_vector(value_node const &v) -> dope_vector & { + auto dv = dope_vector_.find(std::bit_cast(&v)); if (dv == dope_vector_.end()) { - throw compilation_error(v->loc(), status::internal_compiler_error, + throw compilation_error(v.loc(), status::internal_compiler_error, "Dope vector for value is missing"); } return dv->second; } -void convert_to_opencl_pass::set_dope_vector(value_node *v, dope_vector dv) { - uintptr_t u = std::bit_cast(v); +void convert_to_opencl_pass::set_dope_vector(value_node const &v, dope_vector dv) { + uintptr_t u = std::bit_cast(&v); dope_vector_[u] = std::move(dv); } @@ -219,24 +219,24 @@ std::vector convert_to_opencl_pass::operator()(alloca_inst const &a) static_cast(a.stack_ptr()) + t->size_in_bytes()); // no declarations are necessary as alloca only accepts fixed-size memrefs - set_dope_vector(a.result().get(), + set_dope_vector(a.result(0), dope_vector::from_value(*a.result(), [](clir::data_type, clir::var, dope_vector::type, std::int64_t) {})); return {std::move(result)}; } std::vector convert_to_opencl_pass::operator()(axpby_inst const &inst) { - auto at = get_memref_type(*inst.A()); - auto bt = get_memref_type(*inst.B()); - auto alpha_ty = get_scalar_type(*inst.alpha()); - auto beta_ty = get_scalar_type(*inst.beta()); - auto &adv = get_dope_vector(inst.A().get()); - auto &bdv = get_dope_vector(inst.B().get()); + auto at = get_memref_type(inst.A()); + auto bt = get_memref_type(inst.B()); + auto alpha_ty = get_scalar_type(inst.alpha()); + auto beta_ty = get_scalar_type(inst.beta()); + auto &adv = get_dope_vector(inst.A()); + auto &bdv = get_dope_vector(inst.B()); auto pA = inst.tA() == transpose::T && at->dim() == 2 ? 1 : 0; - auto alpha = val(*inst.alpha()); - auto beta = val(*inst.beta()); + auto alpha = val(inst.alpha()); + auto beta = val(inst.beta()); auto const inner_loop = [&](clir::block_builder &bb, clir::expr Ab, clir::expr Bb, clir::expr trip_count, std::size_t num_tiles, clir::var sg_id) { auto m = bb.declare_assign(clir::generic_uint(), "m", clir::get_sub_group_local_id()); @@ -262,8 +262,8 @@ std::vector convert_to_opencl_pass::operator()(axpby_inst const &ins }); }; - auto A = val(*inst.A()); - auto B = val(*inst.B()); + auto A = val(inst.A()); + auto B = val(inst.B()); if (bt->dim() == 0) { auto bb = clir::block_builder{}; const auto a_scaled = multiply(alpha_ty, at->element_ty(), alpha, A[0]); @@ -352,10 +352,10 @@ std::vector convert_to_opencl_pass::operator()(arith_inst const &a) } return {}; }; - auto sty = get_scalar_type(*a.a()); + auto sty = get_scalar_type(a.a()); auto v = declare(*a.result()); return {declaration_assignment(visit(*this, *a.result()->ty()), std::move(v), - make(a.operation(), val(*a.a()), val(*a.b()), sty))}; + make(a.operation(), val(a.a()), val(a.b()), sty))}; } std::vector convert_to_opencl_pass::operator()(arith_unary_inst const &a) { @@ -371,16 +371,16 @@ std::vector convert_to_opencl_pass::operator()(arith_unary_inst cons } return {}; }; - auto sty = get_scalar_type(*a.a()); + auto sty = get_scalar_type(a.a()); auto v = declare(*a.result()); return {declaration_assignment(visit(*this, *a.result()->ty()), std::move(v), - make(a.operation(), val(*a.a()), sty))}; + make(a.operation(), val(a.a()), sty))}; } std::vector convert_to_opencl_pass::operator()(cast_inst const &c) { auto v = declare(*c.result()); auto result_ty = visit(*this, *c.result()->ty()); - auto cst = cast(result_ty, val(*c.a())); + auto cst = cast(result_ty, val(c.a())); return {declaration_assignment(std::move(result_ty), std::move(v), std::move(cst))}; } @@ -404,12 +404,12 @@ std::vector convert_to_opencl_pass::operator()(compare_inst const &c }; auto v = declare(*c.result()); return {declaration_assignment(visit(*this, *c.result()->ty()), std::move(v), - make(c.cond(), val(*c.a()), val(*c.b())))}; + make(c.cond(), val(c.a()), val(c.b())))}; } std::vector convert_to_opencl_pass::operator()(constant_inst const &c) { - auto v = declare(*c.result()); - auto ty = get_scalar_type(*c.result()); + auto v = declare(c.result(0)); + auto ty = get_scalar_type(c.result(0)); auto ty_bits = static_cast(size(ty) * 8); auto rhs = std::visit(overloaded{ @@ -426,12 +426,12 @@ std::vector convert_to_opencl_pass::operator()(constant_inst const & std::vector convert_to_opencl_pass::operator()(expand_inst const &e) { auto result_var = declare(*e.result()); - auto m = get_memref_type(*e.operand()); - auto &dv = get_dope_vector(e.operand().get()); + auto m = get_memref_type(e.operand()); + auto &dv = get_dope_vector(e.operand()); auto static_shape = e.static_expand_shape(); auto dyn_shape = e.expand_shape(); - auto rhs = val(*e.operand()); + auto rhs = val(e.operand()); auto clinst = std::vector{}; clinst.emplace_back( clir::declaration_assignment(this->operator()(*m), std::move(result_var), std::move(rhs))); @@ -468,7 +468,7 @@ std::vector convert_to_opencl_pass::operator()(expand_inst const &e) stride.push_back(dv.stride(i)); } - set_dope_vector(e.result().get(), + set_dope_vector(e.result(0), dope_vector::from_value(*e.result(), [&](clir::data_type a, clir::var b, dope_vector::type t, std::int64_t j) { auto init = t == dope_vector::type::stride ? stride[j] : shape[j]; @@ -479,10 +479,10 @@ std::vector convert_to_opencl_pass::operator()(expand_inst const &e) } std::vector convert_to_opencl_pass::operator()(fuse_inst const &f) { auto result_var = declare(*f.result()); - auto m = get_memref_type(*f.operand()); - auto &dv = get_dope_vector(f.operand().get()); + auto m = get_memref_type(f.operand()); + auto &dv = get_dope_vector(f.operand()); - auto rhs = val(*f.operand()); + auto rhs = val(f.operand()); auto shape = std::vector{}; auto stride = std::vector{}; shape.reserve(m->dim()); @@ -507,7 +507,7 @@ std::vector convert_to_opencl_pass::operator()(fuse_inst const &f) { clinst.emplace_back( clir::declaration_assignment(this->operator()(*m), std::move(result_var), std::move(rhs))); - set_dope_vector(f.result().get(), + set_dope_vector(*f.result(), dope_vector::from_value(*f.result(), [&](clir::data_type a, clir::var b, dope_vector::type t, std::int64_t j) { auto init = t == dope_vector::type::stride ? stride[j] : shape[j]; @@ -518,8 +518,7 @@ std::vector convert_to_opencl_pass::operator()(fuse_inst const &f) { } std::vector convert_to_opencl_pass::operator()(load_inst const &e) { - auto op_val = e.operand(); - auto rhs = val(*op_val); + auto rhs = val(e.operand()); auto clinst = std::vector{}; @@ -531,11 +530,11 @@ std::vector convert_to_opencl_pass::operator()(load_inst const &e) { auto idx = val(*e.index_list().front()); rhs = rhs + idx; - auto &dv = get_dope_vector(e.operand().get()); + auto &dv = get_dope_vector(e.operand()); rhs = clir::dereference(std::move(rhs)) + dv.offset(); set_dope_vector( - e.result().get(), + *e.result(), dope_vector::from_value( *e.result(), [&](clir::data_type a, clir::var b, dope_vector::type t, std::int64_t j) { @@ -549,7 +548,7 @@ std::vector convert_to_opencl_pass::operator()(load_inst const &e) { if (static_cast(e.index_list().size()) != m.dim()) { throw compilation_error(e.loc(), status::ir_invalid_number_of_indices); } - auto &dv = get_dope_vector(e.operand().get()); + auto &dv = get_dope_vector(e.operand()); for (std::int64_t i = 0; i < m.dim(); ++i) { rhs = rhs + val(*e.index_list()[i]) * dv.stride(i); } @@ -558,7 +557,7 @@ std::vector convert_to_opencl_pass::operator()(load_inst const &e) { [&e](auto const &) { throw compilation_error(e.loc(), status::ir_expected_memref_or_group); }}, - *e.operand()->ty()); + *e.operand().ty()); auto lhs = declare(*e.result()); auto result_type = e.result()->ty(); @@ -591,20 +590,20 @@ std::vector convert_to_opencl_pass::operator()(lifetime_stop_inst co } std::vector convert_to_opencl_pass::operator()(gemm_inst const &g) { - auto a = get_memref_type(*g.A()); - auto b = get_memref_type(*g.B()); - auto c = get_memref_type(*g.C()); - auto &adv = get_dope_vector(g.A().get()); - auto &bdv = get_dope_vector(g.B().get()); - auto &cdv = get_dope_vector(g.C().get()); + auto a = get_memref_type(g.A()); + auto b = get_memref_type(g.B()); + auto c = get_memref_type(g.C()); + auto &adv = get_dope_vector(g.A()); + auto &bdv = get_dope_vector(g.B()); + auto &cdv = get_dope_vector(g.C()); auto const M = c->shape(0); auto const N = c->shape(1); auto const ak = g.tA() == transpose::T ? 0 : 1; auto const K = a->shape(ak); - auto gemm_ty = gemm_scalar_type{get_scalar_type(*g.alpha()), a->element_ty(), b->element_ty(), - get_scalar_type(*g.beta()), c->element_ty()}; + auto gemm_ty = gemm_scalar_type{get_scalar_type(g.alpha()), a->element_ty(), b->element_ty(), + get_scalar_type(g.beta()), c->element_ty()}; auto cfg = gemm_configuration{std::move(gemm_ty), g.tA(), g.tB(), @@ -630,26 +629,26 @@ std::vector convert_to_opencl_pass::operator()(gemm_inst const &g) { } has_gemm_.emplace(name); return {clir::expression_statement(clir::call( - std::move(name), {cdv.shape(0), cdv.shape(1), adv.shape(ak), val(*g.alpha()), val(*g.A()), - adv.stride(0), adv.stride(1), val(*g.B()), bdv.stride(0), bdv.stride(1), - val(*g.beta()), val(*g.C()), cdv.stride(0), cdv.stride(1)}))}; + std::move(name), {cdv.shape(0), cdv.shape(1), adv.shape(ak), val(g.alpha()), val(g.A()), + adv.stride(0), adv.stride(1), val(g.B()), bdv.stride(0), bdv.stride(1), + val(g.beta()), val(g.C()), cdv.stride(0), cdv.stride(1)}))}; } std::vector convert_to_opencl_pass::operator()(gemv_inst const &g) { - auto a = get_memref_type(*g.A()); - auto b = get_memref_type(*g.B()); - auto c = get_memref_type(*g.C()); - auto &adv = get_dope_vector(g.A().get()); - auto &bdv = get_dope_vector(g.B().get()); - auto &cdv = get_dope_vector(g.C().get()); + auto a = get_memref_type(g.A()); + auto b = get_memref_type(g.B()); + auto c = get_memref_type(g.C()); + auto &adv = get_dope_vector(g.A()); + auto &bdv = get_dope_vector(g.B()); + auto &cdv = get_dope_vector(g.C()); auto const M = c->shape(0); auto const ak = g.tA() == transpose::T ? 0 : 1; auto const K = a->shape(ak); constexpr auto N = 1; - auto gemm_ty = gemm_scalar_type{get_scalar_type(*g.alpha()), a->element_ty(), b->element_ty(), - get_scalar_type(*g.beta()), c->element_ty()}; + auto gemm_ty = gemm_scalar_type{get_scalar_type(g.alpha()), a->element_ty(), b->element_ty(), + get_scalar_type(g.beta()), c->element_ty()}; auto cfg = gemm_configuration{std::move(gemm_ty), g.tA(), transpose::N, @@ -675,27 +674,27 @@ std::vector convert_to_opencl_pass::operator()(gemv_inst const &g) { } has_gemm_.emplace(name); return {clir::expression_statement( - clir::call(std::move(name), {cdv.shape(0), 1, adv.shape(ak), val(*g.alpha()), val(*g.A()), - adv.stride(0), adv.stride(1), val(*g.B()), bdv.stride(0), 0, - val(*g.beta()), val(*g.C()), cdv.stride(0), 0}))}; + clir::call(std::move(name), {cdv.shape(0), 1, adv.shape(ak), val(g.alpha()), val(g.A()), + adv.stride(0), adv.stride(1), val(g.B()), bdv.stride(0), 0, + val(g.beta()), val(g.C()), cdv.stride(0), 0}))}; } std::vector convert_to_opencl_pass::operator()(ger_inst const &g) { - auto at = get_memref_type(*g.A()); - auto bt = get_memref_type(*g.B()); - auto ct = get_memref_type(*g.C()); - auto &adv = get_dope_vector(g.A().get()); - auto &bdv = get_dope_vector(g.B().get()); - auto &cdv = get_dope_vector(g.C().get()); - - auto alpha = val(*g.alpha()); - auto beta = val(*g.beta()); - auto alpha_ty = get_scalar_type(*g.alpha()); - auto beta_ty = get_scalar_type(*g.beta()); - - auto A = val(*g.A()); - auto B = val(*g.B()); - auto C = val(*g.C()); + auto at = get_memref_type(g.A()); + auto bt = get_memref_type(g.B()); + auto ct = get_memref_type(g.C()); + auto &adv = get_dope_vector(g.A()); + auto &bdv = get_dope_vector(g.B()); + auto &cdv = get_dope_vector(g.C()); + + auto alpha = val(g.alpha()); + auto beta = val(g.beta()); + auto alpha_ty = get_scalar_type(g.alpha()); + auto beta_ty = get_scalar_type(g.beta()); + + auto A = val(g.A()); + auto B = val(g.B()); + auto C = val(g.C()); auto bb = clir::block_builder{}; auto sg_n = bb.declare_assign(clir::generic_uint(), "sg_n", @@ -752,11 +751,11 @@ std::vector convert_to_opencl_pass::operator()(for_inst const &p) { auto clinst = std::vector{}; yielded_vars_.push_back(std::vector{}); - auto lv = declare(*p.loop_var()); - auto lv_ty = visit(*this, *p.loop_var()->ty()); - auto start = clir::declaration_assignment(std::move(lv_ty), lv, val(*p.from())); - auto condition = lv < val(*p.to()); - auto step = p.step() ? clir::add_into(lv, val(*p.step())) : ++lv; + auto lv = declare(p.loop_var()); + auto lv_ty = visit(*this, *p.loop_var().ty()); + auto start = clir::declaration_assignment(std::move(lv_ty), lv, val(p.from())); + auto condition = lv < val(p.to()); + auto step = p.has_step() ? clir::add_into(lv, val(p.step())) : ++lv; auto body = run_on_region(p.body()); clinst.emplace_back(clir::stmt(std::make_shared( std::move(start), std::move(condition), std::move(step), std::move(body)))); @@ -766,10 +765,10 @@ std::vector convert_to_opencl_pass::operator()(for_inst const &p) { } std::vector convert_to_opencl_pass::operator()(foreach_inst const &p) { - auto lv = declare(*p.loop_var()); - auto lv_ty = visit(*this, *p.loop_var()->ty()); - auto from = val(*p.from()); - auto to = val(*p.to()); + auto lv = declare(p.loop_var()); + auto lv_ty = visit(*this, *p.loop_var().ty()); + auto from = val(p.from()); + auto to = val(p.to()); auto bb = clir::block_builder{}; auto sg = bb.declare_assign(clir::generic_uint(), "sg", clir::get_sub_group_id()); auto m = bb.declare_assign(clir::generic_uint(), "m", clir::get_sub_group_local_id()); @@ -784,21 +783,21 @@ std::vector convert_to_opencl_pass::operator()(foreach_inst const &p } std::vector convert_to_opencl_pass::operator()(hadamard_inst const &g) { - auto at = get_memref_type(*g.A()); - auto bt = get_memref_type(*g.B()); - auto ct = get_memref_type(*g.C()); - auto &adv = get_dope_vector(g.A().get()); - auto &bdv = get_dope_vector(g.B().get()); - auto &cdv = get_dope_vector(g.C().get()); - - auto alpha = val(*g.alpha()); - auto beta = val(*g.beta()); - auto alpha_ty = get_scalar_type(*g.alpha()); - auto beta_ty = get_scalar_type(*g.beta()); - - auto A = val(*g.A()); - auto B = val(*g.B()); - auto C = val(*g.C()); + auto at = get_memref_type(g.A()); + auto bt = get_memref_type(g.B()); + auto ct = get_memref_type(g.C()); + auto &adv = get_dope_vector(g.A()); + auto &bdv = get_dope_vector(g.B()); + auto &cdv = get_dope_vector(g.C()); + + auto alpha = val(g.alpha()); + auto beta = val(g.beta()); + auto alpha_ty = get_scalar_type(g.alpha()); + auto beta_ty = get_scalar_type(g.beta()); + + auto A = val(g.A()); + auto B = val(g.B()); + auto C = val(g.C()); auto bb = clir::block_builder{}; auto sg = bb.declare_assign(clir::generic_uint(), "sg", clir::get_sub_group_id()); @@ -836,11 +835,11 @@ std::vector convert_to_opencl_pass::operator()(if_inst const &in) { auto clinst = std::vector{}; yielded_vars_.push_back(std::vector{}); for (auto const &r : in.results()) { - auto v = declare(*r); - clinst.emplace_back(clir::declaration(visit(*this, *r->ty()), v)); + auto v = declare(r); + clinst.emplace_back(clir::declaration(visit(*this, *r.ty()), v)); yielded_vars_.back().emplace_back(std::move(v)); } - auto ib = clir::if_selection_builder(val(*in.condition())); + auto ib = clir::if_selection_builder(val(in.condition())); ib.set_then(run_on_region(in.then())); if (!in.is_otherwise_empty()) { ib.set_otherwise(run_on_region(in.otherwise())); @@ -863,7 +862,7 @@ std::vector convert_to_opencl_pass::operator()(parallel_inst const & std::vector convert_to_opencl_pass::operator()(size_inst const &s) { auto v = declare(*s.result()); - auto &dv = get_dope_vector(s.operand().get()); + auto &dv = get_dope_vector(s.operand()); return {clir::declaration_assignment(visit(*this, *s.result()->ty()), std::move(v), dv.shape(s.mode()))}; @@ -892,10 +891,10 @@ std::vector convert_to_opencl_pass::operator()(subgroup_size_inst co std::vector convert_to_opencl_pass::operator()(subview_inst const &s) { auto result_var = declare(*s.result()); - auto t = get_memref_type(*s.operand()); - auto &dv = get_dope_vector(s.operand().get()); + auto t = get_memref_type(s.operand()); + auto &dv = get_dope_vector(s.operand()); - auto rhs = val(*s.operand()); + auto rhs = val(s.operand()); int j = 0; auto shape_out = std::vector{}; auto stride_out = std::vector{}; @@ -935,7 +934,7 @@ std::vector convert_to_opencl_pass::operator()(subview_inst const &s clinst.emplace_back( clir::declaration_assignment(this->operator()(*t), std::move(result_var), std::move(rhs))); - set_dope_vector(s.result().get(), + set_dope_vector(*s.result(), dope_vector::from_value(*s.result(), [&](clir::data_type a, clir::var b, dope_vector::type t, std::int64_t j) { auto init = t == dope_vector::type::stride ? stride_out[j] : shape_out[j]; @@ -946,38 +945,38 @@ std::vector convert_to_opencl_pass::operator()(subview_inst const &s } std::vector convert_to_opencl_pass::operator()(store_inst const &s) { - auto ot = get_memref_type(*s.operand()); + auto ot = get_memref_type(s.operand()); if (static_cast(s.index_list().size()) != ot->dim()) { throw compilation_error(s.loc(), status::ir_invalid_number_of_indices); } - auto lhs = val(*s.operand()); - auto &dv = get_dope_vector(s.operand().get()); + auto lhs = val(s.operand()); + auto &dv = get_dope_vector(s.operand()); for (std::int64_t i = 0; i < ot->dim(); ++i) { lhs = lhs + val(*s.index_list()[i]) * dv.stride(i); } - auto rhs = val(*s.val()); + auto rhs = val(s.val()); auto st = assignment(dereference(std::move(lhs)), std::move(rhs)); return {expression_statement(std::move(st))}; } std::vector convert_to_opencl_pass::operator()(sum_inst const &inst) { - auto at = get_memref_type(*inst.A()); - auto bt = get_memref_type(*inst.B()); - auto &adv = get_dope_vector(inst.A().get()); - auto &bdv = get_dope_vector(inst.B().get()); + auto at = get_memref_type(inst.A()); + auto bt = get_memref_type(inst.B()); + auto &adv = get_dope_vector(inst.A()); + auto &bdv = get_dope_vector(inst.B()); - auto alpha = val(*inst.alpha()); - auto beta = val(*inst.beta()); - auto alpha_ty = get_scalar_type(*inst.alpha()); - auto beta_ty = get_scalar_type(*inst.beta()); + auto alpha = val(inst.alpha()); + auto beta = val(inst.beta()); + auto alpha_ty = get_scalar_type(inst.alpha()); + auto beta_ty = get_scalar_type(inst.beta()); auto zero = clir::expr(0.0, static_cast(size(at->element_ty()) * 8)); - auto A = val(*inst.A()); - auto B = val(*inst.B()); + auto A = val(inst.A()); + auto B = val(inst.B()); auto bb = clir::block_builder{}; auto acc = bb.declare_assign(to_clir_ty(at->element_ty()), "acc", std::move(zero)); auto sg = bb.declare_assign(clir::generic_uint(), "sg", clir::get_sub_group_id()); @@ -1057,8 +1056,8 @@ std::vector convert_to_opencl_pass::operator()(yield_inst const &in) } std::vector clinst; for (std::int64_t i = 0; i < in.num_operands(); ++i) { - clinst.push_back( - clir::expression_statement(clir::assignment(yielded_vars_.back()[i], val(*in.op(i))))); + auto assign_yielded_var = clir::assignment(yielded_vars_.back()[i], val(*in.op(i))); + clinst.push_back(clir::expression_statement(std::move(assign_yielded_var))); } return clinst; } @@ -1094,22 +1093,22 @@ auto convert_to_opencl_pass::run_on_function(function_node const &fn) -> clir::f // Create prototype auto fb = clir::kernel_builder(std::string(fn.name())); for (auto const &v : fn.params()) { - fb.argument(visit(*this, *v->ty()), declare(*v)); + fb.argument(visit(*this, *v.ty()), declare(v)); auto dv = visit( overloaded{[&fb, &v](memref_data_type const &) -> std::optional { return std::make_optional(dope_vector::from_value( - *v, [&](clir::data_type a, clir::var b, dope_vector::type, - std::int64_t) { fb.argument(std::move(a), std::move(b)); })); + v, [&](clir::data_type a, clir::var b, dope_vector::type, + std::int64_t) { fb.argument(std::move(a), std::move(b)); })); }, [&fb, &v](group_data_type const &) -> std::optional { return std::make_optional(dope_vector::from_value( - *v, [&](clir::data_type a, clir::var b, dope_vector::type, - std::int64_t) { fb.argument(std::move(a), std::move(b)); })); + v, [&](clir::data_type a, clir::var b, dope_vector::type, + std::int64_t) { fb.argument(std::move(a), std::move(b)); })); }, [](auto const &) { return std::nullopt; }}, - *v->ty()); + *v.ty()); if (dv) { - set_dope_vector(v.get(), std::move(*dv)); + set_dope_vector(v, std::move(*dv)); } } diff --git a/src/pass/convert_to_opencl.hpp b/src/pass/convert_to_opencl.hpp index c71103cb..6203fdf8 100644 --- a/src/pass/convert_to_opencl.hpp +++ b/src/pass/convert_to_opencl.hpp @@ -105,8 +105,8 @@ class convert_to_opencl_pass { auto run_on_function(function_node const &fn) -> clir::func; auto val(value_node const &v) -> clir::expr; - auto get_dope_vector(value_node *v) -> dope_vector &; - void set_dope_vector(value_node *v, dope_vector dv); + auto get_dope_vector(value_node const &v) -> dope_vector &; + void set_dope_vector(value_node const &v, dope_vector dv); clir::var declare(value_node const &v); auto get_memref_type(value_node const &v) const -> const memref_data_type *; static auto get_scalar_type(value_node const &v) -> scalar_type; diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 6b17c5a3..e0513c3b 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -57,47 +57,47 @@ void dump_ir_pass::dump_val(value_node const &v) { /* Inst nodes */ void dump_ir_pass::dump_blas_a2(blas_a2_inst const &g) { - dump_val(*g.alpha()); + dump_val(g.alpha()); *os_ << ", "; - dump_val(*g.A()); + dump_val(g.A()); *os_ << ", "; - dump_val(*g.beta()); + dump_val(g.beta()); *os_ << ", "; - dump_val(*g.B()); + dump_val(g.B()); *os_ << " : "; - visit(*this, *g.alpha()->ty()); + visit(*this, *g.alpha().ty()); *os_ << ", "; - visit(*this, *g.A()->ty()); + visit(*this, *g.A().ty()); *os_ << ", "; - visit(*this, *g.beta()->ty()); + visit(*this, *g.beta().ty()); *os_ << ", "; - visit(*this, *g.B()->ty()); + visit(*this, *g.B().ty()); } void dump_ir_pass::dump_blas_a3(blas_a3_inst const &g) { - dump_val(*g.alpha()); + dump_val(g.alpha()); *os_ << ", "; - dump_val(*g.A()); + dump_val(g.A()); *os_ << ", "; - dump_val(*g.B()); + dump_val(g.B()); *os_ << ", "; - dump_val(*g.beta()); + dump_val(g.beta()); *os_ << ", "; - dump_val(*g.C()); + dump_val(g.C()); *os_ << " : "; - visit(*this, *g.alpha()->ty()); + visit(*this, *g.alpha().ty()); *os_ << ", "; - visit(*this, *g.A()->ty()); + visit(*this, *g.A().ty()); *os_ << ", "; - visit(*this, *g.B()->ty()); + visit(*this, *g.B().ty()); *os_ << ", "; - visit(*this, *g.beta()->ty()); + visit(*this, *g.beta().ty()); *os_ << ", "; - visit(*this, *g.C()->ty()); + visit(*this, *g.C().ty()); } void dump_ir_pass::operator()(alloca_inst const &a) { - dump_val(*a.result()); + dump_val(a.result(0)); *os_ << " = alloca -> "; visit(*this, *a.result()->ty()); } @@ -109,21 +109,21 @@ void dump_ir_pass::operator()(axpby_inst const &a) { } void dump_ir_pass::operator()(arith_inst const &a) { - dump_val(*a.result()); + dump_val(a.result(0)); *os_ << " = arith." << to_string(a.operation()) << " "; - dump_val(*a.a()); + dump_val(a.a()); *os_ << ", "; - dump_val(*a.b()); + dump_val(a.b()); *os_ << " : "; - visit(*this, *a.a()->ty()); + visit(*this, *a.a().ty()); } void dump_ir_pass::operator()(arith_unary_inst const &a) { - dump_val(*a.result()); + dump_val(a.result(0)); *os_ << " = arith." << to_string(a.operation()) << " "; - dump_val(*a.a()); + dump_val(a.a()); *os_ << " : "; - visit(*this, *a.a()->ty()); + visit(*this, *a.a().ty()); } void dump_ir_pass::operator()(barrier_inst const &b) { @@ -137,27 +137,27 @@ void dump_ir_pass::operator()(barrier_inst const &b) { } void dump_ir_pass::operator()(cast_inst const &c) { - dump_val(*c.result()); + dump_val(c.result(0)); *os_ << " = cast "; - dump_val(*c.a()); + dump_val(c.a()); *os_ << " : "; - visit(*this, *c.a()->ty()); + visit(*this, *c.a().ty()); *os_ << " -> "; visit(*this, *c.result()->ty()); } void dump_ir_pass::operator()(compare_inst const &a) { - dump_val(*a.result()); + dump_val(a.result(0)); *os_ << " = cmp." << to_string(a.cond()) << " "; - dump_val(*a.a()); + dump_val(a.a()); *os_ << ", "; - dump_val(*a.b()); + dump_val(a.b()); *os_ << " : "; - visit(*this, *a.a()->ty()); + visit(*this, *a.a().ty()); } void dump_ir_pass::operator()(constant_inst const &c) { - dump_val(*c.result()); + dump_val(c.result(0)); *os_ << " = constant "; std::visit(overloaded{ [&](std::int64_t i) { @@ -184,9 +184,9 @@ void dump_ir_pass::operator()(constant_inst const &c) { } void dump_ir_pass::operator()(expand_inst const &e) { - dump_val(*e.result()); + dump_val(e.result(0)); *os_ << " = expand "; - dump_val(*e.operand()); + dump_val(e.operand()); *os_ << "[" << e.expanded_mode() << "->"; auto const &ses = e.static_expand_shape(); auto es = e.expand_shape(); @@ -201,42 +201,42 @@ void dump_ir_pass::operator()(expand_inst const &e) { } } *os_ << "] : "; - visit(*this, *e.operand()->ty()); + visit(*this, *e.operand().ty()); } void dump_ir_pass::operator()(fuse_inst const &f) { - dump_val(*f.result()); + dump_val(f.result(0)); *os_ << " = fuse "; - dump_val(*f.operand()); + dump_val(f.operand()); *os_ << "[" << f.from() << "," << f.to() << "]"; *os_ << " : "; - visit(*this, *f.operand()->ty()); + visit(*this, *f.operand().ty()); } void dump_ir_pass::operator()(load_inst const &e) { - dump_val(*e.result()); + dump_val(e.result(0)); *os_ << " = load "; - dump_val(*e.operand()); + dump_val(e.operand()); *os_ << "["; do_with_infix(e.index_list().begin(), e.index_list().end(), [this](auto const &i) { dump_val(*i); }); *os_ << "] : "; - visit(*this, *e.operand()->ty()); + visit(*this, *e.operand().ty()); } void dump_ir_pass::operator()(group_id_inst const &g) { - dump_val(*g.result()); + dump_val(g.result(0)); *os_ << " = group_id"; } void dump_ir_pass::operator()(group_size_inst const &g) { - dump_val(*g.result()); + dump_val(g.result(0)); *os_ << " = group_size"; } void dump_ir_pass::operator()(lifetime_stop_inst const &l) { *os_ << "lifetime_stop "; - dump_val(*l.object()); + dump_val(l.object()); } void dump_ir_pass::operator()(gemm_inst const &g) { @@ -259,30 +259,30 @@ void dump_ir_pass::operator()(ger_inst const &g) { void dump_ir_pass::operator()(for_inst const &p) { *os_ << "for "; - dump_val(*p.loop_var()); + dump_val(p.loop_var()); *os_ << "="; - dump_val(*p.from()); + dump_val(p.from()); *os_ << ","; - dump_val(*p.to()); - if (p.step()) { + dump_val(p.to()); + if (p.has_step()) { *os_ << ","; - dump_val(*p.step()); + dump_val(p.step()); } *os_ << " : "; - visit(*this, *p.loop_var()->ty()); + visit(*this, *p.loop_var().ty()); *os_ << " "; dump_region(p.body()); } void dump_ir_pass::operator()(foreach_inst const &p) { *os_ << "foreach "; - dump_val(*p.loop_var()); + dump_val(p.loop_var()); *os_ << "="; - dump_val(*p.from()); + dump_val(p.from()); *os_ << ","; - dump_val(*p.to()); + dump_val(p.to()); *os_ << " : "; - visit(*this, *p.loop_var()->ty()); + visit(*this, *p.loop_var().ty()); *os_ << " "; dump_region(p.body()); } @@ -294,7 +294,7 @@ void dump_ir_pass::operator()(hadamard_inst const &g) { void dump_ir_pass::operator()(if_inst const &in) { *os_ << "if "; - dump_val(*in.condition()); + dump_val(in.condition()); *os_ << " "; dump_region(in.then()); if (!in.is_otherwise_empty()) { @@ -304,7 +304,7 @@ void dump_ir_pass::operator()(if_inst const &in) { } void dump_ir_pass::operator()(num_subgroups_inst const &sg) { - dump_val(*sg.result()); + dump_val(sg.result(0)); *os_ << " = num_subgroups"; } @@ -314,33 +314,33 @@ void dump_ir_pass::operator()(parallel_inst const &p) { } void dump_ir_pass::operator()(size_inst const &s) { - dump_val(*s.result()); + dump_val(s.result(0)); *os_ << " = size "; - dump_val(*s.operand()); + dump_val(s.operand()); *os_ << "[" << s.mode() << "]"; *os_ << " : "; - visit(*this, *s.operand()->ty()); + visit(*this, *s.operand().ty()); } void dump_ir_pass::operator()(subgroup_id_inst const &sg) { - dump_val(*sg.result()); + dump_val(sg.result(0)); *os_ << " = subgroup_id"; } void dump_ir_pass::operator()(subgroup_local_id_inst const &sg) { - dump_val(*sg.result()); + dump_val(sg.result(0)); *os_ << " = subgroup_local_id"; } void dump_ir_pass::operator()(subgroup_size_inst const &sg) { - dump_val(*sg.result()); + dump_val(sg.result(0)); *os_ << " = subgroup_size"; } void dump_ir_pass::operator()(subview_inst const &s) { - dump_val(*s.result()); + dump_val(s.result(0)); *os_ << " = subview "; - dump_val(*s.operand()); + dump_val(s.operand()); *os_ << "["; auto dyn_offsets = s.offsets(); auto dyn_sizes = s.sizes(); @@ -366,21 +366,21 @@ void dump_ir_pass::operator()(subview_inst const &s) { } *os_ << "]"; *os_ << " : "; - visit(*this, *s.operand()->ty()); + visit(*this, *s.operand().ty()); *os_ << " ; -> "; visit(*this, *s.result()->ty()); } void dump_ir_pass::operator()(store_inst const &e) { *os_ << "store "; - dump_val(*e.val()); + dump_val(e.val()); *os_ << ", "; - dump_val(*e.operand()); + dump_val(e.operand()); *os_ << "["; do_with_infix(e.index_list().begin(), e.index_list().end(), [this](auto const &i) { dump_val(*i); }); *os_ << "] : "; - visit(*this, *e.operand()->ty()); + visit(*this, *e.operand().ty()); } void dump_ir_pass::operator()(sum_inst const &a) { @@ -425,9 +425,9 @@ void dump_ir_pass::run_on_function(function_node const &fn) { do_with_infix( fn.params().begin(), fn.params().end(), [this](auto const &a) { - dump_val(*a); + dump_val(a); *os_ << ": "; - visit(*this, *a->ty()); + visit(*this, *a.ty()); }, infix); *os_ << ") "; diff --git a/src/pass/insert_barrier.cpp b/src/pass/insert_barrier.cpp index 58639448..24298181 100644 --- a/src/pass/insert_barrier.cpp +++ b/src/pass/insert_barrier.cpp @@ -128,14 +128,14 @@ void insert_barrier_pass::run_on_region(region_node ®, aa_results const &aa) auto const get_rw = [](inst_node &in) -> reads_writes { auto rw = reads_writes{}; - auto const emplace_read = [&rw](value const &v) { - if (auto *m = dyn_cast(v->ty()); m) { - rw.emplace_read(m->addrspace(), v.get()); + auto const emplace_read = [&rw](value_node const &v) { + if (auto *m = dyn_cast(v.ty()); m) { + rw.emplace_read(m->addrspace(), &v); } }; - auto const emplace_write = [&rw](value const &v) { - if (auto *m = dyn_cast(v->ty()); m) { - rw.emplace_write(m->addrspace(), v.get()); + auto const emplace_write = [&rw](value_node const &v) { + if (auto *m = dyn_cast(v.ty()); m) { + rw.emplace_write(m->addrspace(), &v); } }; diff --git a/src/pass/insert_lifetime_stop.cpp b/src/pass/insert_lifetime_stop.cpp index 596e225b..3fe416eb 100644 --- a/src/pass/insert_lifetime_stop.cpp +++ b/src/pass/insert_lifetime_stop.cpp @@ -22,10 +22,10 @@ auto insert_lifetime_stop_pass::run_on_region(region_node ®, aa_results const return {}; } - auto allocas = std::vector{}; + auto allocas = std::vector{}; for (auto &i : reg) { if (auto alloca = dyn_cast(&i); alloca != nullptr) { - allocas.emplace_back(alloca->result(0)); + allocas.emplace_back(&alloca->result(0)); } } @@ -42,14 +42,14 @@ auto insert_lifetime_stop_pass::run_on_region(region_node ®, aa_results const } } for (auto &v : i.results()) { - if (isa(*v->ty())) { - rgn_ops.insert(aa.root(*v)); + if (isa(*v.ty())) { + rgn_ops.insert(aa.root(v)); } } auto alloca_it = allocas.begin(); while (alloca_it != allocas.end()) { - if (rgn_ops.contains(alloca_it->get())) { + if (rgn_ops.contains(*alloca_it)) { prev_it = reg.insts().insert_after( prev_it, std::make_unique(*alloca_it).release()); --prev_it; diff --git a/src/pass/slot_tracker.cpp b/src/pass/slot_tracker.cpp index 5356799a..69a87119 100644 --- a/src/pass/slot_tracker.cpp +++ b/src/pass/slot_tracker.cpp @@ -19,11 +19,11 @@ void slot_tracker::set_slot(value_node const &v) { void slot_tracker::run_on_function(function_node &fn) { slot_ = 0; for (auto const &arg : fn.params()) { - set_slot(*arg); + set_slot(arg); } walk(fn, [this](inst_node const &i) { for (auto const &result : i.results()) { - set_slot(*result); + set_slot(result); } }); } diff --git a/src/pass/stack.cpp b/src/pass/stack.cpp index 0bafb883..0ca9601e 100644 --- a/src/pass/stack.cpp +++ b/src/pass/stack.cpp @@ -41,14 +41,14 @@ void set_stack_ptr_pass::run_on_function(function_node &fn) { } stack_ptr = it->stop; } - allocs.insert(it, allocation{a.result().get(), stack_ptr, stack_ptr + size}); + allocs.insert(it, allocation{a.result(), stack_ptr, stack_ptr + size}); a.stack_ptr(stack_ptr); }, [&allocs](lifetime_stop_inst &s) { int num = 0; - auto v = s.object().get(); + auto v = s.object(); for (auto it = allocs.begin(); it != allocs.end();) { - if (it->value == v) { + if (it->value == &v) { it = allocs.erase(it); ++num; } else { diff --git a/src/pass/work_group_size.cpp b/src/pass/work_group_size.cpp index da0a651a..7e13d417 100644 --- a/src/pass/work_group_size.cpp +++ b/src/pass/work_group_size.cpp @@ -22,7 +22,7 @@ namespace tinytc { -auto get_memref_type(value_node &v) { +auto get_memref_type(value_node const &v) { auto t = dyn_cast(v.ty()); if (t == nullptr) { throw compilation_error(v.loc(), status::ir_expected_memref); @@ -47,7 +47,7 @@ void work_group_size_pass::run_on_function(function_node &fn) { walk(fn, [&shape_set](inst_node &i) { visit( overloaded{[&shape_set](blas_a2_inst &in) { - auto b = get_memref_type(*in.B()); + auto b = get_memref_type(in.B()); if (b->dim() == 1) { shape_set.insert({b->element_ty(), {b->shape(0), 0}}); } else if (b->dim() >= 2) { @@ -55,7 +55,7 @@ void work_group_size_pass::run_on_function(function_node &fn) { } }, [&shape_set](blas_a3_inst &in) { - auto c = get_memref_type(*in.C()); + auto c = get_memref_type(in.C()); if (c->dim() == 1) { shape_set.insert({c->element_ty(), {c->shape(0), 0}}); } else if (c->dim() >= 2) { diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index 4f96399b..784010d2 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -110,20 +110,19 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( auto bb = region_builder{fn_body}; auto gid = bb.add(make_group_id(ctx_, my_loc())); - auto const static_offsets = std::vector{0, 0}; - auto const A_static_sizes = std::vector{M, K}; - auto const B_static_sizes = std::vector{K, N}; - auto const C_static_sizes = std::vector{M, N}; - auto a = bb.add( - make_subview(value{A, true}, static_offsets, A_static_sizes, {}, {}, my_loc())); - auto b = bb.add( - make_subview(value{B, true}, static_offsets, B_static_sizes, {}, {}, my_loc())); - auto c = bb.add( - make_subview(value{C, true}, static_offsets, C_static_sizes, {}, {}, my_loc())); - auto beta_val = - is_beta_nonzero ? value{beta, true} : bb.add(make_constant(0.0, ty_, my_loc())); - bb.add(make_gemm(tA_, tB_, false, value{alpha, true}, std::move(a), std::move(b), - beta_val, std::move(c), my_loc())); + auto const static_offsets = std::vector{0, 0, dynamic}; + auto const A_static_sizes = std::vector{M, K, 0}; + auto const B_static_sizes = std::vector{K, N, 0}; + auto const C_static_sizes = std::vector{M, N, 0}; + auto a = bb.add(make_subview(A, static_offsets, A_static_sizes, + array_view{gid}, {}, my_loc())); + auto b = bb.add(make_subview(B, static_offsets, B_static_sizes, + array_view{gid}, {}, my_loc())); + auto c = bb.add(make_subview(C, static_offsets, C_static_sizes, + array_view{gid}, {}, my_loc())); + auto beta_val = is_beta_nonzero ? beta : bb.add(make_constant(0.0, ty_, my_loc())); + bb.add(make_gemm(tA_, tB_, false, alpha, std::move(a), std::move(b), beta_val, + std::move(c), my_loc())); return f; }; diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 1b6dc732..a486a76f 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -104,16 +104,16 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( tiling[1] /= 2; } - auto const body = [&](region_builder &bb, value const &alpha, value const &A, - value const &B, bool is_beta_nonzero, value const &beta_arg, - value const &C) { + auto const body = [&](region_builder &bb, tinytc_value_t alpha, tinytc_value_t A, + tinytc_value_t B, bool is_beta_nonzero, tinytc_value_t beta_arg, + tinytc_value_t C) { auto c_M_block_size = bb.add(make_constant(M_block_size, index_ty, my_loc())); auto gid = bb.add(make_group_id(ctx_, my_loc())); auto m = bb.add(make_arith(arithmetic::mul, gid, c_M_block_size, my_loc())); auto beta = is_beta_nonzero ? beta_arg : bb.add(make_constant(0.0, ty_, my_loc())); auto const static_offsets = std::vector{dynamic, 0}; - auto const offsets = std::vector{m}; + auto const offsets = array_view{m}; auto const static_gemm = [&](region_builder &bb) { auto const A_static_sizes = std::vector{M_block_size, K}; @@ -125,10 +125,10 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( bb.add(make_gemm(transpose::N, transpose::N, false, alpha, a, B, beta, c, my_loc())); }; - auto const dynamic_gemm = [&](region_builder &bb, value const &dyn_block_size) { + auto const dynamic_gemm = [&](region_builder &bb, tinytc_value_t dyn_block_size) { auto const A_static_sizes = std::vector{dynamic, K}; auto const C_static_sizes = std::vector{dynamic, N}; - auto const sizes = std::vector{dyn_block_size}; + auto const sizes = array_view{dyn_block_size}; auto a = bb.add( make_subview(A, static_offsets, A_static_sizes, offsets, sizes, my_loc())); auto c = bb.add( @@ -175,8 +175,7 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( f.set_work_group_size(wgs[0], wgs[1]); auto bb = region_builder{fn_body}; - body(bb, value{alpha, true}, value{A, true}, value{B, true}, is_beta_nonzero, - value{beta, true}, value{C, true}); + body(bb, alpha, A, B, is_beta_nonzero, beta, C); return f; }; diff --git a/src/region.cpp b/src/region.cpp index 69d5a6ce..88619ea0 100644 --- a/src/region.cpp +++ b/src/region.cpp @@ -27,7 +27,7 @@ tinytc_status_t tinytc_region_get_parameter(tinytc_region_t reg, uint32_t param_ if (reg == nullptr || result == nullptr || param_no >= reg->num_params()) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { *result = reg->param(param_no).get(); }); + return exception_to_status_code([&] { *result = ®->param(param_no); }); } tinytc_status_t tinytc_region_get_parameters(tinytc_region_t reg, uint32_t *result_list_size, @@ -47,7 +47,7 @@ tinytc_status_t tinytc_region_get_parameters(tinytc_region_t reg, uint32_t *resu auto results = reg->param_begin(); auto const limit = std::min(num, *result_list_size); for (uint32_t i = 0; i < limit; ++i) { - result_list[i] = results[i].get(); + result_list[i] = &results[i]; } } *result_list_size = num; diff --git a/src/value.cpp b/src/value.cpp index a6c964e2..b6316a7a 100644 --- a/src/value.cpp +++ b/src/value.cpp @@ -17,34 +17,6 @@ using namespace tinytc; extern "C" { -tinytc_status_t tinytc_value_create(tinytc_value_t *vl, tinytc_data_type_t type, - const tinytc_location_t *lc) { - if (vl == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code( - [&] { *vl = std::make_unique(type, get_optional(lc)).release(); }); -} - -tinytc_status_t tinytc_value_release(tinytc_value_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; -} - -tinytc_status_t tinytc_value_retain(tinytc_value_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - obj->inc_ref(); - return tinytc_status_success; -} - tinytc_status_t tinytc_value_set_name(tinytc_value_t vl, char const *name) { if (vl == nullptr) { return tinytc_status_invalid_arguments; From a8c93723cad713a635703e18af55e5a152d8bf44 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 1 Oct 2024 11:06:41 +0200 Subject: [PATCH 033/297] Clean up api Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.rst | 28 +--- docs/api/builder_capi.yaml | 4 +- docs/api/builder_cxxapi.rst | 7 - docs/api/builder_cxxapi.yaml | 1 - docs/api/core_cxxapi.rst | 14 ++ docs/api/core_cxxapi.yaml | 2 + include/tinytc/tinytc.h | 81 +++------ include/tinytc/tinytc.hpp | 263 +++++++++++++++--------------- src/data_type.cpp | 3 +- src/inst.cpp | 38 ++--- src/recipe/small_gemm_batched.cpp | 27 ++- src/recipe/tall_and_skinny.cpp | 20 +-- src/region.cpp | 6 +- 13 files changed, 216 insertions(+), 278 deletions(-) diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index b30c770b..c4d32f69 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -76,6 +76,8 @@ Common * :ref:`const_tinytc_region_t` + * :ref:`const_tinytc_value_t` + Common Enumerations ------------------- @@ -236,6 +238,11 @@ const_tinytc_region_t .. doxygentypedef:: const_tinytc_region_t +const_tinytc_value_t +.................... + +.. doxygentypedef:: const_tinytc_value_t + Data Type ========= @@ -375,12 +382,8 @@ Instruction * :ref:`tinytc_yield_inst_create` - * :ref:`tinytc_inst_get_region` - * :ref:`tinytc_inst_get_regions` - * :ref:`tinytc_inst_get_value` - * :ref:`tinytc_inst_get_values` * :ref:`tinytc_inst_destroy` @@ -543,21 +546,11 @@ tinytc_yield_inst_create .. doxygenfunction:: tinytc_yield_inst_create -tinytc_inst_get_region -...................... - -.. doxygenfunction:: tinytc_inst_get_region - tinytc_inst_get_regions ....................... .. doxygenfunction:: tinytc_inst_get_regions -tinytc_inst_get_value -..................... - -.. doxygenfunction:: tinytc_inst_get_value - tinytc_inst_get_values ...................... @@ -639,8 +632,6 @@ Region * :ref:`tinytc_region_add_instruction` - * :ref:`tinytc_region_get_parameter` - * :ref:`tinytc_region_get_parameters` Region Functions @@ -651,11 +642,6 @@ tinytc_region_add_instruction .. doxygenfunction:: tinytc_region_add_instruction -tinytc_region_get_parameter -........................... - -.. doxygenfunction:: tinytc_region_get_parameter - tinytc_region_get_parameters ............................ diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index 3b7218ee..b17ff7e1 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -36,6 +36,7 @@ Builder C-API: - const_tinytc_inst_t - const_tinytc_prog_t - const_tinytc_region_t + - const_tinytc_value_t Data Type: function: - tinytc_group_type_get @@ -81,9 +82,7 @@ Builder C-API: - tinytc_subview_inst_create - tinytc_sum_inst_create - tinytc_yield_inst_create - - tinytc_inst_get_region - tinytc_inst_get_regions - - tinytc_inst_get_value - tinytc_inst_get_values - tinytc_inst_destroy Program: @@ -99,7 +98,6 @@ Builder C-API: Region: function: - tinytc_region_add_instruction - - tinytc_region_get_parameter - tinytc_region_get_parameters Value: function: diff --git a/docs/api/builder_cxxapi.rst b/docs/api/builder_cxxapi.rst index d494f17a..6c09f8a6 100644 --- a/docs/api/builder_cxxapi.rst +++ b/docs/api/builder_cxxapi.rst @@ -519,8 +519,6 @@ Region * :ref:`get_num_parameters` - * :ref:`get_parameter` - * :ref:`get_parameters` * Classes @@ -540,11 +538,6 @@ get_num_parameters .. doxygenfunction:: tinytc::get_num_parameters -get_parameter -............. - -.. doxygenfunction:: tinytc::get_parameter - get_parameters .............. diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index 27123482..310a90c0 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -84,7 +84,6 @@ Builder C++-API: function: - tinytc::add_instruction - tinytc::get_num_parameters - - tinytc::get_parameter - tinytc::get_parameters class: - tinytc::region_builder diff --git a/docs/api/core_cxxapi.rst b/docs/api/core_cxxapi.rst index 76284917..f765f2af 100644 --- a/docs/api/core_cxxapi.rst +++ b/docs/api/core_cxxapi.rst @@ -24,8 +24,12 @@ Common * Classes + * :ref:`array_view_base` + * :ref:`array_view` + * :ref:`mutable_array_view` + * :ref:`shared_handle` * :ref:`unique_handle` @@ -68,11 +72,21 @@ CHECK_STATUS_LOC Common Classes -------------- +array_view_base +............... + +.. doxygenclass:: tinytc::array_view_base + array_view .......... .. doxygenclass:: tinytc::array_view +mutable_array_view +.................. + +.. doxygenclass:: tinytc::mutable_array_view + shared_handle ............. diff --git a/docs/api/core_cxxapi.yaml b/docs/api/core_cxxapi.yaml index f5e7e09c..b8c32032 100644 --- a/docs/api/core_cxxapi.yaml +++ b/docs/api/core_cxxapi.yaml @@ -10,7 +10,9 @@ Core C++-API: - tinytc::CHECK_STATUS - tinytc::CHECK_STATUS_LOC class: + - tinytc::array_view_base - tinytc::array_view + - tinytc::mutable_array_view - tinytc::shared_handle - tinytc::unique_handle typedef: diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 027c9648..cb315fb1 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -82,7 +82,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_scalar_type_get(tinytc_data_type_t *dt, TINYTC_EXPORT tinytc_status_t tinytc_memref_type_get( tinytc_data_type_t *dt, tinytc_compiler_context_t ctx, tinytc_scalar_type_t scalar_ty, uint32_t shape_size, const int64_t *shape, uint32_t stride_size, const int64_t *stride, - const tinytc_address_space_t addrspace, const tinytc_location_t *loc); + tinytc_address_space_t addrspace, const tinytc_location_t *loc); /** * @brief Get group data type @@ -511,10 +511,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_num_subgroups_inst_create(tinytc_inst_t *in /** * @brief Create parallel region * - * Takes ownership of region. - * * @code - * parallel { %body } + * parallel { } * @endcode * * @param instr [out] pointer to the inst object created @@ -659,13 +657,11 @@ TINYTC_EXPORT tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinyt /** * @brief Create for loop * - * Takes ownership of region. - * * @code - * for %loop_var = %from, %to, %step : type(%loop_var) { %body } - * ; type(%loop_var) == type(%from) - * ; type(%loop_var) == type(%to) - * ; type(%loop_var) == type(%step) + * for %loop_var = %from, %to, %step : loop_var_type { } + * ; loop_var_type == type(%from) + * ; loop_var_type == type(%to) + * ; loop_var_type == type(%step) * @endcode * * @param instr [out] pointer to the inst object created @@ -685,12 +681,10 @@ TINYTC_EXPORT tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinyt /** * @brief Create foreach loop * - * Takes ownership of region. - * * @code - * foreach %loop_var = %from, %to : type(%loop_var) { %body } - * ; type(%loop_var) == type(%from) - * ; type(%loop_var) == type(%to) + * foreach %loop_var = %from, %to : loop_var_type { } + * ; loop_var_type == type(%from) + * ; loop_var_type == type(%to) * @endcode * * @param instr [out] pointer to the inst object created @@ -709,10 +703,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, t /** * @brief Create if condition * - * Takes ownership of if and else region (if given). - * * @code - * if %condition { %then } else { %otherwise } + * if %condition -> (return_type_list, ...) { } else { } * @endcode * * @param instr [out] pointer to the inst object created @@ -756,16 +748,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_yield_inst_create(tinytc_inst_t *instr, */ TINYTC_EXPORT void tinytc_inst_destroy(tinytc_inst_t instr); -/** - * @brief Get value produced by instruction - * - * @param instr [in] inst object - * @param result [out] result value; may be set to nullptr if instruction does not return a value - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_inst_get_value(tinytc_inst_t instr, tinytc_value_t *result); - /** * @brief Get values produced by instruction * @@ -773,7 +755,9 @@ TINYTC_EXPORT tinytc_status_t tinytc_inst_get_value(tinytc_inst_t instr, tinytc_ * the number of results * * @param instr [in] inst object - * @param result_list_size [inout] number of results to fetch; is updated with the actual value + * @param result_list_size [inout] pointer to the number of results; if result_list_size is 0, then + * it is updated with the number of results; if result_list_size is greater than the number of + * results, the value is updated with the correct number of results * @param result_list [out][range(0, result_list_size)] user-provided memory for storing result * handles; at most result_list_size values are written; can be nullptr if result_list_size is 0 * @@ -783,18 +767,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_inst_get_values(tinytc_inst_t instr, uint32_t *result_list_size, tinytc_value_t *result_list); -/** - * @brief Get child region of instruction - * - * @param instr [in] inst object - * @param region_no [in] region index - * @param result [out] result value - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_inst_get_region(tinytc_inst_t instr, uint32_t region_no, - tinytc_region_t *result); - /** * @brief Get child regions of instruction * @@ -802,7 +774,9 @@ TINYTC_EXPORT tinytc_status_t tinytc_inst_get_region(tinytc_inst_t instr, uint32 * the number of results * * @param instr [in] inst object - * @param result_list_size [inout] number of results to fetch; is updated with the actual value + * @param result_list_size [inout] pointer to the number of results; if result_list_size is 0, then + * it is updated with the number of results; if result_list_size is greater than the number of + * results, the value is updated with the correct number of results * @param result_list [out][range(0, result_list_size)] user-provided memory for storing result * handles; at most result_list_size values are written; can be nullptr if result_list_size is 0 * @@ -830,18 +804,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_inst_get_regions(tinytc_inst_t instr, TINYTC_EXPORT tinytc_status_t tinytc_region_add_instruction(tinytc_region_t reg, tinytc_inst_t instruction); -/** - * @brief Get region parameter - * - * @param reg [in] region object - * @param param_no [in] parameter index - * @param result [out] result value - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_region_get_parameter(tinytc_region_t reg, uint32_t param_no, - tinytc_value_t *result); - /** * @brief Get region parameters * @@ -849,7 +811,9 @@ TINYTC_EXPORT tinytc_status_t tinytc_region_get_parameter(tinytc_region_t reg, u * the number of results * * @param reg [in] region object - * @param result_list_size [inout] number of results to fetch; is updated with the actual value + * @param result_list_size [inout] pointer to the number of results; if result_list_size is 0, then + * it is updated with the number of results; if result_list_size is greather than the number of + * results, the value is updated with the correct number of results * @param result_list [out][range(0, result_list_size)] user-provided memory for storing result * handles; at most result_list_size values are written; can be nullptr if result_list_size is 0 * @@ -886,7 +850,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_func_create(tinytc_func_t *fun, uint32_t na /** * @brief Set work-group size * - * @param fun [out] func object (must be the function definition, not the function prototype) + * @param fun [inout] function object * @param x [in] number of rows in parallel grid; must be a multiple of the subgroup size * @param y [in] number of columns in parallel grid * @@ -897,7 +861,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_func_set_work_group_size(tinytc_func_t fun, /** * @brief Set subgroup size * - * @param fun [out] func object (must be the function definition, not the function prototype) + * @param fun [inout] function object * @param sgs [in] subgroup size; the supported values need to be queried from the compute device * * @return tinytc_status_success on success and error otherwise @@ -943,7 +907,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_prog_create(tinytc_prog_t *prg, tinytc_comp * @brief Append function to program * * The program takes ownership of the function. - * A function must not be added to multiple programs. + * A function must not be added to multiple programs nor must the user destroy the function after + * adding it to the program. * * @param prg [inout] program object * @param fun [in,pass_ownership] function object diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index ca4454d3..5c7729dd 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -286,25 +286,25 @@ template class unique_handle { //////////////////////////// /** - * @brief Stores a view on an array (pointer + size) + * @brief Base implementation of array view * * @tparam T array element type */ -template class array_view { +template class array_view_base { public: - using const_iterator = T const *; + using iterator = T *; /** * @brief Empty array view */ - array_view() = default; + array_view_base() = default; /** * @brief Single element view * * @param single the single element */ - array_view(T const &single) : data_{&single}, size_{1} {} + array_view_base(T &single) : data_{&single}, size_{1} {} /** * @brief ctor @@ -312,7 +312,7 @@ template class array_view { * @param data base pointer * @param size array size */ - array_view(T const *data, std::size_t size) : data_{data}, size_{size} {} + array_view_base(T *data, std::size_t size) : data_{data}, size_{size} {} /** * @brief ctor @@ -320,7 +320,42 @@ template class array_view { * @param begin begin pointer * @param end end pointer (not included) */ - array_view(T const *begin, T const *end) : data_{begin}, size_{end - begin} {} + array_view_base(T *begin, T *end) : data_{begin}, size_{end - begin} {} + + //! Begin iterator + auto begin() const -> iterator { return data_; } + //! End iterator + auto end() const -> iterator { return data_ + size_; } + //! Returns true if view is empty + auto empty() const -> bool { return size_ == 0; } + //! Returns array size + auto size() const -> std::size_t { return size_; } + //! Access first element; must not call when array size is 0 + auto front() const -> T & { return data_[0]; } + //! Access last element; must not call when array size is 0 + auto back() const -> T & { return data_[size_ - 1]; } + //! Get data pointer + auto data() const -> T * { return data_; } + //! Access operator + auto operator[](std::size_t n) const -> T & { return data_[n]; } + //! Convert to vector + operator std::vector>() const { + return std::vector>(data_, data_ + size_); + } + + private: + T *data_ = nullptr; + std::size_t size_ = 0; +}; + +/** + * @brief Stores an immutable view on an array (pointer + size) + * + * @tparam T array element type + */ +template class array_view : public array_view_base { + public: + using array_view_base::array_view_base; /** * @brief Convert vector to array view @@ -328,7 +363,7 @@ template class array_view { * @param vec standard vector */ array_view(std::vector const &vec) - : data_{!vec.empty() ? vec.data() : nullptr}, size_{vec.size()} {} + : array_view_base{(!vec.empty() ? vec.data() : nullptr), vec.size()} {} /** * @brief Convert std::array to array view @@ -337,7 +372,7 @@ template class array_view { * @param arr standard array */ template - array_view(std::array const &arr) : data_{arr.data()}, size_{arr.size()} {} + array_view(std::array const &arr) : array_view_base{arr.data(), arr.size()} {} /** * @brief Convert initializer list to array view (array_view must be rvalue) @@ -345,32 +380,45 @@ template class array_view { * @param arr initializer list */ array_view(std::initializer_list const &arr) - : data_{arr.begin() != arr.end() ? arr.begin() : nullptr}, size_{arr.size()} {} + : array_view_base{(arr.begin() != arr.end() ? arr.begin() : nullptr), arr.size()} { + } +}; - //! Begin iterator - auto begin() const -> const_iterator { return data_; } - //! End iterator - auto end() const -> const_iterator { return data_ + size_; } - //! Returns true if view is empty - auto empty() const -> bool { return size_ == 0; } - //! Returns array size - auto size() const -> std::size_t { return size_; } - //! Access first element; must not call when array size is 0 - auto front() const -> T const & { return data_[0]; } - //! Access last element; must not call when array size is 0 - auto back() const -> T const & { return data_[size_ - 1]; } - //! Get data pointer - auto data() const -> T const * { return data_; } - //! Access operator - auto operator[](std::size_t n) const -> T const & { return data_[n]; } - //! Convert to vector - operator std::vector() const { return std::vector(data_, data_ + size_); } +template array_view(T const &) -> array_view; +template array_view(T const *, std::size_t) -> array_view; +template array_view(T const *, T const *) -> array_view; - private: - T const *data_ = nullptr; - std::size_t size_ = 0; +/** + * @brief Stores a mutable view on an array (pointer + size) + * + * @tparam T array element type + */ +template class mutable_array_view : public array_view_base { + public: + using array_view_base::array_view_base; + + /** + * @brief Convert vector to array view + * + * @param vec standard vector + */ + mutable_array_view(std::vector &vec) + : array_view_base{(!vec.empty() ? vec.data() : nullptr), vec.size()} {} + + /** + * @brief Convert std::array to array view + * + * @tparam N array size + * @param arr standard array + */ + template + mutable_array_view(std::array &arr) : array_view_base{arr.data(), arr.size()} {} }; +template mutable_array_view(T &) -> mutable_array_view; +template mutable_array_view(T *, std::size_t) -> mutable_array_view; +template mutable_array_view(T *, T *) -> mutable_array_view; + //////////////////////////// ///// Compiler context ///// //////////////////////////// @@ -481,7 +529,7 @@ inline tinytc_data_type_t get_scalar(compiler_context const &ctx, scalar_type sc inline tinytc_data_type_t get_memref(compiler_context const &ctx, scalar_type scalar_ty, array_view shape, array_view stride = {}, - const address_space addrspace = address_space::global, + address_space addrspace = address_space::global, location const &loc = {}) { tinytc_data_type_t mt; CHECK_STATUS_LOC( @@ -607,74 +655,34 @@ class inst : public unique_handle { using unique_handle::unique_handle; /** - * @brief Get result value + * @brief Get result values * - * @return Value; may be empty - */ - inline auto get_value() const -> tinytc_value_t { - tinytc_value_t result; - CHECK_STATUS(tinytc_inst_get_value(obj_, &result)); - return result; - } - - /** - * @brief Get number of result values + * May be called with empty view (vals = {}) to get the number of results. * - * @return Number of result values - */ - inline auto get_num_values() const -> std::uint32_t { - std::uint32_t result_list_size = 0; - CHECK_STATUS(tinytc_inst_get_values(obj_, &result_list_size, nullptr)); - return result_list_size; - } - - /** - * @brief Get result values + * @param vals view on buffer that stores results * - * @return Vector of values + * @return Minimum of view size and actual number of result values */ - inline auto get_values() const -> std::vector { - std::uint32_t result_list_size = get_num_values(); - auto values = std::vector(result_list_size); - CHECK_STATUS(tinytc_inst_get_values(obj_, &result_list_size, values.data())); - return values; + inline auto get_values(mutable_array_view vals) const -> std::uint32_t { + std::uint32_t result_list_size = vals.size(); + CHECK_STATUS(tinytc_inst_get_values(obj_, &result_list_size, vals.data())); + return result_list_size; } /** - * @brief Get child region + * @brief Get child regions * - * @param region_no region index + * May be called with empty view (vals = {}) to get the number of child regions. * - * @return Region - */ - inline auto get_region(std::uint32_t region_no) const -> tinytc_region_t { - tinytc_region_t result; - CHECK_STATUS(tinytc_inst_get_region(obj_, region_no, &result)); - return result; - } - - /** - * @brief Get number of child regions + * @param regs view on buffer that stores results * - * @return Number of child regions + * @return Minimum of view size and actual number of child regions */ - inline auto get_num_regions() const -> std::uint32_t { + inline auto get_regions(mutable_array_view regs) const -> std::uint32_t { std::uint32_t result_list_size = 0; - CHECK_STATUS(tinytc_inst_get_regions(obj_, &result_list_size, nullptr)); + CHECK_STATUS(tinytc_inst_get_regions(obj_, &result_list_size, regs.data())); return result_list_size; } - - /** - * @brief Get child regions - * - * @return Vector of regions - */ - inline auto get_regions() const -> std::vector { - std::uint32_t result_list_size = get_num_regions(); - auto regions = std::vector(result_list_size); - CHECK_STATUS(tinytc_inst_get_regions(obj_, &result_list_size, regions.data())); - return regions; - } }; //////////////////////////// @@ -692,40 +700,20 @@ inline void add_instruction(tinytc_region_t reg, inst instruction) { } /** - * @brief Get region parameter - * - * @param reg Region object - * @param region_no Region index + * @brief Get region parameters * - * @return Parameter - */ -inline auto get_parameter(tinytc_region_t reg, std::uint32_t region_no) -> tinytc_value_t { - tinytc_value_t result; - CHECK_STATUS(tinytc_region_get_parameter(reg, region_no, &result)); - return result; -} - -/** - * @brief Get number of child regions + * May be called with empty view (vals = {}) to get the number of parameters. * - * @return Number of child regions - */ -inline auto get_num_parameters(tinytc_region_t reg) -> std::uint32_t { - std::uint32_t result_list_size = 0; - CHECK_STATUS(tinytc_region_get_parameters(reg, &result_list_size, nullptr)); - return result_list_size; -} - -/** - * @brief Get parameters + * @param reg region object + * @param params view on buffer that stores parameters * - * @return Vector of parameters + * @return Minimum of view size and actual number of parameters */ -inline auto get_parameters(tinytc_region_t reg) -> std::vector { - std::uint32_t result_list_size = get_num_parameters(reg); - auto params = std::vector(result_list_size); +inline auto get_parameters(tinytc_region_t reg, + mutable_array_view params) -> std::uint32_t { + std::uint32_t result_list_size = params.size(); CHECK_STATUS(tinytc_region_get_parameters(reg, &result_list_size, params.data())); - return params; + return result_list_size; } //////////////////////////// @@ -1481,8 +1469,8 @@ class region_builder { * @return Value returned by instruction; may be empty */ [[maybe_unused]] inline auto add(inst i, std::string_view name = "") -> tinytc_value_t { - auto result = i.get_value(); - if (result && name.size() > 0) { + tinytc_value_t result = nullptr; + if (i.get_values(result) > 0 && name.size() > 0) { set_name(result, name); } add_instruction(reg_, std::move(i)); @@ -1499,7 +1487,9 @@ class region_builder { */ [[maybe_unused]] inline auto add_multivalued(inst i, std::string_view name = "") -> std::vector { - auto results = i.get_values(); + auto num_results = i.get_values({}); + auto results = std::vector(static_cast(num_results)); + results.resize(i.get_values(results)); if (name.size() > 0) { int counter = 0; auto name_str = std::string{name}; @@ -1549,8 +1539,13 @@ class region_builder { tinytc_data_type_t loop_var_ty, F &&f, std::string_view loop_var_name = "", location const &loc = {}) { auto fi = ::tinytc::make_for(from, to, step, loop_var_ty, loc); - auto reg = fi.get_region(0); - auto loop_var = get_parameter(reg, 0); + tinytc_region_t reg = nullptr; + fi.get_regions(reg); + tinytc_value_t loop_var = nullptr; + get_parameters(reg, loop_var); + if (!reg || !loop_var) { + throw status::internal_compiler_error; + } set_name(loop_var, loop_var_name); add_instruction(reg_, std::move(fi)); auto bb = region_builder{reg}; @@ -1571,8 +1566,13 @@ class region_builder { void foreach (tinytc_value_t from, tinytc_value_t to, tinytc_data_type_t loop_var_ty, F && f, std::string const &loop_var_name = "", location const &loc = {}) { auto fi = ::tinytc::make_foreach(from, to, loop_var_ty, loc); - auto reg = fi.get_region(0); - auto loop_var = get_parameter(reg, 0); + tinytc_region_t reg = nullptr; + fi.get_regions(reg); + tinytc_value_t loop_var = nullptr; + get_parameters(reg, loop_var); + if (!reg || !loop_var) { + throw status::internal_compiler_error; + } set_name(loop_var, loop_var_name); add_instruction(reg_, std::move(fi)); auto bb = region_builder{reg}; @@ -1595,9 +1595,13 @@ class region_builder { array_view return_type_list = {}, location const &loc = {}) -> std::vector { auto ii = ::tinytc::make_if(std::move(condition), return_type_list, loc); - auto r0 = ii.get_region(0); + tinytc_region_t reg = nullptr; + ii.get_regions(reg); + if (!reg) { + throw status::internal_compiler_error; + } auto results = add_multivalued(std::move(ii)); - auto bb = region_builder{r0}; + auto bb = region_builder{reg}; then(bb); return results; } @@ -1620,12 +1624,15 @@ class region_builder { array_view return_type_list = {}, location const &loc = {}) -> std::vector { auto ii = ::tinytc::make_if(std::move(condition), return_type_list, loc); - auto r0 = ii.get_region(0); - auto r1 = ii.get_region(1); + auto regs = std::array{nullptr, nullptr}; + ii.get_regions(regs); + if (!regs[0] || !regs[1]) { + throw status::internal_compiler_error; + } auto results = add_multivalued(std::move(ii)); - auto bb0 = region_builder{r0}; + auto bb0 = region_builder{regs[0]}; then(bb0); - auto bb1 = region_builder{r1}; + auto bb1 = region_builder{regs[1]}; otherwise(bb1); return results; } diff --git a/src/data_type.cpp b/src/data_type.cpp index ac033ab1..d7a54f83 100644 --- a/src/data_type.cpp +++ b/src/data_type.cpp @@ -33,8 +33,7 @@ tinytc_status_t tinytc_scalar_type_get(tinytc_data_type_t *dt, tinytc_compiler_c tinytc_status_t tinytc_memref_type_get(tinytc_data_type_t *dt, tinytc_compiler_context_t ctx, tinytc_scalar_type_t scalar_ty, uint32_t shape_size, const int64_t *shape, uint32_t stride_size, - const int64_t *stride, - const tinytc_address_space_t addrspace, + const int64_t *stride, tinytc_address_space_t addrspace, const tinytc_location_t *loc) { if (dt == nullptr || ctx == nullptr || (shape_size != 0 && shape == nullptr) || (stride_size != 0 && stride == nullptr)) { diff --git a/src/inst.cpp b/src/inst.cpp index be989482..999d79cc 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -17,7 +17,6 @@ #include #include #include -#include using namespace tinytc; @@ -454,12 +453,10 @@ tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condi return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto rt = std::vector(); - rt.reserve(return_type_list_size); - for (uint32_t i = 0; i < return_type_list_size; ++i) { - rt.emplace_back(return_type_list[i]); - } - *instr = std::make_unique(condition, std::move(rt), get_optional(loc)).release(); + *instr = std::make_unique(condition, + array_view{return_type_list, return_type_list_size}, + get_optional(loc)) + .release(); }); } @@ -478,13 +475,6 @@ tinytc_status_t tinytc_yield_inst_create(tinytc_inst_t *instr, uint32_t yield_li void tinytc_inst_destroy(tinytc_inst_t obj) { delete obj; } -tinytc_status_t tinytc_inst_get_value(tinytc_inst_t instr, tinytc_value_t *result) { - if (instr == nullptr || result == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { *result = instr->result(); }); -} - tinytc_status_t tinytc_inst_get_values(tinytc_inst_t instr, uint32_t *result_list_size, tinytc_value_t *result_list) { if (instr == nullptr || result_list_size == nullptr || @@ -496,11 +486,11 @@ tinytc_status_t tinytc_inst_get_values(tinytc_inst_t instr, uint32_t *result_lis if (num_results > std::numeric_limits::max()) { throw std::out_of_range("too many results"); } - auto const num = static_cast(num_results); + auto num = static_cast(num_results); if (*result_list_size > 0) { + num = std::min(num, *result_list_size); auto results = instr->result_begin(); - auto const limit = std::min(num, *result_list_size); - for (uint32_t i = 0; i < limit; ++i) { + for (uint32_t i = 0; i < num; ++i) { result_list[i] = &results[i]; } } @@ -508,14 +498,6 @@ tinytc_status_t tinytc_inst_get_values(tinytc_inst_t instr, uint32_t *result_lis }); } -tinytc_status_t tinytc_inst_get_region(tinytc_inst_t instr, uint32_t region_no, - tinytc_region_t *result) { - if (instr == nullptr || result == nullptr || region_no >= instr->num_child_regions()) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { *result = &instr->child_region(region_no); }); -} - tinytc_status_t tinytc_inst_get_regions(tinytc_inst_t instr, uint32_t *result_list_size, tinytc_region_t *result_list) { if (instr == nullptr || result_list_size == nullptr || @@ -527,11 +509,11 @@ tinytc_status_t tinytc_inst_get_regions(tinytc_inst_t instr, uint32_t *result_li if (num_results > std::numeric_limits::max()) { throw std::out_of_range("too many results"); } - auto const num = static_cast(num_results); + auto num = static_cast(num_results); if (*result_list_size > 0) { auto results = instr->child_regions_begin(); - auto const limit = std::min(num, *result_list_size); - for (uint32_t i = 0; i < limit; ++i) { + num = std::min(num, *result_list_size); + for (uint32_t i = 0; i < num; ++i) { result_list[i] = &results[i]; } } diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index 784010d2..7dac902d 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -96,16 +96,13 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( {1, ldC, strideC}, address_space::global, my_loc()); auto f = make_func(name, {ty_, A_ty, B_ty, ty_, C_ty}, my_loc()); auto fn_body = f.get_body(); - auto alpha = get_parameter(fn_body, 0); - set_name(alpha, "alpha"); - auto A = get_parameter(fn_body, 1); - set_name(A, "A"); - auto B = get_parameter(fn_body, 2); - set_name(B, "B"); - auto beta = get_parameter(fn_body, 3); - set_name(beta, "beta"); - auto C = get_parameter(fn_body, 4); - set_name(C, "C"); + auto params = std::array{}; + get_parameters(fn_body, params); + set_name(params[0], "alpha"); + set_name(params[1], "A"); + set_name(params[2], "B"); + set_name(params[3], "beta"); + set_name(params[4], "C"); auto bb = region_builder{fn_body}; @@ -114,14 +111,14 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( auto const A_static_sizes = std::vector{M, K, 0}; auto const B_static_sizes = std::vector{K, N, 0}; auto const C_static_sizes = std::vector{M, N, 0}; - auto a = bb.add(make_subview(A, static_offsets, A_static_sizes, + auto a = bb.add(make_subview(params[1], static_offsets, A_static_sizes, array_view{gid}, {}, my_loc())); - auto b = bb.add(make_subview(B, static_offsets, B_static_sizes, + auto b = bb.add(make_subview(params[2], static_offsets, B_static_sizes, array_view{gid}, {}, my_loc())); - auto c = bb.add(make_subview(C, static_offsets, C_static_sizes, + auto c = bb.add(make_subview(params[3], static_offsets, C_static_sizes, array_view{gid}, {}, my_loc())); - auto beta_val = is_beta_nonzero ? beta : bb.add(make_constant(0.0, ty_, my_loc())); - bb.add(make_gemm(tA_, tB_, false, alpha, std::move(a), std::move(b), beta_val, + auto beta = is_beta_nonzero ? params[4] : bb.add(make_constant(0.0, ty_, my_loc())); + bb.add(make_gemm(tA_, tB_, false, params[0], std::move(a), std::move(b), beta, std::move(c), my_loc())); return f; diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index a486a76f..ee081db9 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -160,22 +160,18 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( address_space::global, my_loc()); auto f = make_func(name, {ty_, A_ty, B_ty, ty_, C_ty}, my_loc()); auto fn_body = f.get_body(); - auto alpha = get_parameter(fn_body, 0); - set_name(alpha, "alpha"); - auto A = get_parameter(fn_body, 1); - set_name(A, "A"); - auto B = get_parameter(fn_body, 2); - set_name(B, "B"); - auto beta = get_parameter(fn_body, 3); - set_name(beta, "beta"); - auto C = get_parameter(fn_body, 4); - set_name(C, "C"); - f.set_subgroup_size(sgs); + auto params = std::array{}; + get_parameters(fn_body, params); + set_name(params[0], "alpha"); + set_name(params[1], "A"); + set_name(params[2], "B"); + set_name(params[3], "beta"); + set_name(params[4], "C"); auto const wgs = tiling.work_group_size(sgs); f.set_work_group_size(wgs[0], wgs[1]); auto bb = region_builder{fn_body}; - body(bb, alpha, A, B, is_beta_nonzero, beta, C); + body(bb, params[0], params[1], params[2], is_beta_nonzero, params[3], params[4]); return f; }; diff --git a/src/region.cpp b/src/region.cpp index 88619ea0..9f1d9ea6 100644 --- a/src/region.cpp +++ b/src/region.cpp @@ -42,11 +42,11 @@ tinytc_status_t tinytc_region_get_parameters(tinytc_region_t reg, uint32_t *resu if (num_results > std::numeric_limits::max()) { throw std::out_of_range("too many results"); } - auto const num = static_cast(num_results); + auto num = static_cast(num_results); if (*result_list_size > 0) { auto results = reg->param_begin(); - auto const limit = std::min(num, *result_list_size); - for (uint32_t i = 0; i < limit; ++i) { + num = std::min(num, *result_list_size); + for (uint32_t i = 0; i < num; ++i) { result_list[i] = &results[i]; } } From 1402884722e9d458baf679d6a1eecf3fff88f617 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 1 Oct 2024 12:11:11 +0200 Subject: [PATCH 034/297] C++-style API for region and value Signed-off-by: Carsten Uphoff --- docs/api/builder_cxxapi.rst | 114 +++++----- docs/api/builder_cxxapi.yaml | 24 +-- docs/api/core_cxxapi.rst | 7 + docs/api/core_cxxapi.yaml | 1 + include/tinytc/tinytc.hpp | 339 +++++++++++++++++------------- src/codegen_tools.cpp | 35 ++- src/codegen_tools.hpp | 31 ++- src/recipe/small_gemm_batched.cpp | 20 +- src/recipe/tall_and_skinny.cpp | 25 ++- 9 files changed, 308 insertions(+), 288 deletions(-) diff --git a/docs/api/builder_cxxapi.rst b/docs/api/builder_cxxapi.rst index 6c09f8a6..38f1c75f 100644 --- a/docs/api/builder_cxxapi.rst +++ b/docs/api/builder_cxxapi.rst @@ -174,6 +174,10 @@ Data Type * :ref:`to_scalar_type` +* Typedefs + + * :ref:`data_type` + * Variables * :ref:`to_scalar_type_v` @@ -204,6 +208,14 @@ to_scalar_type .. doxygenstruct:: tinytc::to_scalar_type +Data Type Typedefs +------------------ + +data_type +......... + +.. doxygentypedef:: tinytc::data_type + Data Type Variables ------------------- @@ -248,21 +260,21 @@ Instruction * :ref:`make_axpby` - * :ref:`make_arith(arithmetic,tinytc_value_t,tinytc_value_t,location const&)` + * :ref:`make_arith(arithmetic,value,value,location const&)` - * :ref:`make_arith(arithmetic_unary,tinytc_value_t,location const&)` + * :ref:`make_arith(arithmetic_unary,value,location const&)` * :ref:`make_cast` * :ref:`make_cmp` - * :ref:`make_constant(std::complex\,tinytc_data_type_t,location const&)` + * :ref:`make_constant(std::complex\,data_type,location const&)` - * :ref:`make_constant(double,tinytc_data_type_t,location const&)` + * :ref:`make_constant(double,data_type,location const&)` - * :ref:`make_constant(std::int32_t,tinytc_data_type_t,location const&)` + * :ref:`make_constant(std::int32_t,data_type,location const&)` - * :ref:`make_constant(std::int64_t,tinytc_data_type_t,location const&)` + * :ref:`make_constant(std::int64_t,data_type,location const&)` * :ref:`make_expand` @@ -325,15 +337,15 @@ make_axpby .. doxygenfunction:: tinytc::make_axpby -make_arith(arithmetic,tinytc_value_t,tinytc_value_t,location const&) -.................................................................... +make_arith(arithmetic,value,value,location const&) +.................................................. -.. doxygenfunction:: tinytc::make_arith(arithmetic,tinytc_value_t,tinytc_value_t,location const&) +.. doxygenfunction:: tinytc::make_arith(arithmetic,value,value,location const&) -make_arith(arithmetic_unary,tinytc_value_t,location const&) -........................................................... +make_arith(arithmetic_unary,value,location const&) +.................................................. -.. doxygenfunction:: tinytc::make_arith(arithmetic_unary,tinytc_value_t,location const&) +.. doxygenfunction:: tinytc::make_arith(arithmetic_unary,value,location const&) make_cast ......... @@ -345,25 +357,25 @@ make_cmp .. doxygenfunction:: tinytc::make_cmp -make_constant(std::complex,tinytc_data_type_t,location const&) -...................................................................... +make_constant(std::complex,data_type,location const&) +............................................................. -.. doxygenfunction:: tinytc::make_constant(std::complex,tinytc_data_type_t,location const&) +.. doxygenfunction:: tinytc::make_constant(std::complex,data_type,location const&) -make_constant(double,tinytc_data_type_t,location const&) -........................................................ +make_constant(double,data_type,location const&) +............................................... -.. doxygenfunction:: tinytc::make_constant(double,tinytc_data_type_t,location const&) +.. doxygenfunction:: tinytc::make_constant(double,data_type,location const&) -make_constant(std::int32_t,tinytc_data_type_t,location const&) -.............................................................. +make_constant(std::int32_t,data_type,location const&) +..................................................... -.. doxygenfunction:: tinytc::make_constant(std::int32_t,tinytc_data_type_t,location const&) +.. doxygenfunction:: tinytc::make_constant(std::int32_t,data_type,location const&) -make_constant(std::int64_t,tinytc_data_type_t,location const&) -.............................................................. +make_constant(std::int64_t,data_type,location const&) +..................................................... -.. doxygenfunction:: tinytc::make_constant(std::int64_t,tinytc_data_type_t,location const&) +.. doxygenfunction:: tinytc::make_constant(std::int64_t,data_type,location const&) make_expand ........... @@ -513,39 +525,20 @@ prog Region ====== -* Functions - - * :ref:`add_instruction` - - * :ref:`get_num_parameters` - - * :ref:`get_parameters` - * Classes - * :ref:`region_builder` - -Region Functions ----------------- - -add_instruction -............... + * :ref:`region` -.. doxygenfunction:: tinytc::add_instruction - -get_num_parameters -.................. - -.. doxygenfunction:: tinytc::get_num_parameters - -get_parameters -.............. - -.. doxygenfunction:: tinytc::get_parameters + * :ref:`region_builder` Region Classes -------------- +region +...... + +.. doxygenclass:: tinytc::region + region_builder .............. @@ -554,22 +547,15 @@ region_builder Value ===== -* Functions - - * :ref:`get_name` - - * :ref:`set_name` - -Value Functions ---------------- +* Classes -get_name -........ + * :ref:`value` -.. doxygenfunction:: tinytc::get_name +Value Classes +------------- -set_name -........ +value +..... -.. doxygenfunction:: tinytc::set_name +.. doxygenclass:: tinytc::value diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index 310a90c0..9e3fa629 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -32,6 +32,8 @@ Builder C++-API: - tinytc::get_scalar struct: - tinytc::to_scalar_type + typedef: + - tinytc::data_type variable: - tinytc::to_scalar_type_v Function: @@ -43,14 +45,14 @@ Builder C++-API: function: - tinytc::make_alloca - tinytc::make_axpby - - tinytc::make_arith(arithmetic,tinytc_value_t,tinytc_value_t,location const&) - - tinytc::make_arith(arithmetic_unary,tinytc_value_t,location const&) + - tinytc::make_arith(arithmetic,value,value,location const&) + - tinytc::make_arith(arithmetic_unary,value,location const&) - tinytc::make_cast - tinytc::make_cmp - - tinytc::make_constant(std::complex,tinytc_data_type_t,location const&) - - tinytc::make_constant(double,tinytc_data_type_t,location const&) - - tinytc::make_constant(std::int32_t,tinytc_data_type_t,location const&) - - tinytc::make_constant(std::int64_t,tinytc_data_type_t,location const&) + - tinytc::make_constant(std::complex,data_type,location const&) + - tinytc::make_constant(double,data_type,location const&) + - tinytc::make_constant(std::int32_t,data_type,location const&) + - tinytc::make_constant(std::int64_t,data_type,location const&) - tinytc::make_expand - tinytc::make_for - tinytc::make_foreach @@ -81,13 +83,9 @@ Builder C++-API: class: - tinytc::prog Region: - function: - - tinytc::add_instruction - - tinytc::get_num_parameters - - tinytc::get_parameters class: + - tinytc::region - tinytc::region_builder Value: - function: - - tinytc::get_name - - tinytc::set_name + class: + - tinytc::value diff --git a/docs/api/core_cxxapi.rst b/docs/api/core_cxxapi.rst index f765f2af..ede509f2 100644 --- a/docs/api/core_cxxapi.rst +++ b/docs/api/core_cxxapi.rst @@ -30,6 +30,8 @@ Common * :ref:`mutable_array_view` + * :ref:`handle` + * :ref:`shared_handle` * :ref:`unique_handle` @@ -87,6 +89,11 @@ mutable_array_view .. doxygenclass:: tinytc::mutable_array_view +handle +...... + +.. doxygenclass:: tinytc::handle + shared_handle ............. diff --git a/docs/api/core_cxxapi.yaml b/docs/api/core_cxxapi.yaml index b8c32032..a9bd4ec0 100644 --- a/docs/api/core_cxxapi.yaml +++ b/docs/api/core_cxxapi.yaml @@ -13,6 +13,7 @@ Core C++-API: - tinytc::array_view_base - tinytc::array_view - tinytc::mutable_array_view + - tinytc::handle - tinytc::shared_handle - tinytc::unique_handle typedef: diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 5c7729dd..44032dbf 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -127,6 +127,34 @@ template inline constexpr scalar_type to_scalar_type_v = to_scalar_ // Shared / unique handle // //////////////////////////// +template class handle { + public: + //! Create empty (invalid) handle + handle() : obj_{nullptr} {} + //! Create handle from C handle + handle(T obj) : obj_(obj) {} + + //! Dereference C handle and get reference to underlying type + auto operator*() const -> std::remove_pointer_t & { return *obj_; } + //! Convert handle to C handle + auto operator->() const -> T { return obj_; } + //! Returns C handle + auto get() const -> T { return obj_; } + + //! Check whether handle is non-empty (valid) + explicit operator bool() const noexcept { return obj_ != nullptr; } + + //! Check equality + bool operator==(handle const &other) const { return obj_ == other.obj_; } + //! Check inequality + bool operator!=(handle const &other) const { return !(*this == other); } + + operator T() const { return obj_; } + + protected: + T obj_; +}; + namespace internal { //! Wraps retain / release calls for type T template struct shared_handle_traits {}; @@ -495,6 +523,9 @@ inline auto make_compiler_context() -> compiler_context { //! Check if mode i is dynamic ('?') inline bool is_dynamic_value(std::int64_t i) { return i == dynamic; } +//! Alias for tinytc_data_type_t +using data_type = tinytc_data_type_t; + /** * @brief Get a scalar data type * @@ -505,7 +536,7 @@ inline bool is_dynamic_value(std::int64_t i) { return i == dynamic; } * * @return Data type */ -inline tinytc_data_type_t get_scalar(compiler_context const &ctx, scalar_type scalar_ty) { +inline data_type get_scalar(compiler_context const &ctx, scalar_type scalar_ty) { tinytc_data_type_t st; CHECK_STATUS( tinytc_scalar_type_get(&st, ctx.get(), static_cast(scalar_ty))); @@ -526,11 +557,10 @@ inline tinytc_data_type_t get_scalar(compiler_context const &ctx, scalar_type sc * * @return Data type */ -inline tinytc_data_type_t get_memref(compiler_context const &ctx, scalar_type scalar_ty, - array_view shape, - array_view stride = {}, - address_space addrspace = address_space::global, - location const &loc = {}) { +inline data_type get_memref(compiler_context const &ctx, scalar_type scalar_ty, + array_view shape, array_view stride = {}, + address_space addrspace = address_space::global, + location const &loc = {}) { tinytc_data_type_t mt; CHECK_STATUS_LOC( tinytc_memref_type_get(&mt, ctx.get(), static_cast(scalar_ty), @@ -550,8 +580,8 @@ inline tinytc_data_type_t get_memref(compiler_context const &ctx, scalar_type sc * * @return Data type */ -inline tinytc_data_type_t get_group(compiler_context const &ctx, tinytc_data_type_t memref_ty, - std::int64_t offset = 0, location const &loc = {}) { +inline data_type get_group(compiler_context const &ctx, data_type memref_ty, + std::int64_t offset = 0, location const &loc = {}) { tinytc_data_type_t gt; CHECK_STATUS_LOC(tinytc_group_type_get(>, ctx.get(), memref_ty, offset, &loc), loc); return gt; @@ -561,28 +591,32 @@ inline tinytc_data_type_t get_group(compiler_context const &ctx, tinytc_data_typ /////////// Value ////////// //////////////////////////// -/** - * @brief Get name - * - * @param val value object - * - * @return Name as C-string - */ -inline auto get_name(tinytc_value_t val) -> char const * { - char const *name; - CHECK_STATUS(tinytc_value_get_name(val, &name)); - return name; -} +//! @brief OO-wrapper for tinytc_value_t +class value : public handle { + public: + using handle::handle; -/** - * @brief Set value name - * - * @param val value object - * @param name Name - */ -inline void set_name(tinytc_value_t val, std::string_view name) { - CHECK_STATUS(tinytc_value_set_name_n(val, name.size(), name.data())); -} + /** + * @brief Get name + * + * @return Name as C-string + */ + inline auto get_name() -> char const * { + char const *name; + CHECK_STATUS(tinytc_value_get_name(obj_, &name)); + return name; + } + + /** + * @brief Set value name + * + * @param name Name + */ + inline void set_name(std::string_view name) { + CHECK_STATUS(tinytc_value_set_name_n(obj_, name.size(), name.data())); + } +}; +static_assert(std::is_standard_layout_v && sizeof(value) == sizeof(tinytc_value_t)); //////////////////////////// /////////// Inst /////////// @@ -649,6 +683,8 @@ template <> struct unique_handle_traits { }; } // namespace internal +class region; + //! @brief Reference-counting wrapper for tinytc_inst_t class inst : public unique_handle { public: @@ -663,9 +699,10 @@ class inst : public unique_handle { * * @return Minimum of view size and actual number of result values */ - inline auto get_values(mutable_array_view vals) const -> std::uint32_t { + inline auto get_values(mutable_array_view vals) const -> std::uint32_t { std::uint32_t result_list_size = vals.size(); - CHECK_STATUS(tinytc_inst_get_values(obj_, &result_list_size, vals.data())); + tinytc_value_t *vs = reinterpret_cast(vals.data()); + CHECK_STATUS(tinytc_inst_get_values(obj_, &result_list_size, vs)); return result_list_size; } @@ -678,9 +715,10 @@ class inst : public unique_handle { * * @return Minimum of view size and actual number of child regions */ - inline auto get_regions(mutable_array_view regs) const -> std::uint32_t { + inline auto get_regions(mutable_array_view regs) const -> std::uint32_t { std::uint32_t result_list_size = 0; - CHECK_STATUS(tinytc_inst_get_regions(obj_, &result_list_size, regs.data())); + tinytc_region_t *rl = reinterpret_cast(regs.data()); + CHECK_STATUS(tinytc_inst_get_regions(obj_, &result_list_size, rl)); return result_list_size; } }; @@ -689,32 +727,38 @@ class inst : public unique_handle { ////////// Region ////////// //////////////////////////// -/** - * @brief Append instruction to region - * - * @param reg region object - * @param instruction instruction object - */ -inline void add_instruction(tinytc_region_t reg, inst instruction) { - CHECK_STATUS(tinytc_region_add_instruction(reg, instruction.release())); -} +//! @brief OO-wrapper for tinytc_region_t +class region : public handle { + public: + using handle::handle; -/** - * @brief Get region parameters - * - * May be called with empty view (vals = {}) to get the number of parameters. - * - * @param reg region object - * @param params view on buffer that stores parameters - * - * @return Minimum of view size and actual number of parameters - */ -inline auto get_parameters(tinytc_region_t reg, - mutable_array_view params) -> std::uint32_t { - std::uint32_t result_list_size = params.size(); - CHECK_STATUS(tinytc_region_get_parameters(reg, &result_list_size, params.data())); - return result_list_size; -} + /** + * @brief Append instruction to region + * + * @param instruction instruction object + */ + inline void add_instruction(inst instruction) { + CHECK_STATUS(tinytc_region_add_instruction(obj_, instruction.release())); + } + + /** + * + * @brief Get region parameters + * + * May be called with empty view (vals = {}) to get the number of parameters. + * + * @param params view on buffer that stores parameters + * + * @return Minimum of view size and actual number of parameters + */ + inline auto get_parameters(mutable_array_view params) -> std::uint32_t { + std::uint32_t result_list_size = params.size(); + tinytc_value_t *ps = reinterpret_cast(params.data()); + CHECK_STATUS(tinytc_region_get_parameters(obj_, &result_list_size, ps)); + return result_list_size; + } +}; +static_assert(std::is_standard_layout_v && sizeof(region) == sizeof(tinytc_region_t)); //////////////////////////// /////// Instructions /////// @@ -730,8 +774,7 @@ inline auto get_parameters(tinytc_region_t reg, * * @return Instruction */ -inline inst make_arith(arithmetic op, tinytc_value_t a, tinytc_value_t b, - location const &loc = {}) { +inline inst make_arith(arithmetic op, value a, value b, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC( tinytc_arith_inst_create(&instr, static_cast(op), a, b, &loc), loc); @@ -747,7 +790,7 @@ inline inst make_arith(arithmetic op, tinytc_value_t a, tinytc_value_t b, * * @return Instruction */ -inline inst make_arith(arithmetic_unary op, tinytc_value_t a, location const &loc = {}) { +inline inst make_arith(arithmetic_unary op, value a, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC( tinytc_arith_unary_inst_create(&instr, static_cast(op), a, &loc), @@ -764,7 +807,7 @@ inline inst make_arith(arithmetic_unary op, tinytc_value_t a, location const &lo * * @return Instruction */ -inline inst make_cast(tinytc_value_t a, scalar_type to_ty, location const &loc = {}) { +inline inst make_cast(value a, scalar_type to_ty, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC( tinytc_cast_inst_create(&instr, a, static_cast(to_ty), &loc), loc); @@ -781,8 +824,7 @@ inline inst make_cast(tinytc_value_t a, scalar_type to_ty, location const &loc = * * @return Instruction */ -inline inst make_cmp(cmp_condition cond, tinytc_value_t a, tinytc_value_t b, - location const &loc = {}) { +inline inst make_cmp(cmp_condition cond, value a, value b, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC( tinytc_cmp_inst_create(&instr, static_cast(cond), a, b, &loc), loc); @@ -798,8 +840,7 @@ inline inst make_cmp(cmp_condition cond, tinytc_value_t a, tinytc_value_t b, * * @return Instruction */ -inline inst make_constant(std::complex value, tinytc_data_type_t ty, - location const &loc = {}) { +inline inst make_constant(std::complex value, data_type ty, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC( tinytc_constant_inst_create_complex(&instr, value.real(), value.imag(), ty, &loc), loc); @@ -815,7 +856,7 @@ inline inst make_constant(std::complex value, tinytc_data_type_t ty, * * @return Instruction */ -inline inst make_constant(double value, tinytc_data_type_t ty, location const &loc = {}) { +inline inst make_constant(double value, data_type ty, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_constant_inst_create_float(&instr, value, ty, &loc), loc); return inst(instr); @@ -830,7 +871,7 @@ inline inst make_constant(double value, tinytc_data_type_t ty, location const &l * * @return Instruction */ -inline inst make_constant(std::int32_t value, tinytc_data_type_t ty, location const &loc = {}) { +inline inst make_constant(std::int32_t value, data_type ty, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_constant_inst_create_int(&instr, value, ty, &loc), loc); return inst(instr); @@ -845,7 +886,7 @@ inline inst make_constant(std::int32_t value, tinytc_data_type_t ty, location co * * @return Instruction */ -inline inst make_constant(std::int64_t value, tinytc_data_type_t ty, location const &loc = {}) { +inline inst make_constant(std::int64_t value, data_type ty, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_constant_inst_create_int(&instr, value, ty, &loc), loc); return inst(instr); @@ -859,7 +900,7 @@ inline inst make_constant(std::int64_t value, tinytc_data_type_t ty, location co * * @return Instruction */ -inline inst make_alloca(tinytc_data_type_t ty, location const &loc = {}) { +inline inst make_alloca(data_type ty, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_alloca_inst_create(&instr, ty, &loc), loc); return inst(instr); @@ -878,8 +919,8 @@ inline inst make_alloca(tinytc_data_type_t ty, location const &loc = {}) { * * @return Instruction */ -inline inst make_axpby(transpose tA, bool atomic, tinytc_value_t alpha, tinytc_value_t A, - tinytc_value_t beta, tinytc_value_t B, location const &loc = {}) { +inline inst make_axpby(transpose tA, bool atomic, value alpha, value A, value beta, value B, + location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_axpby_inst_create(&instr, static_cast(tA), atomic, alpha, A, beta, B, &loc), @@ -898,9 +939,9 @@ inline inst make_axpby(transpose tA, bool atomic, tinytc_value_t alpha, tinytc_v * * @return Instruction */ -inline inst make_expand(tinytc_value_t a, std::int64_t expanded_mode, +inline inst make_expand(value a, std::int64_t expanded_mode, array_view static_expand_shape, - array_view expand_shape, location const &loc = {}) { + array_view expand_shape, location const &loc = {}) { tinytc_inst_t instr; auto static_len = static_expand_shape.size(); if (static_len > std::numeric_limits::max()) { @@ -910,9 +951,9 @@ inline inst make_expand(tinytc_value_t a, std::int64_t expanded_mode, if (len > std::numeric_limits::max()) { throw std::out_of_range("expand shape too large"); } + const tinytc_value_t *es = reinterpret_cast(expand_shape.data()); CHECK_STATUS_LOC(tinytc_expand_inst_create(&instr, a, expanded_mode, static_len, - static_expand_shape.data(), len, expand_shape.data(), - &loc), + static_expand_shape.data(), len, es, &loc), loc); return inst(instr); } @@ -927,8 +968,7 @@ inline inst make_expand(tinytc_value_t a, std::int64_t expanded_mode, * * @return Instruction */ -inline inst make_fuse(tinytc_value_t a, std::int64_t from, std::int64_t to, - location const &loc = {}) { +inline inst make_fuse(value a, std::int64_t from, std::int64_t to, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_fuse_inst_create(&instr, a, from, to, &loc), loc); return inst(instr); @@ -943,14 +983,14 @@ inline inst make_fuse(tinytc_value_t a, std::int64_t from, std::int64_t to, * * @return Instruction */ -inline inst make_load(tinytc_value_t a, array_view index_list, - location const &loc = {}) { +inline inst make_load(value a, array_view index_list, location const &loc = {}) { tinytc_inst_t instr; auto len = index_list.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("index list too long"); } - CHECK_STATUS_LOC(tinytc_load_inst_create(&instr, a, len, index_list.data(), &loc), loc); + const tinytc_value_t *il = reinterpret_cast(index_list.data()); + CHECK_STATUS_LOC(tinytc_load_inst_create(&instr, a, len, il, &loc), loc); return inst(instr); } @@ -997,9 +1037,8 @@ inline inst make_group_size(compiler_context const &ctx, location const &loc = { * * @return Instruction */ -inline inst make_gemm(transpose tA, transpose tB, bool atomic, tinytc_value_t alpha, - tinytc_value_t A, tinytc_value_t B, tinytc_value_t beta, tinytc_value_t C, - location const &loc = {}) { +inline inst make_gemm(transpose tA, transpose tB, bool atomic, value alpha, value A, value B, + value beta, value C, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_gemm_inst_create(&instr, static_cast(tA), static_cast(tB), atomic, alpha, A, @@ -1022,8 +1061,7 @@ inline inst make_gemm(transpose tA, transpose tB, bool atomic, tinytc_value_t al * * @return Instruction */ -inline inst make_gemv(transpose tA, bool atomic, tinytc_value_t alpha, tinytc_value_t A, - tinytc_value_t B, tinytc_value_t beta, tinytc_value_t C, +inline inst make_gemv(transpose tA, bool atomic, value alpha, value A, value B, value beta, value C, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_gemv_inst_create(&instr, static_cast(tA), atomic, @@ -1045,8 +1083,8 @@ inline inst make_gemv(transpose tA, bool atomic, tinytc_value_t alpha, tinytc_va * * @return Instruction */ -inline inst make_ger(bool atomic, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, - tinytc_value_t beta, tinytc_value_t C, location const &loc = {}) { +inline inst make_ger(bool atomic, value alpha, value A, value B, value beta, value C, + location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_ger_inst_create(&instr, atomic, alpha, A, B, beta, C, &loc), loc); return inst(instr); @@ -1065,8 +1103,8 @@ inline inst make_ger(bool atomic, tinytc_value_t alpha, tinytc_value_t A, tinytc * * @return Instruction */ -inline inst make_hadamard(bool atomic, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, - tinytc_value_t beta, tinytc_value_t C, location const &loc = {}) { +inline inst make_hadamard(bool atomic, value alpha, value A, value B, value beta, value C, + location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_hadamard_inst_create(&instr, atomic, alpha, A, B, beta, C, &loc), loc); return inst(instr); @@ -1108,7 +1146,7 @@ inline inst make_parallel(location const &loc = {}) { * * @return Instruction */ -inline inst make_size(tinytc_value_t a, std::int64_t mode, location const &loc = {}) { +inline inst make_size(value a, std::int64_t mode, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_size_inst_create(&instr, a, mode, &loc), loc); return inst(instr); @@ -1170,10 +1208,9 @@ inline inst make_subgroup_size(compiler_context const &ctx, location const &loc * * @return Instruction */ -inline inst make_subview(tinytc_value_t a, array_view static_offset_list, - array_view static_size_list, - array_view offset_list, - array_view size_list, location const &loc = {}) { +inline inst make_subview(value a, array_view static_offset_list, + array_view static_size_list, array_view offset_list, + array_view size_list, location const &loc = {}) { tinytc_inst_t instr; if (static_offset_list.size() != static_size_list.size()) { throw std::invalid_argument( @@ -1191,9 +1228,11 @@ inline inst make_subview(tinytc_value_t a, array_view static_offse if (size_len > std::numeric_limits::max()) { throw std::out_of_range("dynamic size list too long"); } - CHECK_STATUS_LOC(tinytc_subview_inst_create( - &instr, a, static_len, static_offset_list.data(), static_size_list.data(), - offset_len, offset_list.data(), size_len, size_list.data(), &loc), + const tinytc_value_t *ol = reinterpret_cast(offset_list.data()); + const tinytc_value_t *sl = reinterpret_cast(size_list.data()); + CHECK_STATUS_LOC(tinytc_subview_inst_create(&instr, a, static_len, static_offset_list.data(), + static_size_list.data(), offset_len, ol, size_len, + sl, &loc), loc); return inst(instr); } @@ -1208,14 +1247,14 @@ inline inst make_subview(tinytc_value_t a, array_view static_offse * * @return Instruction */ -inline inst make_store(tinytc_value_t val, tinytc_value_t a, array_view index_list, - location const &loc = {}) { +inline inst make_store(value val, value a, array_view index_list, location const &loc = {}) { tinytc_inst_t instr; auto len = index_list.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("index list too long"); } - CHECK_STATUS_LOC(tinytc_store_inst_create(&instr, val, a, len, index_list.data(), &loc), loc); + const tinytc_value_t *il = reinterpret_cast(index_list.data()); + CHECK_STATUS_LOC(tinytc_store_inst_create(&instr, val, a, len, il, &loc), loc); return inst(instr); } @@ -1232,8 +1271,8 @@ inline inst make_store(tinytc_value_t val, tinytc_value_t a, array_view(tA), atomic, alpha, A, beta, B, &loc), @@ -1252,8 +1291,8 @@ inline inst make_sum(transpose tA, bool atomic, tinytc_value_t alpha, tinytc_val * * @return Instruction */ -inline inst make_for(tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, - tinytc_data_type_t loop_var_type, location const &loc = {}) { +inline inst make_for(value from, value to, value step, data_type loop_var_type, + location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_for_inst_create(&instr, from, to, step, loop_var_type, &loc), loc); return inst(instr); @@ -1269,8 +1308,7 @@ inline inst make_for(tinytc_value_t from, tinytc_value_t to, tinytc_value_t step * * @return Instruction */ -inline inst make_foreach(tinytc_value_t from, tinytc_value_t to, tinytc_data_type_t loop_var_type, - location const &loc = {}) { +inline inst make_foreach(value from, value to, data_type loop_var_type, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_foreach_inst_create(&instr, from, to, loop_var_type, &loc), loc); return inst(instr); @@ -1285,7 +1323,7 @@ inline inst make_foreach(tinytc_value_t from, tinytc_value_t to, tinytc_data_typ * * @return Instruction */ -inline inst make_if(tinytc_value_t condition, array_view return_type_list = {}, +inline inst make_if(value condition, array_view return_type_list = {}, location const &loc = {}) { tinytc_inst_t instr; auto len = return_type_list.size(); @@ -1305,13 +1343,14 @@ inline inst make_if(tinytc_value_t condition, array_view ret * * @return Instruction */ -inline inst make_yield(array_view yield_list, location const &loc = {}) { +inline inst make_yield(array_view yield_list, location const &loc = {}) { tinytc_inst_t instr; auto len = yield_list.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("yield list too long"); } - CHECK_STATUS_LOC(tinytc_yield_inst_create(&instr, len, yield_list.data(), &loc), loc); + const tinytc_value_t *yl = reinterpret_cast(yield_list.data()); + CHECK_STATUS_LOC(tinytc_yield_inst_create(&instr, len, yl, &loc), loc); return inst(instr); } @@ -1338,10 +1377,10 @@ class func : public unique_handle { CHECK_STATUS(tinytc_func_set_subgroup_size(obj_, sgs)); } - auto get_body() -> tinytc_region_t { + auto get_body() -> region { tinytc_region_t body; CHECK_STATUS(tinytc_func_get_body(obj_, &body)); - return body; + return region{body}; } }; @@ -1354,7 +1393,7 @@ class func : public unique_handle { * * @return Function */ -inline func make_func(std::string_view name, array_view param_type_list, +inline func make_func(std::string_view name, array_view param_type_list, location const &loc = {}) { tinytc_func_t fun; auto len = param_type_list.size(); @@ -1395,7 +1434,7 @@ class prog : public shared_handle { * @param fun function */ inline void add_function(func fun) { - CHECK_STATUS(tinytc_prog_add_function(get(), fun.release())); + CHECK_STATUS(tinytc_prog_add_function(obj_, fun.release())); } /** @@ -1458,7 +1497,7 @@ class region_builder { * * @param reg region object */ - region_builder(tinytc_region_t reg) : reg_{reg} {} + region_builder(region reg) : reg_{reg} {} /** * @brief Add instruction @@ -1468,12 +1507,12 @@ class region_builder { * * @return Value returned by instruction; may be empty */ - [[maybe_unused]] inline auto add(inst i, std::string_view name = "") -> tinytc_value_t { - tinytc_value_t result = nullptr; + [[maybe_unused]] inline auto add(inst i, std::string_view name = "") -> value { + auto result = value{}; if (i.get_values(result) > 0 && name.size() > 0) { - set_name(result, name); + result.set_name(name); } - add_instruction(reg_, std::move(i)); + reg_.add_instruction(std::move(i)); return result; } @@ -1485,24 +1524,24 @@ class region_builder { * * @return Values returned by instruction */ - [[maybe_unused]] inline auto - add_multivalued(inst i, std::string_view name = "") -> std::vector { + [[maybe_unused]] inline auto add_multivalued(inst i, + std::string_view name = "") -> std::vector { auto num_results = i.get_values({}); - auto results = std::vector(static_cast(num_results)); + auto results = std::vector(static_cast(num_results)); results.resize(i.get_values(results)); if (name.size() > 0) { int counter = 0; auto name_str = std::string{name}; for (auto &result : results) { - set_name(result, name_str + std::to_string(counter++)); + result.set_name(name_str + std::to_string(counter++)); } } - add_instruction(reg_, std::move(i)); + reg_.add_instruction(std::move(i)); return results; } /** - * @brief Build for-loop with functor f(region_builder&, tinytc_value_t) -> void + * @brief Build for-loop with functor f(region_builder&, value) -> void * * The loop trip count is passed as second argument to the functor. * @@ -1515,13 +1554,13 @@ class region_builder { * @param loc Source code location */ template - void for_loop(tinytc_value_t from, tinytc_value_t to, tinytc_data_type_t loop_var_ty, F &&f, + void for_loop(value from, value to, data_type loop_var_ty, F &&f, std::string_view loop_var_name = "", location const &loc = {}) { for_loop(std::move(from), std::move(to), nullptr, std::move(loop_var_ty), std::forward(f), std::move(loop_var_name), loc); } /** - * @brief Build for-loop with functor f(region_builder&, tinytc_value_t) -> void + * @brief Build for-loop with functor f(region_builder&, value) -> void * * The loop trip count is passed as second argument to the functor. * @@ -1535,24 +1574,23 @@ class region_builder { * @param loc Source code location */ template - void for_loop(tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, - tinytc_data_type_t loop_var_ty, F &&f, std::string_view loop_var_name = "", - location const &loc = {}) { + void for_loop(value from, value to, value step, data_type loop_var_ty, F &&f, + std::string_view loop_var_name = "", location const &loc = {}) { auto fi = ::tinytc::make_for(from, to, step, loop_var_ty, loc); - tinytc_region_t reg = nullptr; + auto reg = region{}; fi.get_regions(reg); - tinytc_value_t loop_var = nullptr; - get_parameters(reg, loop_var); + auto loop_var = value{}; + reg.get_parameters(loop_var); if (!reg || !loop_var) { throw status::internal_compiler_error; } - set_name(loop_var, loop_var_name); - add_instruction(reg_, std::move(fi)); + loop_var.set_name(loop_var_name); + reg_.add_instruction(std::move(fi)); auto bb = region_builder{reg}; f(bb, loop_var); } /** - * @brief Build foreach-loop with functor f(region_builder&, tinytc_value_t) -> void + * @brief Build foreach-loop with functor f(region_builder&, value) -> void * * @tparam F Functor type * @param from Loop variable start @@ -1563,18 +1601,18 @@ class region_builder { * @param loc Source code location */ template - void foreach (tinytc_value_t from, tinytc_value_t to, tinytc_data_type_t loop_var_ty, F && f, + void foreach (value from, value to, data_type loop_var_ty, F && f, std::string const &loop_var_name = "", location const &loc = {}) { auto fi = ::tinytc::make_foreach(from, to, loop_var_ty, loc); - tinytc_region_t reg = nullptr; + auto reg = region{}; fi.get_regions(reg); - tinytc_value_t loop_var = nullptr; - get_parameters(reg, loop_var); + auto loop_var = value{}; + reg.get_parameters(loop_var); if (!reg || !loop_var) { throw status::internal_compiler_error; } - set_name(loop_var, loop_var_name); - add_instruction(reg_, std::move(fi)); + loop_var.set_name(loop_var_name); + reg_.add_instruction(std::move(fi)); auto bb = region_builder{reg}; f(bb, loop_var); } @@ -1591,11 +1629,10 @@ class region_builder { * @return Returned values */ template - auto if_condition(tinytc_value_t condition, F &&then, - array_view return_type_list = {}, - location const &loc = {}) -> std::vector { + auto if_condition(value condition, F &&then, array_view return_type_list = {}, + location const &loc = {}) -> std::vector { auto ii = ::tinytc::make_if(std::move(condition), return_type_list, loc); - tinytc_region_t reg = nullptr; + auto reg = region{}; ii.get_regions(reg); if (!reg) { throw status::internal_compiler_error; @@ -1620,11 +1657,11 @@ class region_builder { * @return Returned values */ template - auto ifelse(tinytc_value_t condition, F &&then, G &&otherwise, - array_view return_type_list = {}, - location const &loc = {}) -> std::vector { + auto ifelse(value condition, F &&then, G &&otherwise, + array_view return_type_list = {}, + location const &loc = {}) -> std::vector { auto ii = ::tinytc::make_if(std::move(condition), return_type_list, loc); - auto regs = std::array{nullptr, nullptr}; + std::array regs = {}; ii.get_regions(regs); if (!regs[0] || !regs[1]) { throw status::internal_compiler_error; @@ -1638,7 +1675,7 @@ class region_builder { } private: - tinytc_region_t reg_; + region reg_; }; //////////////////////////// diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 05c08e03..a4eafdde 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -430,15 +430,14 @@ void write_matrix_block(block_builder &bb, block_accessor const &block, } } -void tile_loop_by_sgs_new(region_builder &bb, tinytc_value_t loop_trip_count, int sgs, - int num_tiles, tinytc_value_t sg_id, - sgs_loop_body_builder_new const &body) { +void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, + value sg_id, sgs_loop_body_builder_new const &body) { tile_loop_by_sgs_new_dynamic(bb, std::move(loop_trip_count), sgs, num_tiles, std::move(sg_id), body); } void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_count, int sgs, - int num_tiles, tinytc_value_t sg_id, + int num_tiles, value sg_id, sgs_loop_body_builder_new const &body) { auto index_ty = scalar_data_type::get(sg_id->context(), scalar_type::index); std::int64_t blocks = loop_trip_count / sgs; @@ -455,8 +454,7 @@ void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_co auto block_start = bb.add(make_arith(arithmetic::mul, c_sgs, sg_id_index)); bb.for_loop( std::move(block_start), c_sgs_blocks, c_sgs_tiles, index_ty, - [&](region_builder &bb, tinytc_value_t block) { body(bb, block, false, c_sgs); }, - "block"); + [&](region_builder &bb, value block) { body(bb, block, false, c_sgs); }, "block"); } if (rem > 0) { @@ -466,9 +464,8 @@ void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_co } } -void tile_loop_by_sgs_new_dynamic(region_builder &bb, tinytc_value_t loop_trip_count, int sgs, - int num_tiles, tinytc_value_t sg_id, - sgs_loop_body_builder_new const &body) { +void tile_loop_by_sgs_new_dynamic(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, + value sg_id, sgs_loop_body_builder_new const &body) { auto index_ty = scalar_data_type::get(sg_id->context(), scalar_type::index); auto c_sgs = bb.add(make_constant(sgs, index_ty)); auto c_sgs_tiles = bb.add(make_constant(sgs * num_tiles, index_ty)); @@ -483,7 +480,7 @@ void tile_loop_by_sgs_new_dynamic(region_builder &bb, tinytc_value_t loop_trip_c auto block_end = bb.add(make_arith(arithmetic::mul, c_sgs, blocks)); bb.for_loop( std::move(block_start), std::move(block_end), c_sgs_tiles, index_ty, - [&](region_builder &bb, tinytc_value_t block) { body(bb, block, false, c_sgs); }, "block"); + [&](region_builder &bb, value block) { body(bb, block, false, c_sgs); }, "block"); auto condition0 = bb.add(make_cmp(cmp_condition::gt, rem, c0)); bb.if_condition(condition0, [&](region_builder &bb) { @@ -495,15 +492,15 @@ void tile_loop_by_sgs_new_dynamic(region_builder &bb, tinytc_value_t loop_trip_c }); } -void tile_loop_uniformly_new(region_builder &bb, tinytc_value_t loop_trip_count, int block_size, - int num_tiles, tinytc_value_t sg_id, +void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int block_size, + int num_tiles, value sg_id, uniform_loop_body_builder_new const &body) { tile_loop_uniformly_new_dynamic(bb, std::move(loop_trip_count), block_size, num_tiles, std::move(sg_id), body); } void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip_count, - int block_size, int num_tiles, tinytc_value_t sg_id, + int block_size, int num_tiles, value sg_id, uniform_loop_body_builder_new const &body) { auto index_ty = scalar_data_type::get(sg_id->context(), scalar_type::index); // Find minimum number of blocks such that the block sizes are smaller or equal block_size @@ -529,7 +526,7 @@ void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip auto block_start = bb.add(make_arith(arithmetic::mul, c_bs_1, sg_id_index)); bb.for_loop( std::move(block_start), c_bs_1_rem, c_bs_1_tiles, index_ty, - [&](region_builder &bb, tinytc_value_t block) { body(bb, block, c_bs_1); }, "block"); + [&](region_builder &bb, value block) { body(bb, block, c_bs_1); }, "block"); } auto tmp = bb.add(make_arith(arithmetic::add, sg_id_index, c_rem_mod_tiles)); @@ -538,11 +535,11 @@ void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip auto block_start = bb.add(make_arith(arithmetic::add, c_bs_1_rem, tmp2)); bb.for_loop( std::move(block_start), c_loop_trip_count, c_bs_tiles, index_ty, - [&](region_builder &bb, tinytc_value_t block) { body(bb, block, c_bs); }, "block"); + [&](region_builder &bb, value block) { body(bb, block, c_bs); }, "block"); } -void tile_loop_uniformly_new_dynamic(region_builder &bb, tinytc_value_t loop_trip_count, - int block_size, int num_tiles, tinytc_value_t sg_id, +void tile_loop_uniformly_new_dynamic(region_builder &bb, value loop_trip_count, int block_size, + int num_tiles, value sg_id, uniform_loop_body_builder_new const &body) { auto index_ty = scalar_data_type::get(loop_trip_count->context(), scalar_type::index); auto c1 = bb.add(make_constant(1, index_ty)); @@ -570,7 +567,7 @@ void tile_loop_uniformly_new_dynamic(region_builder &bb, tinytc_value_t loop_tri auto step_1 = bb.add(make_arith(arithmetic::mul, bs_1, c_tiles)); bb.for_loop( std::move(block_start_1), std::move(block_end_1), std::move(step_1), index_ty, - [&](region_builder &bb, tinytc_value_t block) { body(bb, block, bs_1); }, "block"); + [&](region_builder &bb, value block) { body(bb, block, bs_1); }, "block"); auto tmp0 = bb.add(make_arith(arithmetic::rem, rem, c_tiles)); auto tmp1 = bb.add(make_arith(arithmetic::add, sg_id_index, tmp0)); @@ -581,7 +578,7 @@ void tile_loop_uniformly_new_dynamic(region_builder &bb, tinytc_value_t loop_tri auto step = bb.add(make_arith(arithmetic::mul, bs, c_tiles)); bb.for_loop( std::move(block_start), loop_trip_count, std::move(step), index_ty, - [&](region_builder &bb, tinytc_value_t block) { body(bb, block, bs); }, "block"); + [&](region_builder &bb, value block) { body(bb, block, bs); }, "block"); } } // namespace tinytc diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index 26e29aaf..2241d7cc 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -122,29 +122,24 @@ void write_matrix_block(clir::block_builder &bb, block_accessor const &block, matrix_block_description const &d, bool is_atomic, scalar_type beta_ty, clir::expr beta, core_config const &core_cfg); -using sgs_loop_body_builder_new = - std::function; -using uniform_loop_body_builder_new = - std::function; - -void tile_loop_by_sgs_new(region_builder &bb, tinytc_value_t loop_trip_count, int sgs, - int num_tiles, tinytc_value_t sg_id, - sgs_loop_body_builder_new const &body); +using sgs_loop_body_builder_new = std::function; +using uniform_loop_body_builder_new = std::function; + +void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, + value sg_id, sgs_loop_body_builder_new const &body); void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_count, int sgs, - int num_tiles, tinytc_value_t sg_id, + int num_tiles, value sg_id, sgs_loop_body_builder_new const &body); -void tile_loop_by_sgs_new_dynamic(region_builder &bb, tinytc_value_t loop_trip_count, int sgs, - int num_tiles, tinytc_value_t sg_id, - sgs_loop_body_builder_new const &body); +void tile_loop_by_sgs_new_dynamic(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, + value sg_id, sgs_loop_body_builder_new const &body); -void tile_loop_uniformly_new(region_builder &bb, tinytc_value_t loop_trip_count, int block_size, - int num_tiles, tinytc_value_t sg_id, - uniform_loop_body_builder_new const &body); +void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int block_size, + int num_tiles, value sg_id, uniform_loop_body_builder_new const &body); void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip_count, - int block_size, int num_tiles, tinytc_value_t sg_id, + int block_size, int num_tiles, value sg_id, uniform_loop_body_builder_new const &body); -void tile_loop_uniformly_new_dynamic(region_builder &bb, tinytc_value_t loop_trip_count, - int block_size, int num_tiles, tinytc_value_t sg_id, +void tile_loop_uniformly_new_dynamic(region_builder &bb, value loop_trip_count, int block_size, + int num_tiles, value sg_id, uniform_loop_body_builder_new const &body); } // namespace tinytc diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index 7dac902d..30700601 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -96,13 +96,13 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( {1, ldC, strideC}, address_space::global, my_loc()); auto f = make_func(name, {ty_, A_ty, B_ty, ty_, C_ty}, my_loc()); auto fn_body = f.get_body(); - auto params = std::array{}; - get_parameters(fn_body, params); - set_name(params[0], "alpha"); - set_name(params[1], "A"); - set_name(params[2], "B"); - set_name(params[3], "beta"); - set_name(params[4], "C"); + auto params = std::array{}; + fn_body.get_parameters(params); + params[0].set_name("alpha"); + params[1].set_name("A"); + params[2].set_name("B"); + params[3].set_name("beta"); + params[4].set_name("C"); auto bb = region_builder{fn_body}; @@ -112,11 +112,11 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( auto const B_static_sizes = std::vector{K, N, 0}; auto const C_static_sizes = std::vector{M, N, 0}; auto a = bb.add(make_subview(params[1], static_offsets, A_static_sizes, - array_view{gid}, {}, my_loc())); + array_view{gid}, {}, my_loc())); auto b = bb.add(make_subview(params[2], static_offsets, B_static_sizes, - array_view{gid}, {}, my_loc())); + array_view{gid}, {}, my_loc())); auto c = bb.add(make_subview(params[3], static_offsets, C_static_sizes, - array_view{gid}, {}, my_loc())); + array_view{gid}, {}, my_loc())); auto beta = is_beta_nonzero ? params[4] : bb.add(make_constant(0.0, ty_, my_loc())); bb.add(make_gemm(tA_, tB_, false, params[0], std::move(a), std::move(b), beta, std::move(c), my_loc())); diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index ee081db9..ba8a0c68 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -104,16 +104,15 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( tiling[1] /= 2; } - auto const body = [&](region_builder &bb, tinytc_value_t alpha, tinytc_value_t A, - tinytc_value_t B, bool is_beta_nonzero, tinytc_value_t beta_arg, - tinytc_value_t C) { + auto const body = [&](region_builder &bb, value alpha, value A, value B, + bool is_beta_nonzero, value beta_arg, value C) { auto c_M_block_size = bb.add(make_constant(M_block_size, index_ty, my_loc())); auto gid = bb.add(make_group_id(ctx_, my_loc())); auto m = bb.add(make_arith(arithmetic::mul, gid, c_M_block_size, my_loc())); auto beta = is_beta_nonzero ? beta_arg : bb.add(make_constant(0.0, ty_, my_loc())); auto const static_offsets = std::vector{dynamic, 0}; - auto const offsets = array_view{m}; + auto const offsets = array_view{m}; auto const static_gemm = [&](region_builder &bb) { auto const A_static_sizes = std::vector{M_block_size, K}; @@ -125,10 +124,10 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( bb.add(make_gemm(transpose::N, transpose::N, false, alpha, a, B, beta, c, my_loc())); }; - auto const dynamic_gemm = [&](region_builder &bb, tinytc_value_t dyn_block_size) { + auto const dynamic_gemm = [&](region_builder &bb, value dyn_block_size) { auto const A_static_sizes = std::vector{dynamic, K}; auto const C_static_sizes = std::vector{dynamic, N}; - auto const sizes = array_view{dyn_block_size}; + auto const sizes = array_view{dyn_block_size}; auto a = bb.add( make_subview(A, static_offsets, A_static_sizes, offsets, sizes, my_loc())); auto c = bb.add( @@ -160,13 +159,13 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( address_space::global, my_loc()); auto f = make_func(name, {ty_, A_ty, B_ty, ty_, C_ty}, my_loc()); auto fn_body = f.get_body(); - auto params = std::array{}; - get_parameters(fn_body, params); - set_name(params[0], "alpha"); - set_name(params[1], "A"); - set_name(params[2], "B"); - set_name(params[3], "beta"); - set_name(params[4], "C"); + auto params = std::array{}; + fn_body.get_parameters(params); + params[0].set_name("alpha"); + params[1].set_name("A"); + params[2].set_name("B"); + params[3].set_name("beta"); + params[4].set_name("C"); auto const wgs = tiling.work_group_size(sgs); f.set_work_group_size(wgs[0], wgs[1]); From 592bb68f4e16027924af1b8dea6c22a9161b9457 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 1 Oct 2024 17:54:09 +0200 Subject: [PATCH 035/297] Continue on lower linalg pass; fixes Signed-off-by: Carsten Uphoff --- include/tinytc/tinytc.h | 2 +- include/tinytc/tinytc.hpp | 7 +- src/CMakeLists.txt | 4 +- src/codegen_tools.cpp | 8 +- src/compiler.cpp | 12 +- src/inst.cpp | 8 +- src/node/inst_node.cpp | 9 +- src/node/inst_node.hpp | 7 +- src/parser/parser_impl.yy | 4 +- src/pass/constant_propagation.cpp | 88 ++++++++---- src/pass/constant_propagation.hpp | 28 +--- src/pass/dump_ir.cpp | 3 + src/pass/lower_linalg.cpp | 226 ++++++++++-------------------- src/pass/lower_linalg.hpp | 41 +----- src/pass/slot_tracker.cpp | 7 +- src/pass/slot_tracker.hpp | 2 +- src/passes.def | 2 + src/prog.cpp | 2 +- src/support/ilist_base.hpp | 8 +- src/support/walk.hpp | 40 +++++- test/opt/constant-propagation.ir | 13 ++ 21 files changed, 232 insertions(+), 289 deletions(-) create mode 100644 test/opt/constant-propagation.ir diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index cb315fb1..eb182bd0 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -201,7 +201,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_arith_unary_inst_create(tinytc_inst_t *inst * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_cast_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - tinytc_scalar_type_t to_ty, + tinytc_data_type_t to_ty, const tinytc_location_t *loc); /** diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 44032dbf..9b86faf5 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -716,7 +716,7 @@ class inst : public unique_handle { * @return Minimum of view size and actual number of child regions */ inline auto get_regions(mutable_array_view regs) const -> std::uint32_t { - std::uint32_t result_list_size = 0; + std::uint32_t result_list_size = regs.size(); tinytc_region_t *rl = reinterpret_cast(regs.data()); CHECK_STATUS(tinytc_inst_get_regions(obj_, &result_list_size, rl)); return result_list_size; @@ -807,10 +807,9 @@ inline inst make_arith(arithmetic_unary op, value a, location const &loc = {}) { * * @return Instruction */ -inline inst make_cast(value a, scalar_type to_ty, location const &loc = {}) { +inline inst make_cast(value a, data_type to_ty, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC( - tinytc_cast_inst_create(&instr, a, static_cast(to_ty), &loc), loc); + CHECK_STATUS_LOC(tinytc_cast_inst_create(&instr, a, to_ty, &loc), loc); return inst(instr); } diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 583eec8a..6741cc52 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -41,13 +41,13 @@ set(SOURCES parser/parse_context.cpp parser.cpp pass/check_ir.cpp - #pass/constant_propagation.cpp + pass/constant_propagation.cpp pass/convert_to_opencl.cpp pass/dump_cfg.cpp pass/dump_ir.cpp pass/insert_barrier.cpp pass/insert_lifetime_stop.cpp - #pass/lower_linalg.cpp + pass/lower_linalg.cpp pass/slot_tracker.cpp pass/stack.cpp pass/work_group_size.cpp diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index a4eafdde..bb64c3d7 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -449,7 +449,7 @@ void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_co auto c_tiles_1 = bb.add(make_constant(num_tiles - 1, index_ty)); auto c_rem = bb.add(make_constant(rem, index_ty)); - auto sg_id_index = bb.add(make_cast(sg_id, scalar_type::index)); + auto sg_id_index = bb.add(make_cast(sg_id, index_ty)); if (blocks > 0) { auto block_start = bb.add(make_arith(arithmetic::mul, c_sgs, sg_id_index)); bb.for_loop( @@ -475,7 +475,7 @@ void tile_loop_by_sgs_new_dynamic(region_builder &bb, value loop_trip_count, int auto blocks = bb.add(make_arith(arithmetic::div, loop_trip_count, c_sgs)); auto rem = bb.add(make_arith(arithmetic::rem, loop_trip_count, c_sgs)); - auto sg_id_index = bb.add(make_cast(sg_id, scalar_type::index)); + auto sg_id_index = bb.add(make_cast(sg_id, index_ty)); auto block_start = bb.add(make_arith(arithmetic::mul, c_sgs, sg_id_index)); auto block_end = bb.add(make_arith(arithmetic::mul, c_sgs, blocks)); bb.for_loop( @@ -521,7 +521,7 @@ void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip auto c_tiles = bb.add(make_constant(num_tiles, index_ty)); auto c_loop_trip_count = bb.add(make_constant(loop_trip_count, index_ty)); - auto sg_id_index = bb.add(make_cast(sg_id, scalar_type::index)); + auto sg_id_index = bb.add(make_cast(sg_id, index_ty)); if (rem > 0) { auto block_start = bb.add(make_arith(arithmetic::mul, c_bs_1, sg_id_index)); bb.for_loop( @@ -561,7 +561,7 @@ void tile_loop_uniformly_new_dynamic(region_builder &bb, value loop_trip_count, auto rem = bb.add(make_arith(arithmetic::rem, loop_trip_count, blocks)); rem->name("rem"); - auto sg_id_index = bb.add(make_cast(sg_id, scalar_type::index)); + auto sg_id_index = bb.add(make_cast(sg_id, index_ty)); auto block_start_1 = bb.add(make_arith(arithmetic::mul, bs_1, sg_id_index)); auto block_end_1 = bb.add(make_arith(arithmetic::mul, bs_1, rem)); auto step_1 = bb.add(make_arith(arithmetic::mul, bs_1, c_tiles)); diff --git a/src/compiler.cpp b/src/compiler.cpp index 64d9a082..19bdd80a 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -5,11 +5,13 @@ #include "error.hpp" #include "node/program_node.hpp" #include "pass/check_ir.hpp" +#include "pass/constant_propagation.hpp" #include "pass/convert_to_opencl.hpp" #include "pass/dump_cfg.hpp" #include "pass/dump_ir.hpp" #include "pass/insert_barrier.hpp" #include "pass/insert_lifetime_stop.hpp" +#include "pass/lower_linalg.hpp" #include "pass/stack.hpp" #include "pass/work_group_size.hpp" #include "passes.hpp" @@ -85,13 +87,11 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ run_function_pass(check_ir_pass{}, *prg); run_function_pass(insert_lifetime_stop_pass{}, *prg); run_function_pass(set_stack_ptr_pass{}, *prg); - // insert_barriers(*prg); + run_function_pass(insert_barrier_pass{}, *prg); run_function_pass(work_group_size_pass{info}, *prg); - // lower_linalg(*prg, *info); - //run_function_pass(dump_ir_pass{std::cout}, *prg); - // propagate_constants(*prg); - // dump_ir(std::cout, *prg); - // opencl + // run_function_pass(lower_linalg_pass{info}, *prg); + // run_function_pass(constant_propagation_pass{info}, *prg); + // opencl auto ast = convert_to_opencl_pass{info}.run_on_program(*prg); clir::make_names_unique(ast); diff --git a/src/inst.cpp b/src/inst.cpp index 999d79cc..4143a72f 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -121,14 +121,12 @@ tinytc_status_t tinytc_arith_unary_inst_create(tinytc_inst_t *instr, tinytc_arit } tinytc_status_t tinytc_cast_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - tinytc_scalar_type_t to_ty, const tinytc_location_t *loc) { + tinytc_data_type_t to_ty, const tinytc_location_t *loc) { if (instr == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { - *instr = std::make_unique(a, enum_cast(to_ty), get_optional(loc)) - .release(); - }); + return exception_to_status_code( + [&] { *instr = std::make_unique(a, to_ty, get_optional(loc)).release(); }); } tinytc_status_t tinytc_cmp_inst_create(tinytc_inst_t *instr, tinytc_cmp_condition_t cond, diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 79334191..c08fd3dc 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -177,13 +177,16 @@ arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0 result(0) = value_node{at, lc}; } -cast_inst::cast_inst(tinytc_value_t a, scalar_type to_ty, location const &lc) +cast_inst::cast_inst(tinytc_value_t a, tinytc_data_type_t to_ty, location const &lc) : standard_inst{IK::cast} { op(op_a) = std::move(a); loc(lc); - auto result_ty = scalar_data_type::get(op(op_a)->context(), to_ty); - result(0) = value_node{result_ty, lc}; + if (!isa(*to_ty)) { + throw compilation_error(lc, status::ir_expected_scalar); + } + + result(0) = value_node{to_ty, lc}; } compare_inst::compare_inst(cmp_condition cond, tinytc_value_t a0, tinytc_value_t b0, diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 1fca1235..234dd68b 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -319,10 +319,15 @@ class blas_a3_inst : public standard_inst<5, 0> { inline bool atomic() const { return atomic_; } inline void atomic(bool a) { atomic_ = a; } + inline auto alpha() -> tinytc_value & { return *op(op_alpha); } inline auto alpha() const -> tinytc_value const & { return *op(op_alpha); } + inline auto A() -> tinytc_value & { return *op(op_A); } inline auto A() const -> tinytc_value const & { return *op(op_A); } + inline auto B() -> tinytc_value & { return *op(op_B); } inline auto B() const -> tinytc_value const & { return *op(op_B); } + inline auto beta() -> tinytc_value & { return *op(op_beta); } inline auto beta() const -> tinytc_value const & { return *op(op_beta); } + inline auto C() -> tinytc_value & { return *op(op_C); } inline auto C() const -> tinytc_value const & { return *op(op_C); } protected: @@ -420,7 +425,7 @@ class cast_inst : public standard_inst<1, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::cast; } enum op_number { op_a = 0 }; - cast_inst(tinytc_value_t a, scalar_type to_ty, location const &lc = {}); + cast_inst(tinytc_value_t a, tinytc_data_type_t to_ty, location const &lc = {}); inline auto a() const -> tinytc_value const & { return *op(op_a); } }; diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 244b6811..20f39c3a 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -781,8 +781,8 @@ arith_unary_inst: cast_inst: - CAST var[a] COLON scalar_type[from] RETURNS scalar_type[to] { - check_scalar_type(ctx.cctx(), $a, $from, @a, @from); + CAST var[a] COLON data_type[from] RETURNS data_type[to] { + check_type($a, $from, @a, @from); try { $$ = inst { std::make_unique(std::move($a), $to, @cast_inst).release() }; } catch (compilation_error const &e) { diff --git a/src/pass/constant_propagation.cpp b/src/pass/constant_propagation.cpp index d6a589fa..61682f1b 100644 --- a/src/pass/constant_propagation.cpp +++ b/src/pass/constant_propagation.cpp @@ -3,33 +3,66 @@ #include "pass/constant_propagation.hpp" #include "error.hpp" -#include "node/data_type_node.hpp" -#include "node/value_node.hpp" +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" #include "support/casting.hpp" #include "support/visit.hpp" +#include "support/walk.hpp" #include "tinytc/tinytc.hpp" #include namespace tinytc { +class constant_evaluator { + public: + auto operator()(inst_node &) -> inst; + // auto operator()(arith_inst &) -> inst; + auto operator()(size_inst &in) -> inst; + + private: + auto get_memref_type(value_node const &v) const -> const memref_data_type *; +}; + +auto constant_evaluator::get_memref_type(value_node const &v) const -> const memref_data_type * { + auto t = dyn_cast(v.ty()); + if (t == nullptr) { + throw compilation_error(v.loc(), status::ir_expected_memref); + } + return t; +} + /* Inst nodes */ -void constant_propagation::operator()(inst_node &in) { - for (auto &op : in.operands()) { - if (op) { - uintptr_t u = std::bit_cast(op.get()); - if (auto kc = known_constants_.find(u); kc != known_constants_.end()) { - op = kc->second; - } - } +auto constant_evaluator::operator()(inst_node &in) -> inst { + // for (auto &op : in.operands()) { + // if (op) { + // uintptr_t u = std::bit_cast(op.get()); + // if (auto kc = known_constants_.find(u); kc != known_constants_.end()) { + // op = kc->second; + //} + //} + //} + return inst{}; +} + +auto constant_evaluator::operator()(size_inst &in) -> inst { + auto ct = get_memref_type(in.operand()); + + auto mode_size = ct->shape(in.mode()); + if (!is_dynamic_value(mode_size)) { + return make_constant( + mode_size, scalar_data_type::get(in.operand().context(), scalar_type::index), in.loc()); } + + return inst{}; } -void constant_propagation::operator()(arith_inst &arith) { +/*auto constant_propagation::operator()(arith_inst &arith) -> inst { this->operator()(static_cast(arith)); - auto const &a = arith.a(); - auto const &b = arith.b(); + auto &a = arith.a(); + auto &b = arith.b(); auto at = dyn_cast(a->ty().get()); if (at == nullptr) { @@ -148,25 +181,18 @@ void constant_propagation::operator()(arith_inst &arith) { } } } -} - -void constant_propagation::operator()(parallel_inst &p) { visit(*this, *p.body()); } - -/* Region nodes */ -void constant_propagation::operator()(region_node &b) { - for (auto &s : b.insts()) { - visit(*this, *s); - } -} - -/* Function nodes */ -void constant_propagation::operator()(function_node &fn) { visit(*this, *fn.body()); } +}*/ -/* Program nodes */ -void constant_propagation::operator()(program &p) { - for (auto &fn : p.functions()) { - visit(*this, *fn); - } +void constant_propagation_pass::run_on_function(function_node &fn) { + walk(fn, [&](region_node ®) { + for (auto it = reg.begin(); it != reg.end(); ++it) { + auto known_constant = visit(constant_evaluator{}, *it); + if (known_constant) { + it = reg.insts().erase(it); + it = reg.insts().insert(it, known_constant.release()); + } + } + }); } } // namespace tinytc diff --git a/src/pass/constant_propagation.hpp b/src/pass/constant_propagation.hpp index 502f5eaa..59d8d68a 100644 --- a/src/pass/constant_propagation.hpp +++ b/src/pass/constant_propagation.hpp @@ -4,35 +4,13 @@ #ifndef CONSTANT_PROPAGATION_20240807_HPP #define CONSTANT_PROPAGATION_20240807_HPP -#include "node/function_node.hpp" -#include "node/inst_node.hpp" -#include "node/program_node.hpp" -#include "node/region_node.hpp" - -#include -#include -#include +#include "tinytc/types.h" namespace tinytc { -class constant_propagation { +class constant_propagation_pass { public: - /* Inst nodes */ - void operator()(inst_node &); - void operator()(arith_inst &arith); - void operator()(parallel_inst &p); - - /* Region nodes */ - void operator()(region_node &b); - - /* Func nodes */ - void operator()(function_node &fn); - - /* Program nodes */ - void operator()(program &p); - - private: - std::unordered_map known_constants_; + void run_on_function(::tinytc_func &fn); }; } // namespace tinytc diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index e0513c3b..50fbac4d 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -419,6 +419,9 @@ void dump_ir_pass::dump_region(region_node const ®) { } void dump_ir_pass::run_on_function(function_node const &fn) { + tracker_ = slot_tracker{}; + tracker_.run_on_function(fn); + *os_ << "func @" << fn.name() << "("; std::string infix = ",\n "; infix += std::string(fn.name().size(), ' '); diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index e8eb357d..7107037b 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -4,191 +4,107 @@ #include "pass/lower_linalg.hpp" #include "codegen_tools.hpp" #include "error.hpp" +#include "node/function_node.hpp" +#include "node/inst_node.hpp" #include "support/casting.hpp" #include "support/visit.hpp" +#include "support/walk.hpp" +#include "tiling.hpp" #include "tinytc/tinytc.hpp" namespace tinytc { -auto lower_linalg_pass::get_memref_type(value_node const &v) const -> const memref_data_type * { - auto t = dyn_cast(v.ty().get()); - if (t == nullptr) { - throw compilation_error(v.loc(), status::ir_expected_memref); - } - return t; -} +class linalg_generator { + public: + linalg_generator(local_tiling tiling, core_config core_cfg) + : tiling_{std::move(tiling)}, core_cfg_{std::move(core_cfg)} {} + auto operator()(inst_node &) -> inst { return inst{}; } + auto operator()(ger_inst &g) -> inst; -lower_linalg_pass::lower_linalg_pass(::tinytc_core_info const *info) : info_(std::move(info)) { - if (info_ == nullptr) { - throw std::invalid_argument("info must not be nullptr"); - } -} - -/* Data type nodes */ -// bool lower_linalg_pass::operator()(void_data_type &) { return false; } -// bool lower_linalg_pass::operator()(group_data_type &b) { return visit(*this, *b.ty()); } -// bool lower_linalg_pass::operator()(memref_data_type &m) { -// return m.addrspace() == clir::address_space::local_t; -//} -// bool lower_linalg_pass::operator()(scalar_data_type &) { return false; } - -//[> Value nodes <] -// value_node *lower_linalg_pass::operator()(float_imm &) { return nullptr; } -// value_node *lower_linalg_pass::operator()(int_imm &) { return nullptr; } -// value_node *lower_linalg_pass::operator()(val &v) { -// if (visit(*this, *v.ty())) { -// return &v; -//} -// return nullptr; -//} + private: + auto get_memref_type(value_node const &v) const -> const memref_data_type *; -/* Inst nodes */ -inst lower_linalg_pass::operator()(inst_node &) { return inst{nullptr}; } + local_tiling tiling_ = {}; + core_config core_cfg_ = {}; +}; -inst lower_linalg_pass::operator()(loop_inst &p) { - visit(*this, *p.body()); - return inst{nullptr}; -} - -inst lower_linalg_pass::operator()(if_inst &in) { - visit(*this, *in.then()); - if (in.has_otherwise()) { - visit(*this, *in.otherwise()); +auto linalg_generator::get_memref_type(value_node const &v) const -> const memref_data_type * { + auto t = dyn_cast(v.ty()); + if (t == nullptr) { + throw compilation_error(v.loc(), status::ir_expected_memref); } - return inst{nullptr}; + return t; } -inst lower_linalg_pass::operator()(parallel_inst &p) { - visit(*this, *p.body()); - return inst{nullptr}; -} +auto linalg_generator::operator()(ger_inst &g) -> inst { + auto parallel = make_parallel(g.loc()); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; -inst lower_linalg_pass::operator()(ger_inst &g) { - // auto at = get_memref_type(*g.A()); - // auto bt = get_memref_type(*g.B()); - auto ct = get_memref_type(*g.C()); + auto ctx = compiler_context{g.alpha().context(), true}; + auto i32_ty = get_scalar(ctx, scalar_type::i32); + auto index_ty = get_scalar(ctx, scalar_type::index); - auto bb = region_builder{}; - auto sgid = bb.add(make_subgroup_id(g.loc())); - auto m_tiles_imm = make_imm(tiling_.m_tiles(), g.loc()); - auto sg_n = bb.add(make_arith(arithmetic::div, sgid, m_tiles_imm, g.loc())); - auto sg_m = bb.add(make_arith(arithmetic::rem, sgid, m_tiles_imm, g.loc())); - auto m = bb.add(make_subgroup_local_id(g.loc())); - auto m_index = bb.add(make_cast(m, scalar_type::index, g.loc())); + auto sgid = bb.add(make_subgroup_id(ctx, g.loc())); + auto c_m_tiles = bb.add(make_constant(tiling_.m_tiles(), i32_ty, g.loc())); + auto sg_n = bb.add(make_arith(arithmetic::div, sgid, c_m_tiles, g.loc())); + auto sg_m = bb.add(make_arith(arithmetic::rem, sgid, c_m_tiles, g.loc())); + auto m = bb.add(make_subgroup_local_id(ctx, g.loc())); + auto m_index = bb.add(make_cast(m, index_ty, g.loc())); - auto c_shape1 = is_dynamic_value(ct->shape(1)) ? bb.add(make_size(g.C(), 1, g.loc())) - : make_index(ct->shape(1), g.loc()); - auto c_shape0 = is_dynamic_value(ct->shape(0)) ? bb.add(make_size(g.C(), 0, g.loc())) - : make_index(ct->shape(0), g.loc()); + auto c_shape0 = bb.add(make_size(&g.C(), 0, g.loc())); + auto c_shape1 = bb.add(make_size(&g.C(), 1, g.loc())); tile_loop_uniformly_new( - bb, std::move(c_shape1), core_cfg_.subgroup_size, tiling_.n_tiles(), std::move(sg_n), + bb, c_shape1, core_cfg_.subgroup_size, tiling_.n_tiles(), sg_n, [&](region_builder &bb, value block, value trip_count) { - bb.for_loop(scalar_type::index, make_index(0, g.loc()), trip_count, - [&](region_builder &bb, value const &n) { - auto nn = bb.add(make_arith(arithmetic::add, block, n, g.loc())); - auto b = bb.add(make_load(g.B(), {nn}, g.loc())); - b->name("b"); - tile_loop_by_sgs_new( - bb, c_shape0, core_cfg_.subgroup_size, tiling_.m_tiles(), sg_m, - [&](region_builder &bb, value const &block, bool is_remainder, - value const &inner_trip_count) { - auto mm = bb.add( - make_arith(arithmetic::add, block, m_index, g.loc())); - auto a = bb.add(make_load(g.A(), {mm}, g.loc())); - a->name("a"); - auto ab = bb.add(make_arith(arithmetic::mul, a, b, g.loc())); - bb.add(make_store(ab, g.C(), {mm, nn}, g.loc())); - }); - }); + auto zero = bb.add(make_constant(0, index_ty)); + bb.for_loop(zero, trip_count, index_ty, [&](region_builder &bb, value n) { + auto nn = bb.add(make_arith(arithmetic::add, block, n, g.loc())); + auto b = bb.add(make_load(&g.B(), {nn}, g.loc())); + b->name("b"); + tile_loop_by_sgs_new( + bb, c_shape0, core_cfg_.subgroup_size, tiling_.m_tiles(), sg_m, + [&](region_builder &bb, value block, bool, value) { + auto mm = bb.add(make_arith(arithmetic::add, block, m_index, g.loc())); + auto a = bb.add(make_load(&g.A(), {mm}, g.loc())); + a->name("a"); + auto ab = bb.add(make_arith(arithmetic::mul, a, b, g.loc())); + bb.add(make_store(ab, &g.C(), {mm, nn}, g.loc())); + }); + }); }); - return make_parallel(bb.get_product(), g.loc()); - - /*auto alpha = visit(*this, *g.alpha()); - auto beta = visit(*this, *g.beta()); - auto alpha_ty = get_scalar_type(*g.alpha()->ty()); - auto beta_ty = get_scalar_type(*g.beta()->ty()); - auto A = visit(*this, *g.A()); - auto B = visit(*this, *g.B()); - auto C = visit(*this, *g.C()); - - auto bb = clir::block_builder{}; - auto sg_n = bb.declare_assign(clir::generic_uint(), "sg_n", - clir::get_sub_group_id() / tiling_.m_tiles()); - auto sg_m = bb.declare_assign(clir::generic_uint(), "sg_m", - clir::get_sub_group_id() % tiling_.m_tiles()); - tile_loop_uniformly( - bb, cdv.shape(1), core_cfg_.subgroup_size, tiling_.n_tiles(), std::move(sg_n), - [&](clir::block_builder &bb, clir::expr block, clir::expr trip_count) { - auto n = clir::var("n"); - bb.add(clir::for_loop_builder(clir::declaration_assignment(clir::generic_int(), n, - 0), n < std::move(trip_count), ++n) .body([&](clir::block_builder &bb) { auto b = - bb.declare_assign(to_clir_ty(bt->element_ty()), "b", B + (block + n) * bdv.stride(0)); auto - Cb = bb.declare_assign(this->operator()(*ct), "Cb", C + (block + n) * cdv.stride(1)); auto m - = bb.declare_assign(clir::generic_uint(), "m", clir::get_sub_group_local_id()); - tile_loop_by_sgs( - bb, cdv.shape(0), core_cfg_.subgroup_size, tiling_.m_tiles(), - sg_m, - [&](clir::block_builder &bb, clir::expr block, bool is_remainder, - clir::expr inner_trip_count) { - auto const inner_loop = [&](clir::block_builder &bb) { - auto a = A[(block + m) * adv.stride(0)]; - auto c = bb.declare_assign((*this)(*ct), "c", - Cb + (block + m) * - cdv.stride(0)); auto ab = bb.declare_assign( to_clir_ty(ct->element_ty()), "ab", - multiply(at->element_ty(), bt->element_ty(), - std::move(a), b)); - const auto ab_scaled = multiply(alpha_ty, - ct->element_ty(), alpha, std::move(ab)); store_helper(bb, g.atomic(), c, ct->element_ty(), - ct->addrspace(), std::move(ab_scaled), - beta_ty, beta); - }; - if (is_remainder) { - bb.add(clir::if_selection_builder( - m < std::move(inner_trip_count)) - .then(inner_loop) - .get_product()); - } else { - inner_loop(bb); - } - }); - }) - .get_product()); - });*/ + return parallel; } -/* Region nodes */ -void lower_linalg_pass::operator()(region_node &b) { - for (auto &s : b.insts()) { - if (auto lowered_inst = visit(*this, *s); lowered_inst) { - s = lowered_inst; - } +lower_linalg_pass::lower_linalg_pass(::tinytc_core_info const *info) : info_(std::move(info)) { + if (info_ == nullptr) { + throw std::invalid_argument("info must not be nullptr"); } } -/* Function nodes */ -void lower_linalg_pass::operator()(prototype &) {} - -void lower_linalg_pass::operator()(function_node &fn) { +void lower_linalg_pass::run_on_function(function_node &fn) { auto const subgroup_size = fn.subgroup_size(); + core_config core_cfg = {}; try { - core_cfg_ = info_->get_core_config(subgroup_size); + core_cfg = info_->get_core_config(subgroup_size); } catch (std::out_of_range const &e) { throw compilation_error(fn.loc(), status::unsupported_subgroup_size); } auto const work_group_size = fn.work_group_size(); - tiling_[0] = work_group_size[0] / subgroup_size; - tiling_[1] = work_group_size[1]; - - visit(*this, *fn.prototype()); - visit(*this, *fn.body()); -} - -/* Program nodes */ -void lower_linalg_pass::operator()(program &p) { - for (auto &fn : p.functions()) { - visit(*this, *fn); - } + local_tiling tiling = {}; + tiling[0] = work_group_size[0] / subgroup_size; + tiling[1] = work_group_size[1]; + + walk(fn, [&](region_node ®) { + for (auto it = reg.begin(); it != reg.end(); ++it) { + auto lowered_inst = visit(linalg_generator{tiling, core_cfg}, *it); + if (lowered_inst) { + it = reg.insts().erase(it); + it = reg.insts().insert(it, lowered_inst.release()); + } + } + }); } } // namespace tinytc diff --git a/src/pass/lower_linalg.hpp b/src/pass/lower_linalg.hpp index f4dcffcb..e16b92f5 100644 --- a/src/pass/lower_linalg.hpp +++ b/src/pass/lower_linalg.hpp @@ -4,57 +4,18 @@ #ifndef LOWER_LINALG_20240801_HPP #define LOWER_LINALG_20240801_HPP -#include "device_info.hpp" -#include "node/data_type_node.hpp" -#include "node/function_node.hpp" -#include "node/inst_node.hpp" -#include "node/program_node.hpp" -#include "node/region_node.hpp" -#include "node/value_node.hpp" -#include "tiling.hpp" #include "tinytc/types.h" -#include - namespace tinytc { class lower_linalg_pass { public: lower_linalg_pass(::tinytc_core_info const *info); - /* Data type nodes */ - // bool operator()(void_data_type &); - // bool operator()(group_data_type &b); - // bool operator()(memref_data_type &m); - // bool operator()(scalar_data_type &s); - - //[> Value nodes <] - // value_node *operator()(int_imm &v); - // value_node *operator()(float_imm &v); - // value_node *operator()(val &v); - - /* Stmt nodes */ - inst operator()(inst_node &); - inst operator()(loop_inst &p); - inst operator()(ger_inst &g); - inst operator()(if_inst &in); - inst operator()(parallel_inst &p); - - /* Region nodes */ - void operator()(region_node &b); - - /* Func nodes */ - void operator()(function_node &fn); - - /* Program nodes */ - void operator()(program &p); + void run_on_function(::tinytc_func &fn); private: - auto get_memref_type(value_node const &v) const -> const memref_data_type *; - ::tinytc_core_info const *info_; - local_tiling tiling_ = {}; - core_config core_cfg_ = {}; }; } // namespace tinytc diff --git a/src/pass/slot_tracker.cpp b/src/pass/slot_tracker.cpp index 69a87119..d174b5d9 100644 --- a/src/pass/slot_tracker.cpp +++ b/src/pass/slot_tracker.cpp @@ -16,12 +16,17 @@ void slot_tracker::set_slot(value_node const &v) { } } -void slot_tracker::run_on_function(function_node &fn) { +void slot_tracker::run_on_function(function_node const &fn) { slot_ = 0; for (auto const &arg : fn.params()) { set_slot(arg); } walk(fn, [this](inst_node const &i) { + for (auto const ® : i.child_regions()) { + for (auto const &p : reg.params()) { + set_slot(p); + } + } for (auto const &result : i.results()) { set_slot(result); } diff --git a/src/pass/slot_tracker.hpp b/src/pass/slot_tracker.hpp index e4d11eaa..5a572537 100644 --- a/src/pass/slot_tracker.hpp +++ b/src/pass/slot_tracker.hpp @@ -14,7 +14,7 @@ namespace tinytc { class slot_tracker { public: - void run_on_function(function_node &fn); + void run_on_function(function_node const &fn); auto get_slot(value_node const &v) -> std::int64_t; diff --git a/src/passes.def b/src/passes.def index b9251672..a5562f50 100644 --- a/src/passes.def +++ b/src/passes.def @@ -2,9 +2,11 @@ // SPDX-License-Identifier: BSD-3-Clause FUNCTION_PASS("check-ir", check_ir_pass{}) +FUNCTION_PASS("constant-propagation", constant_propagation_pass{}) FUNCTION_PASS("dump-control-flow-graph", dump_cfg_pass{std::cout}) FUNCTION_PASS("dump-ir", dump_ir_pass{std::cout}) FUNCTION_PASS("insert-barrier", insert_barrier_pass{}) FUNCTION_PASS("insert-lifetime-stop", insert_lifetime_stop_pass{}) FUNCTION_PASS("set-stack-ptr", set_stack_ptr_pass{}) +FUNCTION_PASS_WITH_INFO("lower-linalg", [](tinytc_core_info const* info) { return lower_linalg_pass{info}; }) FUNCTION_PASS_WITH_INFO("work-group-size", [](tinytc_core_info const* info) { return work_group_size_pass(info); }) diff --git a/src/prog.cpp b/src/prog.cpp index 9b9d3e6c..6fd170e3 100644 --- a/src/prog.cpp +++ b/src/prog.cpp @@ -75,7 +75,7 @@ tinytc_status_t tinytc_prog_get_compiler_context(const_tinytc_prog_t prg, return tinytc_status_invalid_arguments; } return tinytc::exception_to_status_code( - [&] { *ctx = tinytc::compiler_context(prg->get_context()).release(); }); + [&] { *ctx = tinytc::compiler_context{prg->get_context(), true}.release(); }); } tinytc_status_t tinytc_prog_print_to_file(const_tinytc_prog_t prg, char const *filename) { diff --git a/src/support/ilist_base.hpp b/src/support/ilist_base.hpp index 675c958b..ead5790f 100644 --- a/src/support/ilist_base.hpp +++ b/src/support/ilist_base.hpp @@ -167,20 +167,18 @@ class ilist_base : protected IListCallback { auto erase(iterator it) -> iterator { // let s = sentinel - // |0|: s{prev->s,next->s} // |1|: n0{prev->s,next->s}, s{prev->n0,next->n0} // |2|: n0{prev->s,next->n1}, n1{prev->n0,next->s}, s{prev->n1,next->n0} base_pointer prev = it.get_base()->prev(); - base_pointer next = it.get_base()->prev(); - prev->prev(next); + base_pointer next = it.get_base()->next(); + prev->next(next); next->prev(prev); it.get_base()->prev(nullptr); it.get_base()->next(nullptr); - // |0| (it -> s) : s{prev->s,next->s} // |1| (it -> n0): s{prev->s,next->s} // |2| (it -> n0): n1{prev->s,next->s}, s{prev->n1,next->n1} // |2| (it -> n1): n0{prev->s,next->s}, s{prev->n0,next->n0} - // this->node_removed(&*it); + this->node_removed(&*it); return iterator{next}; } auto erase(iterator begin, iterator end) -> iterator { diff --git a/src/support/walk.hpp b/src/support/walk.hpp index b705b527..57b55220 100644 --- a/src/support/walk.hpp +++ b/src/support/walk.hpp @@ -30,7 +30,8 @@ class walk_stage { int next_region_ = 0; }; -template void walk(inst_node &i, std::function callback) { +template +void walk(inst_node const &i, std::function callback) { if constexpr (Order == walk_order::pre_order) { callback(i); } @@ -43,9 +44,13 @@ template void walk(inst_node &i, std::function void walk(inst_node &i, std::function callback) { + walk(const_cast(i), + [c = std::move(callback)](inst_node const &i) { c(const_cast(i)); }); +} template -void walk(inst_node &i, std::function callback) { +void walk(inst_node const &i, std::function callback) { for (auto ® : i.child_regions()) { if constexpr (Order == walk_order::pre_order) { callback(reg); @@ -58,9 +63,21 @@ void walk(inst_node &i, std::function callback) { } } } +template +void walk(inst_node &i, std::function callback) { + walk( + const_cast(i), + [c = std::move(callback)](region_node const ®) { c(const_cast(reg)); }); +} void walk(inst_node &i, std::function callback); +template +void walk(function_node const &fn, std::function callback) { + for (auto &i : fn.body()) { + walk(i, callback); + } +} template void walk(function_node &fn, std::function callback) { for (auto &i : fn.body()) { @@ -68,6 +85,25 @@ void walk(function_node &fn, std::function callback) { } } +template +void walk(function_node const &fn, std::function callback) { + if constexpr (Order == walk_order::pre_order) { + callback(fn.body()); + } + for (auto &j : fn.body()) { + walk(j, callback); + } + if constexpr (Order == walk_order::post_order) { + callback(fn.body()); + } +} +template +void walk(function_node &i, std::function callback) { + walk( + const_cast(i), + [c = std::move(callback)](region_node const ®) { c(const_cast(reg)); }); +} + inline void walk(function_node &fn, std::function callback) { for (auto &i : fn.body()) { diff --git a/test/opt/constant-propagation.ir b/test/opt/constant-propagation.ir new file mode 100644 index 00000000..2aafa384 --- /dev/null +++ b/test/opt/constant-propagation.ir @@ -0,0 +1,13 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-opt --constant-propagation < %s | filecheck %s +func @known_size(%a: memref) { + %0 = size %a[0] : memref + %1 = size %a[1] : memref + %2 = arith.add %0, %1 : index +; CHECK-LABEL: func @known_size({{.*}} +; CHECK-NEXT: %0 = constant 64 -> index +; CHECK-NEXT: %1 = constant 32 -> index +; CHECK-NEXT: %2 = constant 96 -> index +} From a3afe404e26287de013669187afed43e9951e6e4 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 1 Oct 2024 18:16:58 +0200 Subject: [PATCH 036/297] scalar_type -> data_type in memref Signed-off-by: Carsten Uphoff --- include/tinytc/tinytc.h | 17 +++++--- include/tinytc/tinytc.hpp | 20 ++++----- src/data_type.cpp | 24 +++++------ src/node/data_type_node.cpp | 21 ++++++--- src/node/data_type_node.hpp | 17 ++++---- src/node/inst_node.cpp | 9 ++-- src/parser/parser_impl.yy | 71 +++++++++++++------------------ src/recipe/small_gemm_batched.cpp | 14 +++--- src/recipe/tall_and_skinny.cpp | 9 ++-- 9 files changed, 94 insertions(+), 108 deletions(-) diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index eb182bd0..fec78fec 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -65,8 +65,9 @@ TINYTC_EXPORT tinytc_status_t tinytc_scalar_type_get(tinytc_data_type_t *dt, /** * @brief Get memref data type * + * Note: modifies compiler context + * * @param dt [out] pointer to the data type object created - * @param ctx [inout] compiler context * @param scalar_ty [in] element type * @param shape_size [in] tensor order; number of elements in shape array, must be 0 if shape == * nullptr @@ -79,16 +80,19 @@ TINYTC_EXPORT tinytc_status_t tinytc_scalar_type_get(tinytc_data_type_t *dt, * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_memref_type_get( - tinytc_data_type_t *dt, tinytc_compiler_context_t ctx, tinytc_scalar_type_t scalar_ty, - uint32_t shape_size, const int64_t *shape, uint32_t stride_size, const int64_t *stride, - tinytc_address_space_t addrspace, const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_memref_type_get(tinytc_data_type_t *dt, + tinytc_data_type_t scalar_ty, + uint32_t shape_size, const int64_t *shape, + uint32_t stride_size, const int64_t *stride, + tinytc_address_space_t addrspace, + const tinytc_location_t *loc); /** * @brief Get group data type * + * Note: modifies compiler context + * * @param dt [out] pointer to the data type object created - * @param ctx [inout] compiler context * @param memref_ty [in] memref data type object * @param offset [in][optional] offset parameter; pass 0 for default * @param loc [in][optional] Source code location; can be nullptr @@ -96,7 +100,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_memref_type_get( * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_group_type_get(tinytc_data_type_t *dt, - tinytc_compiler_context_t ctx, tinytc_data_type_t memref_ty, int64_t offset, const tinytc_location_t *loc); diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 9b86faf5..ac0988e4 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -548,7 +548,6 @@ inline data_type get_scalar(compiler_context const &ctx, scalar_type scalar_ty) * * Cf. \ref tinytc_memref_type_get * - * @param ctx Compiler context * @param scalar_ty Element type * @param shape Tensor shape * @param stride Tensor stride @@ -557,33 +556,30 @@ inline data_type get_scalar(compiler_context const &ctx, scalar_type scalar_ty) * * @return Data type */ -inline data_type get_memref(compiler_context const &ctx, scalar_type scalar_ty, - array_view shape, array_view stride = {}, +inline data_type get_memref(data_type scalar_ty, array_view shape, + array_view stride = {}, address_space addrspace = address_space::global, location const &loc = {}) { tinytc_data_type_t mt; - CHECK_STATUS_LOC( - tinytc_memref_type_get(&mt, ctx.get(), static_cast(scalar_ty), - shape.size(), shape.data(), stride.size(), stride.data(), - static_cast(addrspace), &loc), - loc); + CHECK_STATUS_LOC(tinytc_memref_type_get(&mt, scalar_ty, shape.size(), shape.data(), + stride.size(), stride.data(), + static_cast(addrspace), &loc), + loc); return mt; } /** * @brief Get a group data type * - * @param ctx Compiler context * @param memref_ty Memref data type * @param offset Offset parameter * @param loc Source code location * * @return Data type */ -inline data_type get_group(compiler_context const &ctx, data_type memref_ty, - std::int64_t offset = 0, location const &loc = {}) { +inline data_type get_group(data_type memref_ty, std::int64_t offset = 0, location const &loc = {}) { tinytc_data_type_t gt; - CHECK_STATUS_LOC(tinytc_group_type_get(>, ctx.get(), memref_ty, offset, &loc), loc); + CHECK_STATUS_LOC(tinytc_group_type_get(>, memref_ty, offset, &loc), loc); return gt; } diff --git a/src/data_type.cpp b/src/data_type.cpp index d7a54f83..91866ac0 100644 --- a/src/data_type.cpp +++ b/src/data_type.cpp @@ -30,12 +30,12 @@ tinytc_status_t tinytc_scalar_type_get(tinytc_data_type_t *dt, tinytc_compiler_c [&] { *dt = scalar_data_type::get(ctx, enum_cast(type)); }); } -tinytc_status_t tinytc_memref_type_get(tinytc_data_type_t *dt, tinytc_compiler_context_t ctx, - tinytc_scalar_type_t scalar_ty, uint32_t shape_size, - const int64_t *shape, uint32_t stride_size, - const int64_t *stride, tinytc_address_space_t addrspace, +tinytc_status_t tinytc_memref_type_get(tinytc_data_type_t *dt, tinytc_data_type_t scalar_ty, + uint32_t shape_size, const int64_t *shape, + uint32_t stride_size, const int64_t *stride, + tinytc_address_space_t addrspace, const tinytc_location_t *loc) { - if (dt == nullptr || ctx == nullptr || (shape_size != 0 && shape == nullptr) || + if (dt == nullptr || (shape_size != 0 && shape == nullptr) || (stride_size != 0 && stride == nullptr)) { return tinytc_status_invalid_arguments; } @@ -51,20 +51,18 @@ tinytc_status_t tinytc_memref_type_get(tinytc_data_type_t *dt, tinytc_compiler_c std::span(stride, static_cast(stride_size)); } - *dt = memref_data_type::get(ctx, enum_cast(scalar_ty), std::move(shape_span), - std::move(stride_span), enum_cast(addrspace), - get_optional(loc)); + *dt = memref_data_type::get(scalar_ty, std::move(shape_span), std::move(stride_span), + enum_cast(addrspace), get_optional(loc)); }); } -tinytc_status_t tinytc_group_type_get(tinytc_data_type_t *dt, tinytc_compiler_context_t ctx, - tinytc_data_type_t memref_ty, int64_t offset, - const tinytc_location_t *loc) { - if (dt == nullptr || ctx == nullptr) { +tinytc_status_t tinytc_group_type_get(tinytc_data_type_t *dt, tinytc_data_type_t memref_ty, + int64_t offset, const tinytc_location_t *loc) { + if (dt == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code( - [&] { *dt = group_data_type::get(ctx, memref_ty, offset, get_optional(loc)); }); + [&] { *dt = group_data_type::get(memref_ty, offset, get_optional(loc)); }); } } diff --git a/src/node/data_type_node.cpp b/src/node/data_type_node.cpp index caca6ffe..9154d99e 100644 --- a/src/node/data_type_node.cpp +++ b/src/node/data_type_node.cpp @@ -14,8 +14,9 @@ namespace tinytc { -auto group_data_type::get(tinytc_compiler_context_t ctx, tinytc_data_type_t ty, std::int64_t offset, +auto group_data_type::get(tinytc_data_type_t ty, std::int64_t offset, location const &lc) -> tinytc_data_type_t { + auto ctx = ty->context(); auto &value = ctx->cache()->group_tys[std::make_pair(ty, offset)]; if (value == nullptr) { @@ -37,12 +38,15 @@ group_data_type::group_data_type(tinytc_compiler_context_t ctx, tinytc_data_type } } -memref_data_type::memref_data_type(tinytc_compiler_context_t ctx, scalar_type type, +memref_data_type::memref_data_type(tinytc_compiler_context_t ctx, tinytc_data_type_t element_ty, std::vector shape, std::vector stride, address_space addrspace, location const &lc) - : data_type_node(DTK::memref, ctx), element_ty_(std::move(type)), shape_(std::move(shape)), + : data_type_node(DTK::memref, ctx), element_ty_(element_ty), shape_(std::move(shape)), stride_(std::move(stride)), addrspace_(addrspace) { + if (!isa(*element_ty_)) { + throw compilation_error(lc, status::ir_expected_scalar); + } if (stride_.size() != shape_.size()) { throw compilation_error(lc, status::ir_shape_stride_mismatch); } @@ -58,10 +62,15 @@ memref_data_type::memref_data_type(tinytc_compiler_context_t ctx, scalar_type ty } } -auto memref_data_type::get(tinytc_compiler_context_t ctx, scalar_type element_ty, - std::span shape, +scalar_type memref_data_type::element_ty() const { + return dyn_cast(element_ty_)->ty(); +} + +auto memref_data_type::get(tinytc_data_type_t element_ty, std::span shape, std::span stride, address_space addrspace, location const &lc) -> tinytc_data_type_t { + auto ctx = element_ty->context(); + auto stride_buffer = std::vector{}; if (stride.empty()) { stride_buffer = canonical_stride(shape); @@ -110,7 +119,7 @@ auto memref_data_type_key::hash() -> std::uint64_t { } auto memref_data_type_key::operator==(memref_data_type const &mt) -> bool { - return element_ty == mt.element_ty() && addrspace == mt.addrspace() && + return element_ty == mt.element_data_ty() && addrspace == mt.addrspace() && std::equal(shape.begin(), shape.end(), mt.shape().begin(), mt.shape().end()) && std::equal(stride.begin(), stride.end(), mt.stride().begin(), mt.stride().end()); } diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp index fc030aea..c699e27e 100644 --- a/src/node/data_type_node.hpp +++ b/src/node/data_type_node.hpp @@ -42,7 +42,7 @@ using data_type_node = ::tinytc_data_type; class group_data_type : public data_type_node { public: inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::group; } - static auto get(tinytc_compiler_context_t ctx, tinytc_data_type_t ty, std::int64_t offset, + static auto get(tinytc_data_type_t ty, std::int64_t offset, location const &lc = {}) -> tinytc_data_type_t; inline auto ty() const -> tinytc_data_type_t { return ty_; } @@ -61,19 +61,20 @@ class memref_data_type : public data_type_node { public: inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::memref; } static auto canonical_stride(std::span shape) -> std::vector; - static auto get(tinytc_compiler_context_t ctx, scalar_type element_ty, - std::span shape, std::span stride, + static auto get(tinytc_data_type_t element_ty, std::span shape, + std::span stride, address_space addrspace = address_space::global, location const &lc = {}) -> tinytc_data_type_t; - inline scalar_type element_ty() const { return element_ty_; } + scalar_type element_ty() const; + inline tinytc_data_type_t element_data_ty() const { return element_ty_; } inline std::int64_t dim() const { return shape_.size(); } inline auto const &shape() const { return shape_; } inline std::int64_t shape(std::int64_t i) const { return shape_[i]; } inline auto const &stride() const { return stride_; } inline std::int64_t stride(std::int64_t i) const { return stride_[i]; } inline std::int64_t size_in_bytes() const { - return is_dynamic() ? dynamic : size(element_ty_) * stride_.back() * shape_.back(); + return is_dynamic() ? dynamic : size(element_ty()) * stride_.back() * shape_.back(); } inline auto addrspace() const -> address_space { return addrspace_; } inline void addrspace(address_space space) { addrspace_ = space; } @@ -88,17 +89,17 @@ class memref_data_type : public data_type_node { inline bool is_canonical_stride() const { return stride_ == canonical_stride(shape_); } protected: - memref_data_type(tinytc_compiler_context_t ctx, scalar_type type, + memref_data_type(tinytc_compiler_context_t ctx, tinytc_data_type_t element_ty, std::vector shape, std::vector stride, address_space addrspace = address_space::global, location const &lc = {}); - scalar_type element_ty_; + tinytc_data_type_t element_ty_; std::vector shape_, stride_; address_space addrspace_ = address_space::global; }; struct memref_data_type_key { - scalar_type element_ty; + tinytc_data_type_t element_ty; std::span shape, stride; address_space addrspace; diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index c08fd3dc..6667d0f4 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -274,8 +274,7 @@ expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, stride.push_back(m->stride(i)); } - auto result_ty = - memref_data_type::get(m->context(), m->element_ty(), shape, stride, m->addrspace()); + auto result_ty = memref_data_type::get(m->element_data_ty(), shape, stride, m->addrspace()); result(0) = value_node{result_ty, lc}; } @@ -311,8 +310,7 @@ fuse_inst::fuse_inst(tinytc_value_t op0, std::int64_t from, std::int64_t to, loc shape.push_back(m->shape(i)); stride.push_back(m->stride(i)); } - auto result_ty = - memref_data_type::get(m->context(), m->element_ty(), shape, stride, m->addrspace()); + auto result_ty = memref_data_type::get(m->element_data_ty(), shape, stride, m->addrspace()); result(0) = value_node{result_ty, lc}; } @@ -535,8 +533,7 @@ subview_inst::subview_inst(tinytc_value_t op0, array_view static_o } } - auto result_ty = - memref_data_type::get(m->context(), m->element_ty(), shape, stride, m->addrspace()); + auto result_ty = memref_data_type::get(m->element_data_ty(), shape, stride, m->addrspace()); result(0) = value_node{result_ty, lc}; } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 20f39c3a..16fc07c1 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -45,14 +45,6 @@ namespace tinytc { - void check_scalar_type(compiler_context const &ctx, tinytc_value_t val, scalar_type const &sty, - location &loc1, location &loc2) { - if (val->ty() != get_scalar(ctx, sty)) { - auto loc = loc1; - loc.end = loc2.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); - } - } void check_type(tinytc_value_t val, tinytc_data_type_t ty, location &loc1, location &loc2) { if (val->ty() != ty) { auto loc = loc1; @@ -155,7 +147,7 @@ %nterm >> attributes %nterm > attribute %nterm data_type -%nterm scalar_type +%nterm scalar_type %nterm memref_type %nterm optional_address_space %nterm > mode_list @@ -312,21 +304,20 @@ attribute: data_type: - scalar_type { $$ = get_scalar(ctx.cctx(), $scalar_type); } + scalar_type | memref_type | group_type ; scalar_type: - INTEGER_TYPE - | FLOATING_TYPE + INTEGER_TYPE { $$ = get_scalar(ctx.cctx(), $INTEGER_TYPE); } + | FLOATING_TYPE { $$ = get_scalar(ctx.cctx(), $FLOATING_TYPE); } ; memref_type: MEMREF LCHEV scalar_type mode_list optional_address_space RCHEV { try { - $$ = - get_memref(ctx.cctx(), $scalar_type, $mode_list, {}, $optional_address_space, @memref_type); + $$ = get_memref($scalar_type, $mode_list, {}, $optional_address_space, @memref_type); } catch (compilation_error const &e) { error(e.loc(), e.what()); YYERROR; @@ -339,7 +330,7 @@ memref_type: throw syntax_error(loc, "Shape and stride list must have the same length"); } try { - $$ = get_memref(ctx.cctx(), $scalar_type, $mode_list, $optional_stride_list, + $$ = get_memref($scalar_type, $mode_list, $optional_stride_list, $optional_address_space, @memref_type); } catch (compilation_error const &e) { error(e.loc(), e.what()); @@ -376,7 +367,7 @@ constant_or_dynamic: group_type: GROUP LCHEV memref_type group_offset RCHEV { - $$ = get_group(ctx.cctx(), std::move($memref_type), $group_offset, @group_type); + $$ = get_group(std::move($memref_type), $group_offset, @group_type); } ; @@ -431,9 +422,9 @@ axpby_inst: AXPBY transpose[ta] atomic var[alpha] COMMA var[a] COMMA var[beta] COMMA var[b] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA scalar_type[fbeta] COMMA memref_type[mb] { - check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); + check_type($alpha, $falpha, @alpha, @falpha); check_type($a, $ma, @a, @ma); - check_scalar_type(ctx.cctx(), $beta, $fbeta, @beta, @fbeta); + check_type($beta, $fbeta, @beta, @fbeta); check_type($b, $mb, @b, @mb); try { $$ = inst { @@ -495,10 +486,10 @@ gemm_inst: var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] COMMA memref_type[mc] { - check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); + check_type($alpha, $falpha, @alpha, @falpha); check_type($a, $ma, @a, @ma); check_type($b, $mb, @b, @mb); - check_scalar_type(ctx.cctx(), $beta, $fbeta, @beta, @fbeta); + check_type($beta, $fbeta, @beta, @fbeta); check_type($c, $mc, @c, @mc); try { $$ = inst { @@ -519,10 +510,10 @@ gemv_inst: var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] COMMA memref_type[mc] { - check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); + check_type($alpha, $falpha, @alpha, @falpha); check_type($a, $ma, @a, @ma); check_type($b, $mb, @b, @mb); - check_scalar_type(ctx.cctx(), $beta, $fbeta, @beta, @fbeta); + check_type($beta, $fbeta, @beta, @fbeta); check_type($c, $mc, @c, @mc); try { $$ = inst { @@ -547,10 +538,10 @@ ger_inst: var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] COMMA memref_type[mc] { - check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); + check_type($alpha, $falpha, @alpha, @falpha); check_type($a, $ma, @a, @ma); check_type($b, $mb, @b, @mb); - check_scalar_type(ctx.cctx(), $beta, $fbeta, @beta, @fbeta); + check_type($beta, $fbeta, @beta, @fbeta); check_type($c, $mc, @c, @mc); try { $$ = inst { @@ -655,10 +646,10 @@ hadamard_inst: var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] COMMA memref_type[mc] { - check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); + check_type($alpha, $falpha, @alpha, @falpha); check_type($a, $ma, @a, @ma); check_type($b, $mb, @b, @mb); - check_scalar_type(ctx.cctx(), $beta, $fbeta, @beta, @fbeta); + check_type($beta, $fbeta, @beta, @fbeta); check_type($c, $mc, @c, @mc); try { $$ = inst { @@ -678,9 +669,9 @@ sum_inst: SUM transpose[ta] atomic var[alpha] COMMA var[a] COMMA var[beta] COMMA var[b] COLON scalar_type[falpha] COMMA memref_type[ma] COMMA scalar_type[fbeta] COMMA memref_type[mb] { - check_scalar_type(ctx.cctx(), $alpha, $falpha, @alpha, @falpha); + check_type($alpha, $falpha, @alpha, @falpha); check_type($a, $ma, @a, @ma); - check_scalar_type(ctx.cctx(), $beta, $fbeta, @beta, @fbeta); + check_type($beta, $fbeta, @beta, @fbeta); check_type($b, $mb, @b, @mb); try { $$ = inst { @@ -703,11 +694,7 @@ yield_inst: throw syntax_error(loc, "Identifier and scalar type list must have the same length"); } for (std::size_t i = 0; i < $vals.size(); ++i) { - if (auto ty = dyn_cast($tys[i]); ty) { - check_scalar_type(ctx.cctx(), $vals[i], ty->ty(), @vals, @tys); - } else { - throw syntax_error(@tys, "Yield only accepts scalar types"); - } + check_type($vals[i], $tys[i], @vals, @tys); } $$ = inst{std::make_unique(std::move($vals)).release()}; } @@ -749,8 +736,8 @@ alloca_inst: arith_inst: ARITH ARITHMETIC var[a] COMMA var[b] COLON scalar_type[ty] { - check_scalar_type(ctx.cctx(), $a, $ty, @a, @ty); - check_scalar_type(ctx.cctx(), $b, $ty, @b, @ty); + check_type($a, $ty, @a, @ty); + check_type($b, $ty, @b, @ty); try { $$ = inst { std::make_unique($ARITHMETIC, std::move($a), std::move($b), @arith_inst) @@ -765,7 +752,7 @@ arith_inst: arith_unary_inst: ARITH ARITHMETIC_UNARY var[a] COLON scalar_type[ty] { - check_scalar_type(ctx.cctx(), $a, $ty, @a, @ty); + check_type($a, $ty, @a, @ty); try { $$ = inst { std::make_unique($ARITHMETIC_UNARY, std::move($a), @@ -794,8 +781,8 @@ cast_inst: compare_inst: CMP CMP_CONDITION var[a] COMMA var[b] COLON scalar_type[ty] { - check_scalar_type(ctx.cctx(), $a, $ty, @a, @ty); - check_scalar_type(ctx.cctx(), $b, $ty, @b, @ty); + check_type($a, $ty, @a, @ty); + check_type($b, $ty, @b, @ty); try { $$ = inst { std::make_unique($CMP_CONDITION, std::move($a), std::move($b), @@ -887,7 +874,7 @@ expand_shape: integer_constant_or_identifier: var { - check_scalar_type(ctx.cctx(), $var, scalar_type::index, @var, @var); + check_type($var, get_scalar(ctx.cctx(), scalar_type::index), @var, @var); $$ = $var; } | INTEGER_CONSTANT { @@ -963,7 +950,7 @@ group_size_inst: if_inst: IF var[condition] optional_returned_values { - check_scalar_type(ctx.cctx(), $condition, scalar_type::i1, @condition, @condition); + check_type($condition, get_scalar(ctx.cctx(), scalar_type::i1), @condition, @condition); try { auto loc = @IF; loc.end = @optional_returned_values.end; @@ -1000,9 +987,9 @@ optional_scalar_type_list: ; scalar_type_list: - scalar_type { $$.push_back(get_scalar(ctx.cctx(), $scalar_type)); } + scalar_type { $$.push_back($scalar_type); } | scalar_type_list COMMA scalar_type { - $$ = std::move($1); $$.push_back(get_scalar(ctx.cctx(), $scalar_type)); + $$ = std::move($1); $$.push_back($scalar_type); } ; diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index 30700601..4a02263f 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -86,14 +86,12 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( auto const tB_ = enum_cast(tB); auto const kernel = [&](char const *name, bool is_beta_nonzero) { - auto A_ty = - get_memref(ctx_, enum_cast(ty), {selA(M, K), selA(K, M), dynamic}, - {1, ldA, strideA}, address_space::global, my_loc()); - auto B_ty = - get_memref(ctx_, enum_cast(ty), {selB(K, N), selB(N, K), dynamic}, - {1, ldB, strideB}, address_space::global, my_loc()); - auto C_ty = get_memref(ctx_, enum_cast(ty), {M, N, dynamic}, - {1, ldC, strideC}, address_space::global, my_loc()); + auto A_ty = get_memref(ty_, {selA(M, K), selA(K, M), dynamic}, {1, ldA, strideA}, + address_space::global, my_loc()); + auto B_ty = get_memref(ty_, {selB(K, N), selB(N, K), dynamic}, {1, ldB, strideB}, + address_space::global, my_loc()); + auto C_ty = get_memref(ty_, {M, N, dynamic}, {1, ldC, strideC}, + address_space::global, my_loc()); auto f = make_func(name, {ty_, A_ty, B_ty, ty_, C_ty}, my_loc()); auto fn_body = f.get_body(); auto params = std::array{}; diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index ba8a0c68..e394bd79 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -151,12 +151,9 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( }; auto const kernel = [&](char const *name, bool is_beta_nonzero) { - auto A_ty = get_memref(ctx_, enum_cast(ty), {M, K}, {1, ldA}, - address_space::global, my_loc()); - auto B_ty = get_memref(ctx_, enum_cast(ty), {K, N}, {1, ldB}, - address_space::global, my_loc()); - auto C_ty = get_memref(ctx_, enum_cast(ty), {M, N}, {1, ldC}, - address_space::global, my_loc()); + auto A_ty = get_memref(ty_, {M, K}, {1, ldA}, address_space::global, my_loc()); + auto B_ty = get_memref(ty_, {K, N}, {1, ldB}, address_space::global, my_loc()); + auto C_ty = get_memref(ty_, {M, N}, {1, ldC}, address_space::global, my_loc()); auto f = make_func(name, {ty_, A_ty, B_ty, ty_, C_ty}, my_loc()); auto fn_body = f.get_body(); auto params = std::array{}; From 9ff41598e29ced838570d4dd0c71b447d8751309 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 2 Oct 2024 13:24:49 +0200 Subject: [PATCH 037/297] Add def-use chain datastructure Signed-off-by: Carsten Uphoff --- src/CMakeLists.txt | 1 + src/compiler.cpp | 1 + src/node/inst_node.cpp | 64 ++++++++++--------- src/node/inst_node.hpp | 61 +++++++++++------- src/node/program_node.hpp | 5 +- src/node/value_node.cpp | 90 +++++++++++++++++++++++++- src/node/value_node.hpp | 87 +++++++++++++++++++++++++ src/pass/convert_to_opencl.cpp | 12 ++-- src/pass/dump_def_use.cpp | 47 ++++++++++++++ src/pass/dump_def_use.hpp | 25 ++++++++ src/pass/dump_ir.cpp | 22 ++++--- src/pass/dump_ir.hpp | 4 +- src/pass/insert_lifetime_stop.cpp | 4 +- src/passes.def | 1 + src/support/util.hpp | 101 ++++++++++++++++++------------ test/opt/dump-def-use.ir | 33 ++++++++++ 16 files changed, 442 insertions(+), 116 deletions(-) create mode 100644 src/pass/dump_def_use.cpp create mode 100644 src/pass/dump_def_use.hpp create mode 100644 test/opt/dump-def-use.ir diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6741cc52..c24f87b2 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -44,6 +44,7 @@ set(SOURCES pass/constant_propagation.cpp pass/convert_to_opencl.cpp pass/dump_cfg.cpp + pass/dump_def_use.cpp pass/dump_ir.cpp pass/insert_barrier.cpp pass/insert_lifetime_stop.cpp diff --git a/src/compiler.cpp b/src/compiler.cpp index 19bdd80a..b6b47615 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -8,6 +8,7 @@ #include "pass/constant_propagation.hpp" #include "pass/convert_to_opencl.hpp" #include "pass/dump_cfg.hpp" +#include "pass/dump_def_use.hpp" #include "pass/dump_ir.hpp" #include "pass/insert_barrier.hpp" #include "pass/insert_lifetime_stop.hpp" diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 6667d0f4..dd8cd818 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -40,28 +40,30 @@ memref_data_type *get_memref_type(location const &loc, tinytc_value const &v) { blas_a2_inst::blas_a2_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t beta, tinytc_value_t B, bool atomic) : standard_inst{tid}, atomic_(atomic) { - op(op_alpha) = std::move(alpha); - op(op_A) = std::move(A); - op(op_beta) = std::move(beta); - op(op_B) = std::move(B); + op(op_alpha, alpha); + op(op_A, A); + op(op_beta, beta); + op(op_B, B); } blas_a3_inst::blas_a3_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, tinytc_value_t beta, tinytc_value_t C, bool atomic) : standard_inst{tid}, atomic_(atomic) { - op(op_alpha) = std::move(alpha); - op(op_A) = std::move(A); - op(op_B) = std::move(B); - op(op_beta) = std::move(beta); - op(op_C) = std::move(C); + op(op_alpha, alpha); + op(op_A, A); + op(op_B, B); + op(op_beta, beta); + op(op_C, C); } loop_inst::loop_inst(IK tid, tinytc_value_t from0, tinytc_value_t to0, tinytc_value_t step0, tinytc_data_type_t loop_var_type, location const &lc) : standard_inst{tid, step0 ? 3 : 2} { - op(op_from) = std::move(from0); - op(op_to) = std::move(to0); - op(op_step) = std::move(step0); + op(op_from, from0); + op(op_to, to0); + if (step0) { + op(op_step, step0); + } body().set_params(array_view{loop_var_type}, lc); loc(lc); @@ -122,8 +124,8 @@ axpby_inst::axpby_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, t arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b0, location const &lc) : standard_inst{IK::arith}, operation_(operation) { - op(op_a) = std::move(a0); - op(op_b) = std::move(b0); + op(op_a, a0); + op(op_b, b0); loc(lc); auto at = get_scalar_type(loc(), a()); @@ -158,7 +160,7 @@ arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0, location const &lc) : standard_inst{IK::arith_unary}, operation_(operation) { - op(op_a) = std::move(a0); + op(op_a, a0); loc(lc); auto at = get_scalar_type(loc(), a()); @@ -179,7 +181,7 @@ arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0 cast_inst::cast_inst(tinytc_value_t a, tinytc_data_type_t to_ty, location const &lc) : standard_inst{IK::cast} { - op(op_a) = std::move(a); + op(op_a, a); loc(lc); if (!isa(*to_ty)) { @@ -192,8 +194,8 @@ cast_inst::cast_inst(tinytc_value_t a, tinytc_data_type_t to_ty, location const compare_inst::compare_inst(cmp_condition cond, tinytc_value_t a0, tinytc_value_t b0, location const &lc) : standard_inst{IK::compare}, cond_(cond) { - op(op_a) = std::move(a0); - op(op_b) = std::move(b0); + op(op_a, a0); + op(op_b, b0); loc(lc); auto at = get_scalar_type(loc(), a()); @@ -232,9 +234,9 @@ expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, array_view expand_shape0, location const &lc) : standard_inst{IK::expand, static_cast(1 + expand_shape0.size())}, expanded_mode_(expanded_mode), static_expand_shape_(std::move(static_expand_shape0)) { - op(0) = std::move(op0); + op(0, op0); for (std::size_t i = 0; i < expand_shape0.size(); ++i) { - op(1 + i) = expand_shape0[i]; + op(1 + i, expand_shape0[i]); } loc(lc); @@ -280,7 +282,7 @@ expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, fuse_inst::fuse_inst(tinytc_value_t op0, std::int64_t from, std::int64_t to, location const &lc) : standard_inst{IK::fuse}, from_(from), to_(to) { - op(0) = std::move(op0); + op(0, op0); loc(lc); auto m = get_memref_type(loc(), operand()); bool const range_ok = 0 <= from_ && from_ < to_ && to_ < m->dim(); @@ -316,9 +318,9 @@ fuse_inst::fuse_inst(tinytc_value_t op0, std::int64_t from, std::int64_t to, loc load_inst::load_inst(tinytc_value_t op0, array_view index_list0, location const &lc) : standard_inst{IK::load, static_cast(1 + index_list0.size())} { - op(0) = std::move(op0); + op(0, op0); for (std::size_t i = 0; i < index_list0.size(); ++i) { - op(1 + i) = index_list0[i]; + op(1 + i, index_list0[i]); } loc(lc); @@ -460,7 +462,7 @@ hadamard_inst::hadamard_inst(tinytc_value_t alpha0, tinytc_value_t A0, tinytc_va if_inst::if_inst(tinytc_value_t condition, array_view return_types, location const &lc) : standard_inst{IK::if_, 1, static_cast(return_types.size())} { - op(0) = std::move(condition); + op(0, condition); loc(lc); for (std::size_t i = 0; i < return_types.size(); ++i) { result(i) = value_node{return_types[i], lc}; @@ -475,7 +477,7 @@ parallel_inst::parallel_inst(location const &lc) : standard_inst{IK::parallel} { size_inst::size_inst(tinytc_value_t op0, std::int64_t mode, location const &lc) : standard_inst{IK::size}, mode_(mode) { - op(0) = std::move(op0); + op(0, op0); loc(lc); auto m = get_memref_type(loc(), operand()); bool const range_ok = 0 <= mode_ && mode_ < m->dim(); @@ -493,15 +495,15 @@ subview_inst::subview_inst(tinytc_value_t op0, array_view static_o location const &lc) : standard_inst{IK::subview, static_cast(1 + offsets0.size() + sizes0.size())}, static_offsets_(std::move(static_offsets0)), static_sizes_(std::move(static_sizes0)) { - op(0) = std::move(op0); + op(0, op0); { std::size_t i = 1; for (auto const &val : offsets0) { - op(i++) = val; + op(i++, val); } num_dyn_offsets_ = i - 1; for (auto const &val : sizes0) { - op(i++) = val; + op(i++, val); } } loc(lc); @@ -540,12 +542,12 @@ subview_inst::subview_inst(tinytc_value_t op0, array_view static_o store_inst::store_inst(tinytc_value_t val0, tinytc_value_t op0, array_view index_list0, location const &lc) : standard_inst{IK::store, static_cast(2 + index_list0.size())} { - op(op_val) = std::move(val0); - op(op_operand) = std::move(op0); + op(op_val, val0); + op(op_operand, op0); { std::size_t i = op_operand; for (auto const &val : index_list0) { - op(++i) = val; + op(++i, val); } } loc(lc); diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 234dd68b..f1d2f2bc 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -85,9 +85,6 @@ using inst_nodes = class subgroup_local_id_inst, class subgroup_size_inst, class sum_inst, class yield_inst>; -using op_range = iterator_range_wrapper; -using const_op_range = iterator_range_wrapper; - using result_range = iterator_range_wrapper; using const_result_range = iterator_range_wrapper; @@ -100,6 +97,19 @@ struct tinytc_inst : tinytc::ilist_node_with_parent public: using leaves = tinytc::inst_nodes; + using op_iterator = + tinytc::indirect_random_access_iterator; + using const_op_iterator = + tinytc::indirect_random_access_iterator; + + using op_range = tinytc::iterator_range_wrapper; + using const_op_range = tinytc::iterator_range_wrapper; + + static_assert(std::random_access_iterator); + static_assert(std::random_access_iterator); + static_assert(std::ranges::random_access_range); + static_assert(std::ranges::random_access_range); + inline tinytc_inst(tinytc::IK tid) : tid_(tid) {} virtual ~tinytc_inst() = default; @@ -114,14 +124,15 @@ struct tinytc_inst : tinytc::ilist_node_with_parent inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } // Iterator over operands - inline auto op_begin() -> tinytc_value_t * { return op_begin_; } - inline auto op_end() -> tinytc_value_t * { return op_end_; } - inline auto operands() -> tinytc::op_range { return {op_begin_, op_end_}; } - inline auto op_begin() const -> const tinytc_value_t * { return op_begin_; } - inline auto op_end() const -> const tinytc_value_t * { return op_end_; } - inline auto operands() const -> tinytc::const_op_range { return {op_begin_, op_end_}; } - inline auto op(std::size_t pos) -> tinytc_value_t & { return op_begin_[pos]; } - inline auto op(std::size_t pos) const -> tinytc_value_t const & { return op_begin_[pos]; } + inline auto op_begin() -> op_iterator { return {op_begin_}; } + inline auto op_end() -> op_iterator { return {op_end_}; } + inline auto operands() -> op_range { return {op_begin(), op_end()}; } + inline auto op_begin() const -> const_op_iterator { return {op_begin_}; } + inline auto op_end() const -> const_op_iterator { return {op_end_}; } + inline auto operands() const -> const_op_range { return {op_begin(), op_end()}; } + inline auto op(std::size_t pos) -> tinytc_value_t { return op_begin_[pos].get(); } + inline auto op(std::size_t pos) const -> tinytc_value_t { return op_begin_[pos].get(); } + inline void op(std::size_t pos, tinytc_value_t val) { op_begin_[pos] = val; } inline auto num_operands() const -> std::int64_t { return op_end_ - op_begin_; } // Iterator over results @@ -208,15 +219,15 @@ struct tinytc_inst : tinytc::ilist_node_with_parent } protected: - inline auto op_range(tinytc_value_t *begin, tinytc_value_t *end) { + inline auto set_op_range(tinytc::use *begin, tinytc::use *end) noexcept { op_begin_ = begin; op_end_ = end; } - inline auto result_range(tinytc_value_t begin, tinytc_value_t end) { + inline auto set_result_range(tinytc_value_t begin, tinytc_value_t end) noexcept { result_begin_ = begin; result_end_ = end; } - inline auto child_regions_range(tinytc_region_t begin, tinytc_region_t end) { + inline auto set_child_regions_range(tinytc_region_t begin, tinytc_region_t end) noexcept { child_regions_begin_ = begin; child_regions_end_ = end; } @@ -224,7 +235,7 @@ struct tinytc_inst : tinytc::ilist_node_with_parent private: tinytc::IK tid_; tinytc::location loc_; - tinytc_value_t *op_begin_ = nullptr, *op_end_ = nullptr; + tinytc::use *op_begin_ = nullptr, *op_end_ = nullptr; tinytc_value_t result_begin_ = nullptr, result_end_ = nullptr; tinytc_region_t child_regions_begin_ = nullptr, child_regions_end_ = nullptr; }; @@ -272,18 +283,24 @@ class standard_inst : public inst_node { : inst_node{tid}, ops_{num_operands}, results_{num_results}, child_regions_{num_child_regions} { if (num_operands > 0) { - op_range(ops_.get(), ops_.get() + num_operands); + auto *op_begin = ops_.get(); + set_op_range(op_begin, op_begin + num_operands); + if constexpr (NumOperands != 0) { + for (std::int64_t i = 0; i < num_operands; ++i) { + op_begin[i].owner(this); + } + } } if (num_results > 0) { - result_range(results_.get(), results_.get() + num_results); + set_result_range(results_.get(), results_.get() + num_results); } if (num_child_regions > 0) { - child_regions_range(child_regions_.get(), child_regions_.get() + num_child_regions); + set_child_regions_range(child_regions_.get(), child_regions_.get() + num_child_regions); } } private: - object_container ops_; + object_container ops_; object_container results_; object_container child_regions_; }; @@ -523,9 +540,7 @@ class group_size_inst : public standard_inst<0, 1> { class lifetime_stop_inst : public standard_inst<1, 0> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::lifetime_stop; } - inline lifetime_stop_inst(tinytc_value_t obj) : standard_inst{IK::lifetime_stop} { - op(0) = std::move(obj); - } + inline lifetime_stop_inst(tinytc_value_t obj) : standard_inst{IK::lifetime_stop} { op(0, obj); } inline auto object() const -> tinytc_value const & { return *op(0); } }; @@ -716,7 +731,7 @@ class yield_inst : public standard_inst { : standard_inst{IK::yield, static_cast(vals.size())} { loc(lc); for (std::size_t i = 0; i < vals.size(); ++i) { - op(i) = vals[i]; + op(i, vals[i]); } } }; diff --git a/src/node/program_node.hpp b/src/node/program_node.hpp index d843c24b..036a0942 100644 --- a/src/node/program_node.hpp +++ b/src/node/program_node.hpp @@ -16,8 +16,9 @@ struct tinytc_prog final : tinytc::reference_counted { public: - using iterator = tinytc::indirect_iterator::iterator>; - using const_iterator = tinytc::indirect_iterator::const_iterator>; + using iterator = tinytc::indirect_random_access_iterator::iterator>; + using const_iterator = + tinytc::indirect_random_access_iterator::const_iterator>; tinytc_prog(tinytc::compiler_context ctx, tinytc_location const &lc = {}); diff --git a/src/node/value_node.cpp b/src/node/value_node.cpp index 972ac408..8b6b1357 100644 --- a/src/node/value_node.cpp +++ b/src/node/value_node.cpp @@ -3,5 +3,93 @@ #include "node/value_node.hpp" -tinytc_value::tinytc_value(tinytc_data_type_t ty, tinytc::location const &lc) +using namespace tinytc; + +tinytc_value::tinytc_value(tinytc_data_type_t ty, location const &lc) : ty_{std::move(ty)}, loc_{lc} {} + +auto tinytc_value::use_begin() -> use_iterator { return {first_use_}; } +auto tinytc_value::use_end() -> use_iterator { return {nullptr}; } +auto tinytc_value::uses() -> iterator_range_wrapper { + return {use_begin(), use_end()}; +} +auto tinytc_value::use_begin() const -> const_use_iterator { return {first_use_}; } +auto tinytc_value::use_end() const -> const_use_iterator { return {nullptr}; } +auto tinytc_value::uses() const -> iterator_range_wrapper { + return {use_begin(), use_end()}; +} + +namespace tinytc { + +use::use(tinytc_inst_t owner) : owner_{owner} {} + +use::~use() { + if (value_) { + remove_use_from_current_list(); + } +} + +use &use::operator=(value_node *val) { + set(val); + return *this; +} + +void use::set(value_node *value) { + if (value_) { + remove_use_from_current_list(); + } + value_ = value; + if (value_) { + add_use_to_list(&value_->first_use_); + } +} + +/* + * Let next = &A.n and we have + * + * ...A|.p|.n-->B|.p|.n-->C|.p|.n... + * ...----| ^-------| ^-------| + * + * After inserting T (T = this) we want the following new or adjusted pointers + * + * ...A|.p|.n==>T|.p|.n==>B|.p|.n-->C|.p|.n... + * ...---| ^======| ^======| ^------| + * + * We need to set + * next_ = T.n -> B = *next + * next_->prev_ = B.p -> &T.n = &next_ + * prev_ = T.p -> &A.n = next + * *next = A.n -> T = this + */ +void use::add_use_to_list(use **next) { + next_ = *next; + if (next_) { + next_->prev_ = &next_; + } + prev_ = next; + *next = this; +} + +/* + * We want to remove T (T = this): + * + * ...A|.p|.n-->T|.p|.n-->B|.p|.n-->C|.p|.n... + * ...---| ^------| ^------| ^------| + * + * After removing T we want the following adjusted pointers + * + * ...A|.p|.n==>B|.p|.n-->C|.p|.n... + * ...---| ^======| ^------| + * + * We need to set + * next_->prev_ = B.p -> &A.n = prev_ + * *prev_ = A.n -> B = next_ + */ +void use::remove_use_from_current_list() { + if (next_) { + next_->prev_ = prev_; + } + *prev_ = next_; +} + +} // namespace tinytc diff --git a/src/node/value_node.hpp b/src/node/value_node.hpp index 6a50844f..615cb913 100644 --- a/src/node/value_node.hpp +++ b/src/node/value_node.hpp @@ -6,12 +6,20 @@ #include "location.hpp" #include "node/data_type_node.hpp" +#include "support/util.hpp" #include "tinytc/types.h" #include +#include #include #include +namespace tinytc { +class use; +class use_iterator; +class const_use_iterator; +}; // namespace tinytc + struct tinytc_value final { public: tinytc_value(tinytc_data_type_t ty = nullptr, tinytc::location const &lc = {}); @@ -27,16 +35,95 @@ struct tinytc_value final { inline void name(std::string name) { name_ = std::move(name); } auto has_name() const -> bool { return !name_.empty(); } + auto use_begin() -> tinytc::use_iterator; + auto use_end() -> tinytc::use_iterator; + auto uses() -> tinytc::iterator_range_wrapper; + auto use_begin() const -> tinytc::const_use_iterator; + auto use_end() const -> tinytc::const_use_iterator; + auto uses() const -> tinytc::iterator_range_wrapper; + private: tinytc_data_type_t ty_; tinytc::location loc_; std::string name_; + + friend class tinytc::use; + tinytc::use *first_use_ = nullptr; }; namespace tinytc { using value_node = ::tinytc_value; +class use { + public: + use() = default; + use(tinytc_inst_t owner); + ~use(); + + use(use &&other) = delete; + use(use const &other) = delete; + use &operator=(use &&other) = delete; + use &operator=(use const &other) = delete; + + use &operator=(value_node *val); + + inline auto get() -> value_node * { return value_; } + inline auto get() const -> value_node const * { return value_; } + void set(value_node *value); + + inline auto owner() const -> tinytc_inst_t { return owner_; } + inline void owner(tinytc_inst_t owner) { owner_ = owner; } + + inline auto next() -> use * { return next_; } + inline auto next() const -> use const * { return next_; } + + private: + void add_use_to_list(use **next); + void remove_use_from_current_list(); + + tinytc_inst_t owner_ = nullptr; + value_node *value_ = nullptr; + use **prev_ = nullptr; + use *next_ = nullptr; +}; + +namespace detail { +template class use_iterator_base { + public: + using value_type = std::conditional_t; + using pointer = value_type *; + using reference = value_type &; + using difference_type = std::ptrdiff_t; + + use_iterator_base() : pos_{nullptr} {} + use_iterator_base(pointer pos) : pos_{std::move(pos)} {} + + auto operator*() const -> reference { return *pos_; } + auto operator->() const -> pointer { return pos_; } + auto operator++() -> Derived & { + pos_ = pos_->next(); + return *static_cast(this); + } + auto operator++(int) -> Derived { + auto old_pos = pos_; + pos_ = pos_->next(); + return Derived{old_pos}; + } + auto operator==(use_iterator_base const &other) const -> bool { return pos_ == other.pos_; } + auto operator!=(use_iterator_base const &other) const -> bool { return pos_ != other.pos_; } + + private: + pointer pos_; +}; +} // namespace detail + +class use_iterator : public detail::use_iterator_base {}; +class const_use_iterator : public detail::use_iterator_base {}; + +static_assert(std::forward_iterator); +static_assert(std::forward_iterator); + } // namespace tinytc #endif // VALUE_NODE_20230309_HPP diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index bb056b9f..a106c13e 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -451,7 +451,7 @@ std::vector convert_to_opencl_pass::operator()(expand_inst const &e) int j = 0; for (auto &s : static_shape) { if (is_dynamic_value(s)) { - eshape_cl.emplace_back(val(*dyn_shape[j++])); + eshape_cl.emplace_back(val(dyn_shape[j++])); } else { eshape_cl.emplace_back(clir::expr(s, static_cast(size(scalar_type::index) * 8))); } @@ -527,7 +527,7 @@ std::vector convert_to_opencl_pass::operator()(load_inst const &e) { throw compilation_error(e.loc(), status::ir_invalid_number_of_indices); } - auto idx = val(*e.index_list().front()); + auto idx = val(e.index_list().front()); rhs = rhs + idx; auto &dv = get_dope_vector(e.operand()); @@ -550,7 +550,7 @@ std::vector convert_to_opencl_pass::operator()(load_inst const &e) { } auto &dv = get_dope_vector(e.operand()); for (std::int64_t i = 0; i < m.dim(); ++i) { - rhs = rhs + val(*e.index_list()[i]) * dv.stride(i); + rhs = rhs + val(e.index_list()[i]) * dv.stride(i); } rhs = clir::dereference(std::move(rhs)); }, @@ -907,7 +907,7 @@ std::vector convert_to_opencl_pass::operator()(subview_inst const &s auto offset_cl = clir::expr{}; if (is_dynamic_value(offset)) { - offset_cl = val(*dyn_offsets[joffset++]); + offset_cl = val(dyn_offsets[joffset++]); } else { offset_cl = clir::expr(offset, static_cast(tinytc::size(scalar_type::index) * 8)); @@ -918,7 +918,7 @@ std::vector convert_to_opencl_pass::operator()(subview_inst const &s if (size > 0 || is_dynamic_value(size)) { auto size_cl = clir::expr{}; if (is_dynamic_value(size)) { - size_cl = val(*dyn_sizes[jsize++]); + size_cl = val(dyn_sizes[jsize++]); } else { size_cl = clir::expr(size, static_cast(tinytc::size(scalar_type::index) * 8)); @@ -954,7 +954,7 @@ std::vector convert_to_opencl_pass::operator()(store_inst const &s) auto lhs = val(s.operand()); auto &dv = get_dope_vector(s.operand()); for (std::int64_t i = 0; i < ot->dim(); ++i) { - lhs = lhs + val(*s.index_list()[i]) * dv.stride(i); + lhs = lhs + val(s.index_list()[i]) * dv.stride(i); } auto rhs = val(s.val()); diff --git a/src/pass/dump_def_use.cpp b/src/pass/dump_def_use.cpp new file mode 100644 index 00000000..efcfad4b --- /dev/null +++ b/src/pass/dump_def_use.cpp @@ -0,0 +1,47 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/dump_def_use.hpp" +#include "pass/dump_ir.hpp" +#include "support/util.hpp" +#include "support/visit.hpp" +#include "support/walk.hpp" + +namespace tinytc { + +dump_def_use_pass::dump_def_use_pass(std::ostream &os) : os_(&os) {} + +void dump_def_use_pass::run_on_function(function_node const &fn) { + auto dump_ir = dump_ir_pass(*os_, 0); + dump_ir.init_slot_tracker(fn); + + *os_ << "Def-use in " << fn.name() << std::endl; + walk(fn, [&](inst_node const &i) { + if (i.num_results() > 0 || i.num_child_regions() > 0) { + *os_ << "> "; + visit(dump_ir, i); + *os_ << std::endl; + auto const def_use = [&](value_node const &v) { + *os_ << " def "; + dump_ir.dump_val(v); + *os_ << std::endl; + for (auto &u : v.uses()) { + *os_ << " > "; + visit(dump_ir, *u.owner()); + *os_ << std::endl; + } + }; + for (auto &res : i.results()) { + def_use(res); + } + for (auto ® : i.child_regions()) { + for (auto &p : reg.params()) { + def_use(p); + } + } + } + }); + *os_ << std::endl; +} + +} // namespace tinytc diff --git a/src/pass/dump_def_use.hpp b/src/pass/dump_def_use.hpp new file mode 100644 index 00000000..ccec17cb --- /dev/null +++ b/src/pass/dump_def_use.hpp @@ -0,0 +1,25 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DUMP_DEF_USE_20241002_HPP +#define DUMP_DEF_USE_20241002_HPP + +#include "node/function_node.hpp" + +#include + +namespace tinytc { + +class dump_def_use_pass { + public: + dump_def_use_pass(std::ostream &os); + + void run_on_function(function_node const &fn); + + private: + std::ostream *os_; +}; + +} // namespace tinytc + +#endif // DUMP_DEF_USE_20241002_HPP diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 50fbac4d..28c0b5f6 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -195,7 +195,7 @@ void dump_ir_pass::operator()(expand_inst const &e) { *os_ << " x "; } if (is_dynamic_value(ses[i])) { - dump_val(*es[j++]); + dump_val(es[j++]); } else { *os_ << ses[i]; } @@ -219,7 +219,7 @@ void dump_ir_pass::operator()(load_inst const &e) { dump_val(e.operand()); *os_ << "["; do_with_infix(e.index_list().begin(), e.index_list().end(), - [this](auto const &i) { dump_val(*i); }); + [this](auto const &i) { dump_val(i); }); *os_ << "] : "; visit(*this, *e.operand().ty()); } @@ -350,7 +350,7 @@ void dump_ir_pass::operator()(subview_inst const &s) { } auto offset = s.static_offsets()[i]; if (is_dynamic_value(offset)) { - dump_val(*dyn_offsets[joffset++]); + dump_val(dyn_offsets[joffset++]); } else { *os_ << offset; } @@ -358,7 +358,7 @@ void dump_ir_pass::operator()(subview_inst const &s) { if (size > 0 || is_dynamic_value(size)) { *os_ << ":"; if (is_dynamic_value(size)) { - dump_val(*dyn_sizes[jsize++]); + dump_val(dyn_sizes[jsize++]); } else { *os_ << size; } @@ -378,7 +378,7 @@ void dump_ir_pass::operator()(store_inst const &e) { dump_val(e.operand()); *os_ << "["; do_with_infix(e.index_list().begin(), e.index_list().end(), - [this](auto const &i) { dump_val(*i); }); + [this](auto const &i) { dump_val(i); }); *os_ << "] : "; visit(*this, *e.operand().ty()); } @@ -392,10 +392,10 @@ void dump_ir_pass::operator()(sum_inst const &a) { void dump_ir_pass::operator()(yield_inst const &y) { *os_ << "yield "; if (y.num_operands() > 0) { - do_with_infix(y.op_begin(), y.op_end(), [this](auto const &i) { dump_val(*i); }, ", "); + do_with_infix(y.op_begin(), y.op_end(), [this](auto const &i) { dump_val(i); }, ", "); *os_ << " : "; do_with_infix( - y.op_begin(), y.op_end(), [this](auto const &i) { visit(*this, *i->ty()); }, ", "); + y.op_begin(), y.op_end(), [this](auto const &i) { visit(*this, *i.ty()); }, ", "); } else { *os_ << ":"; } @@ -419,8 +419,7 @@ void dump_ir_pass::dump_region(region_node const ®) { } void dump_ir_pass::run_on_function(function_node const &fn) { - tracker_ = slot_tracker{}; - tracker_.run_on_function(fn); + init_slot_tracker(fn); *os_ << "func @" << fn.name() << "("; std::string infix = ",\n "; @@ -449,4 +448,9 @@ void dump_ir_pass::run_on_function(function_node const &fn) { void dump_ir_pass::run_on_region(region_node const ®) { dump_region(reg); } void dump_ir_pass::run_on_instruction(inst_node const &in) { visit(*this, in); } +void dump_ir_pass::init_slot_tracker(function_node const &fn) { + tracker_ = slot_tracker{}; + tracker_.run_on_function(fn); +} + } // namespace tinytc diff --git a/src/pass/dump_ir.hpp b/src/pass/dump_ir.hpp index a6f070a9..788115da 100644 --- a/src/pass/dump_ir.hpp +++ b/src/pass/dump_ir.hpp @@ -65,11 +65,13 @@ class dump_ir_pass { void run_on_region(region_node const ®); void run_on_instruction(inst_node const &in); + void dump_val(value_node const &v); + void init_slot_tracker(function_node const &fn); + private: void dump_region(region_node const ®); void dump_blas_a2(blas_a2_inst const &g); void dump_blas_a3(blas_a3_inst const &g); - void dump_val(value_node const &v); template void do_with_infix(Iterator begin, Iterator end, Action a, std::string const &infix = ",") { diff --git a/src/pass/insert_lifetime_stop.cpp b/src/pass/insert_lifetime_stop.cpp index 3fe416eb..c27c68e6 100644 --- a/src/pass/insert_lifetime_stop.cpp +++ b/src/pass/insert_lifetime_stop.cpp @@ -37,8 +37,8 @@ auto insert_lifetime_stop_pass::run_on_region(region_node ®, aa_results const rgn_ops.merge(run_on_region(subreg, aa)); } for (auto &v : i.operands()) { - if (isa(*v->ty())) { - rgn_ops.insert(aa.root(*v)); + if (isa(*v.ty())) { + rgn_ops.insert(aa.root(v)); } } for (auto &v : i.results()) { diff --git a/src/passes.def b/src/passes.def index a5562f50..caedd283 100644 --- a/src/passes.def +++ b/src/passes.def @@ -4,6 +4,7 @@ FUNCTION_PASS("check-ir", check_ir_pass{}) FUNCTION_PASS("constant-propagation", constant_propagation_pass{}) FUNCTION_PASS("dump-control-flow-graph", dump_cfg_pass{std::cout}) +FUNCTION_PASS("dump-def-use", dump_def_use_pass{std::cout}) FUNCTION_PASS("dump-ir", dump_ir_pass{std::cout}) FUNCTION_PASS("insert-barrier", insert_barrier_pass{}) FUNCTION_PASS("insert-lifetime-stop", insert_lifetime_stop_pass{}) diff --git a/src/support/util.hpp b/src/support/util.hpp index e4414d18..1f1672a6 100644 --- a/src/support/util.hpp +++ b/src/support/util.hpp @@ -41,85 +41,104 @@ constexpr auto fnv1a(Head &&head, Tail &&...tail) -> std::uint64_t { template class iterator_range_wrapper { public: iterator_range_wrapper(ItT begin, ItT end) : begin_(std::move(begin)), end_(std::move(end)) {} - ItT begin() const { return begin_; } - ItT end() const { return end_; } + ItT begin() { return begin_; } + ItT end() { return end_; } private: ItT begin_, end_; }; -template class indirect_iterator : public IteratorT { +template class indirection_kind_deref { public: - using value_type = std::decay_t()))>; + using iterator_value_reference = decltype(*std::declval()); + using value_type = std::decay_t())>; + using reference = value_type &; using pointer = value_type *; + + static auto ref(iterator_value_reference v) -> reference { return *v; } + static auto ptr(iterator_value_reference v) -> pointer { return &ref(v); } +}; + +template class indirection_kind_get { + public: + using iterator_value_reference = decltype(*std::declval()); + using value_type = + std::remove_reference_t().get())>; using reference = value_type &; + using pointer = value_type *; - auto operator*() const -> reference { return *(this->IteratorT::operator*()); } - auto operator->() const -> pointer { return &*(this->IteratorT::operator*()); } - auto operator[](std::size_t n) const -> reference { return *(this->IteratorT::operator[](n)); } + static auto ref(iterator_value_reference v) -> reference { return *v.get(); } + static auto ptr(iterator_value_reference v) -> pointer { return v.get(); } }; -template class pointer_iterator { +template class IndirectionKind = indirection_kind_deref> +class indirect_random_access_iterator { public: - using value_type = T; + using value_type = typename IndirectionKind::value_type; using pointer = value_type *; using reference = value_type &; using difference_type = std::ptrdiff_t; - pointer_iterator() : ptr_{nullptr} {} - pointer_iterator(pointer ptr) : ptr_{std::move(ptr)} {} + indirect_random_access_iterator() : it_{nullptr} {} + indirect_random_access_iterator(IteratorT it) : it_{std::move(it)} {} - auto operator*() const -> reference { return *ptr_; } - auto operator->() const -> pointer { return ptr_; } - auto operator[](std::size_t n) const -> reference { return ptr_[n]; } - auto operator++() -> pointer_iterator & { - ++ptr_; + auto operator*() const -> reference { return IndirectionKind::ref(*it_); } + auto operator->() const -> pointer { return IndirectionKind::ptr(*it_); } + auto operator[](std::size_t n) const -> reference { + return IndirectionKind::ref(it_[n]); + } + auto operator++() -> indirect_random_access_iterator & { + ++it_; return *this; } - auto operator++(int) -> pointer_iterator { - auto tmp = ptr_++; - return pointer_iterator{tmp}; + auto operator++(int) -> indirect_random_access_iterator { + auto tmp = it_++; + return indirect_random_access_iterator{tmp}; } - auto operator--() -> pointer_iterator & { - --ptr_; + auto operator--() -> indirect_random_access_iterator & { + --it_; return *this; } - auto operator--(int) -> pointer_iterator { - auto tmp = ptr_--; - return pointer_iterator{tmp}; + auto operator--(int) -> indirect_random_access_iterator { + auto tmp = it_--; + return indirect_random_access_iterator{tmp}; } - auto operator-(pointer_iterator const &other) const -> difference_type { - return other.ptr_ - ptr_; + auto operator-(indirect_random_access_iterator const &other) const -> difference_type { + return it_ - other.it_; } - auto operator+=(std::ptrdiff_t n) -> pointer_iterator & { - ptr_ += n; + auto operator+=(std::ptrdiff_t n) -> indirect_random_access_iterator & { + it_ += n; return *this; } - auto operator-=(std::ptrdiff_t n) -> pointer_iterator & { - ptr_ -= n; + auto operator-=(std::ptrdiff_t n) -> indirect_random_access_iterator & { + it_ -= n; return *this; } - auto operator==(pointer_iterator const &other) const -> bool { return ptr_ == other.ptr_; } - auto operator<=>(pointer_iterator const &other) const -> bool { return ptr_ <=> other.ptr_; } + auto operator==(indirect_random_access_iterator const &other) const -> bool { + return it_ == other.it_; + } + auto operator<=>(indirect_random_access_iterator const &other) const -> bool { + return it_ <=> other.it_; + } private: - pointer ptr_; + IteratorT it_; }; -template -auto operator+(pointer_iterator const &p, std::ptrdiff_t n) -> pointer_iterator { - auto q = pointer_iterator{p}; +template class Kind> +auto operator+(indirect_random_access_iterator const &p, std::ptrdiff_t n) { + auto q = indirect_random_access_iterator{p}; return q += n; } -template -auto operator+(std::ptrdiff_t n, pointer_iterator const &p) -> pointer_iterator { +template class Kind> +auto operator+(std::ptrdiff_t n, indirect_random_access_iterator const &p) { return p + n; } -template -auto operator-(pointer_iterator const &p, std::ptrdiff_t n) -> pointer_iterator { - auto q = pointer_iterator{p}; +template class Kind> +auto operator-(indirect_random_access_iterator const &p, std::ptrdiff_t n) { + auto q = indirect_random_access_iterator{p}; return q -= n; } diff --git a/test/opt/dump-def-use.ir b/test/opt/dump-def-use.ir new file mode 100644 index 00000000..8e9dc3e5 --- /dev/null +++ b/test/opt/dump-def-use.ir @@ -0,0 +1,33 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-opt --dump-def-use < %s | filecheck %s + +func @foobar() { + %one = constant 1 -> index + %lb = constant 0 -> index + %ub = constant 5 -> index + for %i=%lb,%ub : index { + %1 = arith.add %i, %one : index + %2 = arith.rem %1, %one : index + } +; CHECK: Def-use in foobar +; CHECK-NEXT: > %one = constant 1 -> index +; CHECK-NEXT: def %one +; CHECK-NEXT: > %2 = arith.rem %1, %one : index +; CHECK-NEXT: > %1 = arith.add %i, %one : index +; CHECK-NEXT: > %lb = constant 0 -> index +; CHECK-NEXT: def %lb +; CHECK-NEXT: > for %i=%lb,%ub : index {...} +; CHECK-NEXT: > %ub = constant 5 -> index +; CHECK-NEXT: def %ub +; CHECK-NEXT: > for %i=%lb,%ub : index {...} +; CHECK-NEXT: > for %i=%lb,%ub : index {...} +; CHECK-NEXT: def %i +; CHECK-NEXT: > %1 = arith.add %i, %one : index +; CHECK-NEXT: > %1 = arith.add %i, %one : index +; CHECK-NEXT: def %1 +; CHECK-NEXT: > %2 = arith.rem %1, %one : index +; CHECK-NEXT: > %2 = arith.rem %1, %one : index +; CHECK-NEXT: def %2 +} From 494219b8391fb7c0be19bda9181a718edb4b589c Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 2 Oct 2024 18:19:31 +0200 Subject: [PATCH 038/297] Constant propagation Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 65 +++-- include/tinytc/types.h | 2 + include/tinytc/types.hpp | 2 + src/error.cpp | 6 +- src/node/inst_node.cpp | 74 +++-- src/node/inst_node.hpp | 89 +++--- src/node/region_node.cpp | 2 +- src/node/value_node.cpp | 4 +- src/node/value_node.hpp | 12 +- src/pass/constant_propagation.cpp | 342 ++++++++++++++--------- src/pass/constant_propagation_helper.hpp | 246 ++++++++++++++++ src/pass/convert_to_opencl.cpp | 2 +- src/pass/stack.cpp | 2 +- src/scalar_type.hpp | 4 + test/codegen/scalar_arithmetic_error.ir | 2 +- test/opt/constant-propagation.ir | 68 +++++ 16 files changed, 687 insertions(+), 235 deletions(-) create mode 100644 src/pass/constant_propagation_helper.hpp diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 7e0a08eb..14fc3f42 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -61,7 +61,7 @@ Constants .. code:: abnf - constant = floating-constant / integer-constant + constant = complex-constant / floating-constant / integer-constant integer-constant = "true" / "false" / [sign] 1*DIGIT sign = "-" / "+" floating-constant = [sign] *DIGIT "." 1*DIGIT ["e" [sign] 1*DIGIT] @@ -71,6 +71,7 @@ Constants floating-constant-dec = [sign] (mantissa-dec ["e" exponent] / 1*DIGIT "e" exponent) floating-constant-hex = [sign] "0x" (mantissa-hex ["p" exponent] / 1*HEXDIG "p" exponent) floating-constant = floating-constant-dec / floating-constant-hex + complex-constant = "[" floating-constant "," floating-constant "]" Integer constants must lie in the range :math:`-2^{63}+1,\dots,2^{63}-1`. @@ -138,11 +139,13 @@ Scalar types .. code:: abnf - scalar-type = integer-type / floating-type + scalar-type = integer-type / floating-type / complex-type integer-type = "i" ("1" / "8" / "16" / "32" / "64") / "index" floating-type = "f" ("32" / "64") + complex-type = "c" ("32" / "64") -Scalar types are either signless integer ("i") or floating point ("f"). +Scalar types are either signless integer ("i"), floating point ("f"), +or complex floating point ("c"). The number behind the scalar type prefix denotes the number of bits, e.g. "f64" are double precision floating point numbers. The "index" type is an integer type whose width is platform-specific. @@ -560,20 +563,23 @@ Overview Binary arithmetic operation on scalars. Both operands, as well as the returned type, have the same scalar type. -==== ============ ============================================================================== -Op Allowed type Description -==== ============ ============================================================================== -.add scalar-type Sum of operands -.sub scalar-type Difference of operands -.mul scalar-type Product of operands -.div scalar-type Quotient of operands -.rem scalar-type Remainder from the division of operands -.shl integer-type Left shift first operand by number of bits given by second operand -.shr integer-type Arithmetic right shift first operand by number of bits given by second operand -.and integer-type Bitwise and -.or integer-type Bitwise or -.xor integer-type Bitwise xor -==== ============ ============================================================================== +The following table shows the operations' description and the types that are allowed for the operation. +The backslash "\\" is used to exclude types from the list of allowed types. + +==== ============================ ================================================================ +Op Allowed type Description +==== ============================ ================================================================ +.add scalar-type Sum of operands +.sub scalar-type Difference of operands +.mul scalar-type Product of operands +.div scalar-type Quotient of operands +.rem scalar-type \\ complex-type Remainder from the division of operands +.shl integer-type \\ i1 Left shift first operand by second operand +.shr integer-type \\ i1 Arithmetic right shift first operand by second operand +.and integer-type Bitwise and +.or integer-type Bitwise or +.xor integer-type Bitwise xor +==== ============================ ================================================================ Arithmetic (unary) .................. @@ -589,6 +595,8 @@ Overview Unary arithmetic operation on scalars. The returned value has the same type as the operand. +The following table shows the operations' description and the types that are allowed for the operation. + ==== ============ ============================================================================== Op Allowed type Description ==== ============ ============================================================================== @@ -650,16 +658,19 @@ Overview Scalar comparison. Both operands must have the same scalar type and the returned value is boolean. -==== ===================== -Cond Description -==== ===================== -.eq Equal -.ne Not equal -.gt Greater than -.ge Greater than or equal -.lt Less than -.le Less than or equal -==== ===================== +The following table shows the comparisons' description and the types that are allowed for the comparison. +The backslash "\\" is used to exclude types from the list of allowed types. + +==== =========================== ===================== +Cond Allowed type Description +==== =========================== ===================== +.eq scalar-type Equal +.ne scalar-type Not equal +.gt scalar-type \\ complex-type Greater than +.ge scalar-type \\ complex-type Greater than or equal +.lt scalar-type \\ complex-type Less than +.le scalar-type \\ complex-type Less than or equal +==== =========================== ===================== Constant ........ diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 154d5312..046b03d9 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -65,6 +65,8 @@ typedef enum { tinytc_status_ir_expected_local_address_space = 0x114, ///< Expected local address space tinytc_status_ir_expected_global_address_space = 0x115, ///< Expected global address space tinytc_status_ir_invalid_offset = 0x116, ///< Invalid offset + tinytc_status_ir_i1_unsupported = 0x117, ///< Instruction does not support i1 type + tinytc_status_ir_complex_unsupported = 0x118, ///< Instruction does not support complex type // Level zero errors tinytc_status_ze_result_not_ready = 0x10000, ///< ZE_RESULT_NOT_READY tinytc_status_ze_result_error_device_lost = 0x10001, ///< ZE_RESULT_ERROR_DEVICE_LOST diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 21f600d0..d1fba78d 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -75,6 +75,8 @@ enum class status { ir_expected_local_address_space = tinytc_status_ir_expected_local_address_space, ir_expected_global_address_space = tinytc_status_ir_expected_global_address_space, ir_invalid_offset = tinytc_status_ir_invalid_offset, + ir_i1_unsupported = tinytc_status_ir_i1_unsupported, + ir_complex_unsupported = tinytc_status_ir_complex_unsupported, ze_result_not_ready = tinytc_status_ze_result_not_ready, ze_result_error_device_lost = tinytc_status_ze_result_error_device_lost, ze_result_error_out_of_host_memory = tinytc_status_ze_result_error_out_of_host_memory, diff --git a/src/error.cpp b/src/error.cpp index a8b91ee6..819301b8 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -156,7 +156,7 @@ char const *tinytc_error_string(tinytc_status_t status) { case tinytc_status_ir_collective_called_from_spmd: return "Collective instruction must not be called from SPMD region"; case tinytc_status_ir_fp_unsupported: - return "Floating point type unsupported for instruction"; + return "Floating point type unsupported by instruction"; case tinytc_status_ir_spmd_called_from_collective: return "SPMD instruction must not be called from collective region"; case tinytc_status_ir_expected_local_address_space: @@ -165,6 +165,10 @@ char const *tinytc_error_string(tinytc_status_t status) { return "A memref with global address space is expected"; case tinytc_status_ir_invalid_offset: return "Offset must be non-negative or dynamic"; + case tinytc_status_ir_i1_unsupported: + return "i1 type unsupported by instruction"; + case tinytc_status_ir_complex_unsupported: + return "complex type unsupported by instruction"; // Level Zero case tinytc_status_ze_result_not_ready: return "ZE_RESULT_NOT_READY"; diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index dd8cd818..f8170dc0 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -86,7 +86,7 @@ alloca_inst::alloca_inst(tinytc_data_type_t ty, location const &lc) : standard_inst{IK::alloca}, stack_ptr_{-1} { loc(lc); - result(0) = value_node{ty, lc}; + result(0) = value_node{ty, this, lc}; auto memref = dyn_cast(result(0).ty()); if (memref == nullptr) { throw compilation_error(loc(), status::ir_expected_memref); @@ -134,27 +134,41 @@ arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b if (at->ty() != bt->ty()) { throw compilation_error(loc(), status::ir_scalar_mismatch); } - bool inst_supports_fp = false; + bool inst_supports_fp = true; + bool inst_supports_complex = true; + bool inst_supports_i1 = true; switch (operation) { case arithmetic::add: case arithmetic::sub: case arithmetic::mul: case arithmetic::div: + break; case arithmetic::rem: - inst_supports_fp = true; + inst_supports_complex = false; break; - case arithmetic::shl: - case arithmetic::shr: case arithmetic::and_: case arithmetic::or_: case arithmetic::xor_: inst_supports_fp = false; + inst_supports_complex = false; + break; + case arithmetic::shl: + case arithmetic::shr: + inst_supports_i1 = false; + inst_supports_fp = false; + inst_supports_complex = false; break; } + if (!inst_supports_i1 && at->ty() == scalar_type::i1) { + throw compilation_error(loc(), status::ir_i1_unsupported); + } if (!inst_supports_fp && is_floating_type(at->ty())) { throw compilation_error(loc(), status::ir_fp_unsupported); } - result(0) = value_node{at}; + if (!inst_supports_complex && is_complex_type(at->ty())) { + throw compilation_error(loc(), status::ir_complex_unsupported); + } + result(0) = value_node{at, this, lc}; } arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0, @@ -164,19 +178,23 @@ arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0 loc(lc); auto at = get_scalar_type(loc(), a()); - bool inst_supports_fp = false; + bool inst_supports_fp = true; + bool inst_supports_complex = true; switch (operation) { case arithmetic_unary::neg: - inst_supports_fp = true; break; case arithmetic_unary::not_: inst_supports_fp = false; + inst_supports_complex = false; break; } if (!inst_supports_fp && is_floating_type(at->ty())) { throw compilation_error(loc(), status::ir_fp_unsupported); } - result(0) = value_node{at, lc}; + if (!inst_supports_complex && is_complex_type(at->ty())) { + throw compilation_error(loc(), status::ir_complex_unsupported); + } + result(0) = value_node{at, this, lc}; } cast_inst::cast_inst(tinytc_value_t a, tinytc_data_type_t to_ty, location const &lc) @@ -188,7 +206,7 @@ cast_inst::cast_inst(tinytc_value_t a, tinytc_data_type_t to_ty, location const throw compilation_error(lc, status::ir_expected_scalar); } - result(0) = value_node{to_ty, lc}; + result(0) = value_node{to_ty, this, lc}; } compare_inst::compare_inst(cmp_condition cond, tinytc_value_t a0, tinytc_value_t b0, @@ -205,8 +223,24 @@ compare_inst::compare_inst(cmp_condition cond, tinytc_value_t a0, tinytc_value_t throw compilation_error(loc(), status::ir_scalar_mismatch); } + bool inst_supports_complex = true; + switch (cond_) { + case cmp_condition::eq: + case cmp_condition::ne: + break; + case cmp_condition::gt: + case cmp_condition::ge: + case cmp_condition::lt: + case cmp_condition::le: + inst_supports_complex = false; + break; + } + if (!inst_supports_complex && is_complex_type(at->ty())) { + throw compilation_error(loc(), status::ir_complex_unsupported); + } + auto result_ty = scalar_data_type::get(at->context(), scalar_type::i1); - result(0) = value_node{result_ty, lc}; + result(0) = value_node{result_ty, this, lc}; } constant_inst::constant_inst(value_type const &value, tinytc_data_type_t ty, location const &lc) @@ -226,7 +260,7 @@ constant_inst::constant_inst(value_type const &value, tinytc_data_type_t ty, loc throw compilation_error(loc(), status::ir_expected_scalar); } - result(0) = value_node{ty, lc}; + result(0) = value_node{ty, this, lc}; } expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, @@ -277,7 +311,7 @@ expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, } auto result_ty = memref_data_type::get(m->element_data_ty(), shape, stride, m->addrspace()); - result(0) = value_node{result_ty, lc}; + result(0) = value_node{result_ty, this, lc}; } fuse_inst::fuse_inst(tinytc_value_t op0, std::int64_t from, std::int64_t to, location const &lc) @@ -313,7 +347,7 @@ fuse_inst::fuse_inst(tinytc_value_t op0, std::int64_t from, std::int64_t to, loc stride.push_back(m->stride(i)); } auto result_ty = memref_data_type::get(m->element_data_ty(), shape, stride, m->addrspace()); - result(0) = value_node{result_ty, lc}; + result(0) = value_node{result_ty, this, lc}; } load_inst::load_inst(tinytc_value_t op0, array_view index_list0, location const &lc) @@ -329,14 +363,14 @@ load_inst::load_inst(tinytc_value_t op0, array_view index_list0, if (static_cast(index_list().size()) != 1) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } - result(0) = value_node{g.ty(), lc}; + result(0) = value_node{g.ty(), this, lc}; }, [&](memref_data_type &m) { if (m.dim() != static_cast(index_list().size())) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } auto result_ty = scalar_data_type::get(m.context(), m.element_ty()); - result(0) = value_node{result_ty, lc}; + result(0) = value_node{result_ty, this, lc}; }, [&](auto &) { throw compilation_error(loc(), status::ir_expected_memref_or_group); }}, *operand().ty()); @@ -465,7 +499,7 @@ if_inst::if_inst(tinytc_value_t condition, array_view return op(0, condition); loc(lc); for (std::size_t i = 0; i < return_types.size(); ++i) { - result(i) = value_node{return_types[i], lc}; + result(i) = value_node{return_types[i], this, lc}; } } @@ -485,8 +519,8 @@ size_inst::size_inst(tinytc_value_t op0, std::int64_t mode, location const &lc) throw compilation_error(loc(), status::ir_out_of_bounds); } - auto result_ty = scalar_data_type::get(op(0)->context(), scalar_type::index); - result(0) = value_node{result_ty, lc}; + auto result_ty = scalar_data_type::get(op(0).context(), scalar_type::index); + result(0) = value_node{result_ty, this, lc}; } subview_inst::subview_inst(tinytc_value_t op0, array_view static_offsets0, @@ -536,7 +570,7 @@ subview_inst::subview_inst(tinytc_value_t op0, array_view static_o } auto result_ty = memref_data_type::get(m->element_data_ty(), shape, stride, m->addrspace()); - result(0) = value_node{result_ty, lc}; + result(0) = value_node{result_ty, this, lc}; } store_inst::store_inst(tinytc_value_t val0, tinytc_value_t op0, diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index f1d2f2bc..b6d618d8 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -130,9 +130,11 @@ struct tinytc_inst : tinytc::ilist_node_with_parent inline auto op_begin() const -> const_op_iterator { return {op_begin_}; } inline auto op_end() const -> const_op_iterator { return {op_end_}; } inline auto operands() const -> const_op_range { return {op_begin(), op_end()}; } - inline auto op(std::size_t pos) -> tinytc_value_t { return op_begin_[pos].get(); } - inline auto op(std::size_t pos) const -> tinytc_value_t { return op_begin_[pos].get(); } + inline auto op(std::size_t pos) -> tinytc_value & { return *op_begin_[pos].get(); } + inline auto op(std::size_t pos) const -> tinytc_value const & { return *op_begin_[pos].get(); } inline void op(std::size_t pos, tinytc_value_t val) { op_begin_[pos] = val; } + inline auto get_use(std::size_t pos) -> tinytc::use & { return op_begin_[pos]; } + inline auto get_use(std::size_t pos) const -> tinytc::use const & { return op_begin_[pos]; } inline auto num_operands() const -> std::int64_t { return op_end_ - op_begin_; } // Iterator over results @@ -292,7 +294,8 @@ class standard_inst : public inst_node { } } if (num_results > 0) { - set_result_range(results_.get(), results_.get() + num_results); + auto *result_begin = results_.get(); + set_result_range(result_begin, result_begin + num_results); } if (num_child_regions > 0) { set_child_regions_range(child_regions_.get(), child_regions_.get() + num_child_regions); @@ -316,10 +319,10 @@ class blas_a2_inst : public standard_inst<4, 0> { inline bool atomic() const { return atomic_; } inline void atomic(bool a) { atomic_ = a; } - inline auto alpha() const -> tinytc_value const & { return *op(op_alpha); } - inline auto A() const -> tinytc_value const & { return *op(op_A); } - inline auto beta() const -> tinytc_value const & { return *op(op_beta); } - inline auto B() const -> tinytc_value const & { return *op(op_B); } + inline auto alpha() const -> tinytc_value const & { return op(op_alpha); } + inline auto A() const -> tinytc_value const & { return op(op_A); } + inline auto beta() const -> tinytc_value const & { return op(op_beta); } + inline auto B() const -> tinytc_value const & { return op(op_B); } protected: bool atomic_; @@ -336,16 +339,16 @@ class blas_a3_inst : public standard_inst<5, 0> { inline bool atomic() const { return atomic_; } inline void atomic(bool a) { atomic_ = a; } - inline auto alpha() -> tinytc_value & { return *op(op_alpha); } - inline auto alpha() const -> tinytc_value const & { return *op(op_alpha); } - inline auto A() -> tinytc_value & { return *op(op_A); } - inline auto A() const -> tinytc_value const & { return *op(op_A); } - inline auto B() -> tinytc_value & { return *op(op_B); } - inline auto B() const -> tinytc_value const & { return *op(op_B); } - inline auto beta() -> tinytc_value & { return *op(op_beta); } - inline auto beta() const -> tinytc_value const & { return *op(op_beta); } - inline auto C() -> tinytc_value & { return *op(op_C); } - inline auto C() const -> tinytc_value const & { return *op(op_C); } + inline auto alpha() -> tinytc_value & { return op(op_alpha); } + inline auto alpha() const -> tinytc_value const & { return op(op_alpha); } + inline auto A() -> tinytc_value & { return op(op_A); } + inline auto A() const -> tinytc_value const & { return op(op_A); } + inline auto B() -> tinytc_value & { return op(op_B); } + inline auto B() const -> tinytc_value const & { return op(op_B); } + inline auto beta() -> tinytc_value & { return op(op_beta); } + inline auto beta() const -> tinytc_value const & { return op(op_beta); } + inline auto C() -> tinytc_value & { return op(op_C); } + inline auto C() const -> tinytc_value const & { return op(op_C); } protected: bool atomic_; @@ -359,10 +362,10 @@ class loop_inst : public standard_inst<3, 0, 1> { enum op_number { op_from = 0, op_to = 1, op_step = 2 }; loop_inst(IK tid, tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, tinytc_data_type_t loop_var_type, location const &loc = {}); - inline auto from() const -> tinytc_value const & { return *op(op_from); } - inline auto to() const -> tinytc_value const & { return *op(op_to); } - inline auto has_step() const -> bool { return op(op_step) != nullptr; } - inline auto step() const -> tinytc_value const & { return *op(op_step); } + inline auto from() const -> tinytc_value const & { return op(op_from); } + inline auto to() const -> tinytc_value const & { return op(op_to); } + inline auto has_step() const -> bool { return get_use(op_step).get() != nullptr; } + inline auto step() const -> tinytc_value const & { return op(op_step); } inline auto body() -> tinytc_region & { return child_region(0); } inline auto body() const -> tinytc_region const & { return child_region(0); } inline auto loop_var() -> tinytc_value & { return body().param(0); } @@ -400,8 +403,8 @@ class arith_inst : public standard_inst<2, 1> { arith_inst(arithmetic op, tinytc_value_t a, tinytc_value_t b, location const &lc = {}); inline arithmetic operation() const { return operation_; } - inline auto a() const -> tinytc_value const & { return *op(op_a); } - inline auto b() const -> tinytc_value const & { return *op(op_b); } + inline auto a() const -> tinytc_value const & { return op(op_a); } + inline auto b() const -> tinytc_value const & { return op(op_b); } private: arithmetic operation_; @@ -414,7 +417,7 @@ class arith_unary_inst : public standard_inst<1, 1> { arith_unary_inst(arithmetic_unary op, tinytc_value_t a, location const &lc = {}); inline arithmetic_unary operation() const { return operation_; } - inline auto a() const -> tinytc_value const & { return *op(op_a); } + inline auto a() const -> tinytc_value const & { return op(op_a); } private: arithmetic_unary operation_; @@ -443,7 +446,7 @@ class cast_inst : public standard_inst<1, 1> { inline static bool classof(inst_node const &i) { return i.type_id() == IK::cast; } enum op_number { op_a = 0 }; cast_inst(tinytc_value_t a, tinytc_data_type_t to_ty, location const &lc = {}); - inline auto a() const -> tinytc_value const & { return *op(op_a); } + inline auto a() const -> tinytc_value const & { return op(op_a); } }; class compare_inst : public standard_inst<2, 1> { @@ -453,8 +456,8 @@ class compare_inst : public standard_inst<2, 1> { compare_inst(cmp_condition cond, tinytc_value_t a, tinytc_value_t b, location const &lc = {}); inline cmp_condition cond() const { return cond_; } - inline auto a() const -> tinytc_value const & { return *op(op_a); } - inline auto b() const -> tinytc_value const & { return *op(op_b); } + inline auto a() const -> tinytc_value const & { return op(op_a); } + inline auto b() const -> tinytc_value const & { return op(op_b); } private: cmp_condition cond_; @@ -485,10 +488,10 @@ class expand_inst : public standard_inst { return static_expand_shape_; } - inline auto operand() const -> tinytc_value const & { return *op(0); } + inline auto operand() const -> tinytc_value const & { return op(0); } inline auto expand_shape() { return operands() | std::views::drop(1); } inline auto expand_shape() const { return operands() | std::views::drop(1); } - inline auto expand_shape(std::int64_t i) const -> tinytc_value const & { return *op(i + 1); } + inline auto expand_shape(std::int64_t i) const -> tinytc_value const & { return op(i + 1); } private: std::int64_t expanded_mode_; @@ -500,7 +503,7 @@ class fuse_inst : public standard_inst<1, 1> { inline static bool classof(inst_node const &i) { return i.type_id() == IK::fuse; } fuse_inst(tinytc_value_t op, std::int64_t from, std::int64_t to, location const &lc = {}); - inline auto operand() const -> tinytc_value const & { return *op(0); } + inline auto operand() const -> tinytc_value const & { return op(0); } inline std::int64_t from() const { return from_; } inline std::int64_t to() const { return to_; } @@ -513,7 +516,7 @@ class load_inst : public standard_inst { inline static bool classof(inst_node const &i) { return i.type_id() == IK::load; } load_inst(tinytc_value_t op, array_view index_list, location const &lc = {}); - inline auto operand() const -> tinytc_value const & { return *op(0); } + inline auto operand() const -> tinytc_value const & { return op(0); } inline auto index_list() const { return operands() | std::views::drop(1); } }; @@ -523,7 +526,7 @@ class group_id_inst : public standard_inst<0, 1> { inline group_id_inst(tinytc_compiler_context_t ctx, location const &lc = {}) : standard_inst{IK::group_id} { loc(lc); - result(0) = value_node{scalar_data_type::get(ctx, scalar_type::index), lc}; + result(0) = value_node{scalar_data_type::get(ctx, scalar_type::index), this, lc}; } }; @@ -533,7 +536,7 @@ class group_size_inst : public standard_inst<0, 1> { inline group_size_inst(tinytc_compiler_context_t ctx, location const &lc = {}) : standard_inst{IK::group_size} { loc(lc); - result(0) = value_node{scalar_data_type::get(ctx, scalar_type::index), lc}; + result(0) = value_node{scalar_data_type::get(ctx, scalar_type::index), this, lc}; } }; @@ -541,7 +544,7 @@ class lifetime_stop_inst : public standard_inst<1, 0> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::lifetime_stop; } inline lifetime_stop_inst(tinytc_value_t obj) : standard_inst{IK::lifetime_stop} { op(0, obj); } - inline auto object() const -> tinytc_value const & { return *op(0); } + inline auto object() const -> tinytc_value const & { return op(0); } }; class gemm_inst : public blas_a3_inst { @@ -608,7 +611,7 @@ class if_inst : public standard_inst<1, dynamic, 2> { enum child_region_number { child_region_then = 0, child_region_otherwise = 1 }; if_inst(tinytc_value_t condition, array_view return_types = {}, location const &lc = {}); - inline auto condition() const -> tinytc_value const & { return *op(0); } + inline auto condition() const -> tinytc_value const & { return op(0); } inline auto then() -> tinytc_region & { return child_region(child_region_then); } inline auto then() const -> tinytc_region const & { return child_region(child_region_then); } inline auto otherwise() -> tinytc_region & { return child_region(child_region_otherwise); } @@ -624,7 +627,7 @@ class num_subgroups_inst : public standard_inst<0, 1> { inline num_subgroups_inst(tinytc_compiler_context_t ctx, location const &lc = {}) : standard_inst{IK::num_subgroups} { loc(lc); - result(0) = value_node{scalar_data_type::get(ctx, scalar_type::i32), lc}; + result(0) = value_node{scalar_data_type::get(ctx, scalar_type::i32), this, lc}; } }; @@ -642,7 +645,7 @@ class size_inst : public standard_inst<1, 1> { inline static bool classof(inst_node const &i) { return i.type_id() == IK::size; } size_inst(tinytc_value_t op, std::int64_t mode, location const &lc = {}); - inline auto operand() const -> tinytc_value const & { return *op(0); } + inline auto operand() const -> tinytc_value const & { return op(0); } inline std::int64_t mode() const { return mode_; } private: @@ -655,7 +658,7 @@ class subgroup_id_inst : public standard_inst<0, 1> { inline subgroup_id_inst(tinytc_compiler_context_t ctx, location const &lc = {}) : standard_inst{IK::subgroup_id} { loc(lc); - result(0) = value_node{scalar_data_type::get(ctx, scalar_type::i32), lc}; + result(0) = value_node{scalar_data_type::get(ctx, scalar_type::i32), this, lc}; } }; @@ -665,7 +668,7 @@ class subgroup_local_id_inst : public standard_inst<0, 1> { inline subgroup_local_id_inst(tinytc_compiler_context_t ctx, location const &lc = {}) : standard_inst{IK::subgroup_local_id} { loc(lc); - result(0) = value_node{scalar_data_type::get(ctx, scalar_type::i32), lc}; + result(0) = value_node{scalar_data_type::get(ctx, scalar_type::i32), this, lc}; } }; @@ -675,7 +678,7 @@ class subgroup_size_inst : public standard_inst<0, 1> { inline subgroup_size_inst(tinytc_compiler_context_t ctx, location const &lc = {}) : standard_inst{IK::subgroup_size} { loc(lc); - result(0) = value_node{scalar_data_type::get(ctx, scalar_type::i32), lc}; + result(0) = value_node{scalar_data_type::get(ctx, scalar_type::i32), this, lc}; } }; @@ -689,7 +692,7 @@ class subview_inst : public standard_inst { inline auto static_offsets() const -> array_view { return static_offsets_; } inline auto static_sizes() const -> array_view { return static_sizes_; } - inline auto operand() const -> tinytc_value const & { return *op(0); } + inline auto operand() const -> tinytc_value const & { return op(0); } inline auto offsets() const { return operands() | std::views::drop(1) | std::views::take(num_dyn_offsets_); } @@ -707,8 +710,8 @@ class store_inst : public standard_inst { store_inst(tinytc_value_t val, tinytc_value_t op, array_view index_list, location const &lc = {}); - inline auto val() const -> tinytc_value const & { return *op(op_val); } - inline auto operand() const -> tinytc_value const & { return *op(op_operand); } + inline auto val() const -> tinytc_value const & { return op(op_val); } + inline auto operand() const -> tinytc_value const & { return op(op_operand); } inline auto index_list() const { return operands() | std::views::drop(2); } }; diff --git a/src/node/region_node.cpp b/src/node/region_node.cpp index 82cb0e55..0bc6891c 100644 --- a/src/node/region_node.cpp +++ b/src/node/region_node.cpp @@ -31,6 +31,6 @@ tinytc_region::~tinytc_region() {} void tinytc_region::set_params(array_view param_types, location const &lc) { params_.resize(param_types.size()); for (std::size_t i = 0; i < param_types.size(); ++i) { - params_[i] = tinytc_value{param_types[i], lc}; + params_[i] = tinytc_value{param_types[i], nullptr, lc}; } } diff --git a/src/node/value_node.cpp b/src/node/value_node.cpp index 8b6b1357..0954a43d 100644 --- a/src/node/value_node.cpp +++ b/src/node/value_node.cpp @@ -5,8 +5,8 @@ using namespace tinytc; -tinytc_value::tinytc_value(tinytc_data_type_t ty, location const &lc) - : ty_{std::move(ty)}, loc_{lc} {} +tinytc_value::tinytc_value(tinytc_data_type_t ty, tinytc_inst_t def_inst, location const &lc) + : ty_{std::move(ty)}, loc_{lc}, def_inst_{def_inst} {} auto tinytc_value::use_begin() -> use_iterator { return {first_use_}; } auto tinytc_value::use_end() -> use_iterator { return {nullptr}; } diff --git a/src/node/value_node.hpp b/src/node/value_node.hpp index 615cb913..91ae9d3c 100644 --- a/src/node/value_node.hpp +++ b/src/node/value_node.hpp @@ -22,7 +22,13 @@ class const_use_iterator; struct tinytc_value final { public: - tinytc_value(tinytc_data_type_t ty = nullptr, tinytc::location const &lc = {}); + tinytc_value(tinytc_data_type_t ty = nullptr, tinytc_inst_t def_inst_ = nullptr, + tinytc::location const &lc = {}); + + tinytc_value(tinytc_value const &) = delete; + tinytc_value(tinytc_value &&) = default; + tinytc_value &operator=(tinytc_value const &) = delete; + tinytc_value &operator=(tinytc_value &&) = default; inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } @@ -42,9 +48,13 @@ struct tinytc_value final { auto use_end() const -> tinytc::const_use_iterator; auto uses() const -> tinytc::iterator_range_wrapper; + // Can be nullptr, e.g. if value is a region parameter + auto defining_inst() const -> tinytc_inst_t { return def_inst_; } + private: tinytc_data_type_t ty_; tinytc::location loc_; + tinytc_inst_t def_inst_ = nullptr; std::string name_; friend class tinytc::use; diff --git a/src/pass/constant_propagation.cpp b/src/pass/constant_propagation.cpp index 61682f1b..be2b9ed9 100644 --- a/src/pass/constant_propagation.cpp +++ b/src/pass/constant_propagation.cpp @@ -6,19 +6,133 @@ #include "node/function_node.hpp" #include "node/inst_node.hpp" #include "node/region_node.hpp" +#include "pass/constant_propagation_helper.hpp" +#include "scalar_type.hpp" #include "support/casting.hpp" #include "support/visit.hpp" #include "support/walk.hpp" #include "tinytc/tinytc.hpp" #include +#include +#include namespace tinytc { +template class unary_op_dispatcher { + private: + scalar_type switch_ty; + F computer; + + public: + unary_op_dispatcher(scalar_type sw_ty, F &&f) + : switch_ty{sw_ty}, computer{std::forward(f)} {} + + auto operator()(std::int64_t const &A) -> inst { + switch (switch_ty) { + case scalar_type::i1: + return computer.template operator()(A); + case scalar_type::i8: + return computer.template operator()(A); + case scalar_type::i16: + return computer.template operator()(A); + case scalar_type::i32: + return computer.template operator()(A); + case scalar_type::i64: + return computer.template operator()(A); + case scalar_type::index: + return computer.template operator()(A); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + }; + } + auto operator()(double const &A) -> inst { + switch (switch_ty) { + case scalar_type::f32: + return computer.template operator()(A); + case scalar_type::f64: + return computer.template operator()(A); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + } + } + auto operator()(std::complex const &A) -> inst { + switch (switch_ty) { + case scalar_type::c32: + return computer.template operator()>(A); + case scalar_type::c64: + return computer.template operator()>(A); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + } + } +}; + +template class binary_op_dispatcher { + private: + scalar_type switch_ty; + F computer; + + public: + binary_op_dispatcher(scalar_type sw_ty, F &&f) + : switch_ty{sw_ty}, computer{std::forward(f)} {} + + auto operator()(std::int64_t const &A, std::int64_t const &B) -> inst { + switch (switch_ty) { + case scalar_type::i1: + return computer.template operator()(A, B); + case scalar_type::i8: + return computer.template operator()(A, B); + case scalar_type::i16: + return computer.template operator()(A, B); + case scalar_type::i32: + return computer.template operator()(A, B); + case scalar_type::i64: + return computer.template operator()(A, B); + case scalar_type::index: + return computer.template operator()(A, B); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + }; + } + auto operator()(double const &A, double const &B) -> inst { + switch (switch_ty) { + case scalar_type::f32: + return computer.template operator()(A, B); + case scalar_type::f64: + return computer.template operator()(A, B); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + } + } + auto operator()(std::complex const &A, std::complex const &B) -> inst { + switch (switch_ty) { + case scalar_type::c32: + return computer.template operator()>(A, B); + case scalar_type::c64: + return computer.template operator()>(A, B); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + } + } + template auto operator()(T const &, U const &) -> inst { + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + } +}; + class constant_evaluator { public: auto operator()(inst_node &) -> inst; - // auto operator()(arith_inst &) -> inst; + auto operator()(arith_inst &) -> inst; + auto operator()(arith_unary_inst &) -> inst; + auto operator()(cast_inst &) -> inst; + auto operator()(compare_inst &) -> inst; auto operator()(size_inst &in) -> inst; private: @@ -33,17 +147,81 @@ auto constant_evaluator::get_memref_type(value_node const &v) const -> const mem return t; } -/* Inst nodes */ -auto constant_evaluator::operator()(inst_node &in) -> inst { - // for (auto &op : in.operands()) { - // if (op) { - // uintptr_t u = std::bit_cast(op.get()); - // if (auto kc = known_constants_.find(u); kc != known_constants_.end()) { - // op = kc->second; - //} - //} - //} - return inst{}; +auto constant_evaluator::operator()(inst_node &) -> inst { return {}; } + +auto constant_evaluator::operator()(arith_inst &in) -> inst { + auto &op_a = in.a(); + auto &op_b = in.b(); + + constant_inst *a_const = dyn_cast(op_a.defining_inst()); + constant_inst *b_const = dyn_cast(op_b.defining_inst()); + if (a_const == nullptr || b_const == nullptr) { + return inst{}; + } + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_scalar); + } + + auto computer = compute_binary_op{in.operation(), op_a.ty(), in.loc()}; + auto dispatcher = binary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), a_const->value(), b_const->value()); +} + +auto constant_evaluator::operator()(arith_unary_inst &in) -> inst { + auto &op_a = in.a(); + + constant_inst *a_const = dyn_cast(op_a.defining_inst()); + if (a_const == nullptr) { + return inst{}; + } + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_scalar); + } + + auto computer = compute_unary_op{in.operation(), op_a.ty(), in.loc()}; + auto dispatcher = unary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), a_const->value()); +} + +auto constant_evaluator::operator()(cast_inst &in) -> inst { + auto &op_a = in.a(); + + constant_inst *a_const = dyn_cast(op_a.defining_inst()); + if (a_const == nullptr) { + return inst{}; + } + + auto rt = dyn_cast(in.result(0).ty()); + if (rt == nullptr) { + throw compilation_error(in.result(0).loc(), status::ir_expected_scalar); + } + + return std::visit(overloaded{[&](auto A) -> inst { return compute_cast(rt, A, in.loc()); }}, + a_const->value()); +} + +auto constant_evaluator::operator()(compare_inst &in) -> inst { + auto &op_a = in.a(); + auto &op_b = in.b(); + + constant_inst *a_const = dyn_cast(op_a.defining_inst()); + constant_inst *b_const = dyn_cast(op_b.defining_inst()); + if (a_const == nullptr || b_const == nullptr) { + return inst{}; + } + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_scalar); + } + + auto computer = compute_compare{in.cond(), in.result(0).ty(), in.loc()}; + auto dispatcher = binary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), a_const->value(), b_const->value()); } auto constant_evaluator::operator()(size_inst &in) -> inst { @@ -58,137 +236,27 @@ auto constant_evaluator::operator()(size_inst &in) -> inst { return inst{}; } -/*auto constant_propagation::operator()(arith_inst &arith) -> inst { - this->operator()(static_cast(arith)); - - auto &a = arith.a(); - auto &b = arith.b(); - - auto at = dyn_cast(a->ty().get()); - if (at == nullptr) { - throw compilation_error(a->loc(), status::ir_expected_scalar); - } - - if (is_floating_type(at->ty())) { - auto av = dyn_cast(a.get()); - auto bv = dyn_cast(b.get()); - if (av != nullptr && bv != nullptr) { - auto const compute = [&arith](auto a, auto b) { - switch (arith.operation()) { - case arithmetic::add: - return a + b; - case arithmetic::sub: - return a - b; - case arithmetic::mul: - return a * b; - case arithmetic::div: - return a / b; - case arithmetic::rem: - return std::fmod(a, b); - default: - break; - } - throw compilation_error(arith.loc(), status::ir_fp_unsupported); - }; - - auto constant_val = value{}; - switch (at->ty()) { - case scalar_type::f32: - constant_val = make_imm( - compute(static_cast(av->value()), static_cast(bv->value())), - scalar_type::f32, arith.loc()); - break; - case scalar_type::f64: - constant_val = - make_imm(compute(av->value(), bv->value()), scalar_type::f64, arith.loc()); - break; - default: - break; - }; - if (constant_val) { - uintptr_t u = std::bit_cast(arith.result().get()); - known_constants_[u] = std::move(constant_val); - } - } - } else { - auto av = dyn_cast(a.get()); - auto bv = dyn_cast(b.get()); - if (av != nullptr && bv != nullptr) { - auto const compute = [&arith](auto a, auto b) { - switch (arith.operation()) { - case arithmetic::add: - return a + b; - case arithmetic::sub: - return a - b; - case arithmetic::mul: - return a * b; - case arithmetic::div: - return a / b; - case arithmetic::rem: - return a % b; - case arithmetic::shl: - return a << b; - case arithmetic::shr: - return a >> b; - case arithmetic::and_: - return a & b; - case arithmetic::or_: - return a | b; - case arithmetic::xor_: - return a ^ b; - } - throw compilation_error(arith.loc(), status::runtime_error); - }; - - auto constant_val = value{}; - switch (at->ty()) { - case scalar_type::i1: { - bool const val = - compute(static_cast(av->value()), static_cast(bv->value())); - constant_val = - make_imm(static_cast(val), scalar_type::i1, arith.loc()); - break; - } - case scalar_type::i8: - constant_val = make_imm(compute(static_cast(av->value()), - static_cast(bv->value())), - arith.loc()); - break; - case scalar_type::i16: - constant_val = make_imm(compute(static_cast(av->value()), - static_cast(bv->value())), - arith.loc()); - break; - case scalar_type::i32: - constant_val = make_imm(compute(static_cast(av->value()), - static_cast(bv->value())), - arith.loc()); - break; - case scalar_type::i64: - constant_val = - make_imm(compute(av->value(), bv->value()), scalar_type::i64, arith.loc()); - break; - case scalar_type::index: - constant_val = - make_imm(compute(av->value(), bv->value()), scalar_type::index, arith.loc()); - break; - default: - break; - }; - if (constant_val) { - uintptr_t u = std::bit_cast(arith.result().get()); - known_constants_[u] = std::move(constant_val); - } - } - } -}*/ - void constant_propagation_pass::run_on_function(function_node &fn) { walk(fn, [&](region_node ®) { for (auto it = reg.begin(); it != reg.end(); ++it) { auto known_constant = visit(constant_evaluator{}, *it); if (known_constant) { + // update uses + if (it->num_results() != known_constant->num_results()) { + throw status::internal_compiler_error; + } + auto r_old = it->result_begin(); + auto r_new = known_constant->result_begin(); + for (; r_old != it->result_end() && r_new != known_constant->result_end(); + ++r_old, ++r_new) { + r_new->name(r_old->name()); + for (auto &u : r_old->uses()) { + u.set(&*r_new); + } + } + // delete old instruction it = reg.insts().erase(it); + // insert new instruction it = reg.insts().insert(it, known_constant.release()); } } diff --git a/src/pass/constant_propagation_helper.hpp b/src/pass/constant_propagation_helper.hpp new file mode 100644 index 00000000..c215830c --- /dev/null +++ b/src/pass/constant_propagation_helper.hpp @@ -0,0 +1,246 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CONSTANT_PROPAGATION_HELPER_20241002_HPP +#define CONSTANT_PROPAGATION_HELPER_20241002_HPP + +#include "scalar_type.hpp" +#include "tinytc/tinytc.hpp" + +#include +#include + +namespace tinytc { + +struct compute_unary_op { + arithmetic_unary operation; + data_type ty; + location const &loc; + + template + requires(std::is_integral_v) + auto operator()(T a) { + T val = 0; + switch (operation) { + case arithmetic_unary::neg: + val = -a; + break; + case arithmetic_unary::not_: + if constexpr (std::is_same_v) { + val = !a; + } else { + val = ~a; + } + break; + } + return make_constant(val, ty, loc); + } + + template + requires(!std::is_integral_v) + auto operator()(U const &A) -> inst { + const auto a = static_cast(A); + T val = {}; + switch (operation) { + case arithmetic_unary::neg: + val = -a; + break; + default: + if constexpr (!std::is_floating_point_v) { + throw compilation_error(loc, status::ir_complex_unsupported); + } + throw compilation_error(loc, status::ir_fp_unsupported); + break; + } + return make_constant(val, ty, loc); + } +}; + +struct compute_binary_op { + arithmetic operation; + data_type ty; + location const &loc; + + template + requires(std::is_integral_v) + auto operator()(T a, T b) { + T val = 0; + switch (operation) { + case arithmetic::add: + val = a + b; + break; + case arithmetic::sub: + val = a - b; + break; + case arithmetic::mul: + val = a * b; + break; + case arithmetic::div: + val = a / b; + break; + case arithmetic::rem: + val = a % b; + break; + case arithmetic::shl: + if constexpr (std::is_same_v) { + throw compilation_error(loc, status::ir_i1_unsupported); + } else { + val = a << b; + } + break; + case arithmetic::shr: + if constexpr (std::is_same_v) { + throw compilation_error(loc, status::ir_i1_unsupported); + } else { + val = a >> b; + } + break; + case arithmetic::and_: + val = a & b; + break; + case arithmetic::or_: + val = a | b; + break; + case arithmetic::xor_: + val = a ^ b; + break; + } + return make_constant(val, ty, loc); + } + + template + requires(!std::is_integral_v) + auto operator()(U const &A, U const &B) -> inst { + const auto a = static_cast(A); + const auto b = static_cast(B); + T val = {}; + switch (operation) { + case arithmetic::add: + val = a + b; + break; + case arithmetic::sub: + val = a - b; + break; + case arithmetic::mul: + val = a * b; + break; + case arithmetic::div: + val = a / b; + break; + case arithmetic::rem: + if constexpr (!std::is_floating_point_v) { + throw compilation_error(loc, status::ir_complex_unsupported); + } else { + val = std::fmod(a, b); + } + break; + default: + if constexpr (!std::is_floating_point_v) { + throw compilation_error(loc, status::ir_complex_unsupported); + } + throw compilation_error(loc, status::ir_fp_unsupported); + break; + } + return make_constant(val, ty, loc); + } +}; + +struct compute_compare { + cmp_condition cond; + data_type ty; + location const &loc; + + template + requires(std::is_integral_v || std::is_floating_point_v) + auto operator()(T a, T b) { + bool val = false; + switch (cond) { + case cmp_condition::eq: + val = (a == b); + break; + case cmp_condition::ne: + val = (a != b); + break; + case cmp_condition::gt: + val = (a > b); + break; + case cmp_condition::ge: + val = (a >= b); + break; + case cmp_condition::lt: + val = (a < b); + break; + case cmp_condition::le: + val = (a <= b); + break; + }; + return make_constant(val, ty, loc); + } + + template + auto operator()(std::complex const &A, std::complex const &B) { + const auto a = static_cast(A); + const auto b = static_cast(B); + bool val = false; + switch (cond) { + case cmp_condition::eq: + val = (a == b); + break; + case cmp_condition::ne: + val = (a != b); + break; + default: + throw compilation_error(loc, status::ir_complex_unsupported); + break; + }; + return make_constant(val, ty, loc); + } +}; + +template struct value_cast_impl { + auto operator()(U const &u) { return static_cast(u); } +}; + +template struct value_cast_impl { + auto operator()(U const &u) { return u != U{}; } +}; + +template struct value_cast_impl> { + auto operator()(std::complex const &u) { return u != std::complex{}; } +}; + +template struct value_cast_impl> { + auto operator()(std::complex const &u) { return static_cast(u.real()); } +}; + +template auto value_cast(U const &u) { return value_cast_impl{}(u); } + +template auto compute_cast(scalar_data_type *to_ty, T A, location const &loc) -> inst { + switch (to_ty->ty()) { + case scalar_type::i1: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::i8: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::i16: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::i32: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::i64: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::index: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::f32: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::f64: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::c32: + return make_constant(value_cast>(A), to_ty, loc); + case scalar_type::c64: + return make_constant(value_cast>(A), to_ty, loc); + }; + return {}; +}; + +} // namespace tinytc + +#endif // CONSTANT_PROPAGATION_HELPER_20241002_HPP diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index a106c13e..3a617337 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -1056,7 +1056,7 @@ std::vector convert_to_opencl_pass::operator()(yield_inst const &in) } std::vector clinst; for (std::int64_t i = 0; i < in.num_operands(); ++i) { - auto assign_yielded_var = clir::assignment(yielded_vars_.back()[i], val(*in.op(i))); + auto assign_yielded_var = clir::assignment(yielded_vars_.back()[i], val(in.op(i))); clinst.push_back(clir::expression_statement(std::move(assign_yielded_var))); } return clinst; diff --git a/src/pass/stack.cpp b/src/pass/stack.cpp index 0ca9601e..9c233afe 100644 --- a/src/pass/stack.cpp +++ b/src/pass/stack.cpp @@ -46,7 +46,7 @@ void set_stack_ptr_pass::run_on_function(function_node &fn) { }, [&allocs](lifetime_stop_inst &s) { int num = 0; - auto v = s.object(); + auto &v = s.object(); for (auto it = allocs.begin(); it != allocs.end();) { if (it->value == &v) { it = allocs.erase(it); diff --git a/src/scalar_type.hpp b/src/scalar_type.hpp index 7380a957..6f88e771 100644 --- a/src/scalar_type.hpp +++ b/src/scalar_type.hpp @@ -9,8 +9,12 @@ #include #include +#include + namespace tinytc { +using host_index_type = std::int64_t; + bool is_floating_type(scalar_type ty); bool is_complex_type(scalar_type ty); bool is_integer_type(scalar_type ty); diff --git a/test/codegen/scalar_arithmetic_error.ir b/test/codegen/scalar_arithmetic_error.ir index 82c34a9a..e51c6a0c 100644 --- a/test/codegen/scalar_arithmetic_error.ir +++ b/test/codegen/scalar_arithmetic_error.ir @@ -6,5 +6,5 @@ func @t1(%a: f32, %b: f32) { %1 = arith.and %a, %b : f32 ; CHECK: %1 = arith.and %a, %b : f32 ; CHECK-NEXT: ~~~~~~~~~~~~~~~~~~~~~~ -; CHECK-NEXT::6.8-29: Floating point type unsupported for instruction +; CHECK-NEXT::6.8-29: Floating point type unsupported by instruction } diff --git a/test/opt/constant-propagation.ir b/test/opt/constant-propagation.ir index 2aafa384..3e862bf9 100644 --- a/test/opt/constant-propagation.ir +++ b/test/opt/constant-propagation.ir @@ -11,3 +11,71 @@ func @known_size(%a: memref) { ; CHECK-NEXT: %1 = constant 32 -> index ; CHECK-NEXT: %2 = constant 96 -> index } + +func @known_loop_bounds() { + %one = constant 1 -> index + %lb = constant 5 -> index + %size = constant 42 -> index + %tmp = arith.sub %size, %lb : index + %ub = arith.sub %tmp, %one : index + for %i=%lb,%ub { + } +; CHECK-LABEL: func @known_loop_bounds({{.*}} +; CHECK: %one = constant 1 -> index +; CHECK-NEXT: %lb = constant 5 -> index +; CHECK-NEXT: %size = constant 42 -> index +; CHECK-NEXT: %tmp = constant 37 -> index +; CHECK-NEXT: %ub = constant 36 -> index +; CHECK-NEXT: for %i=%lb,%ub : index { +; CHECK-NEXT: } +} + +func @known_arith() { + %0 = constant 1 -> i64 + %1 = arith.not %0 : i64 + %2 = constant 2 -> i64 + %3 = arith.add %0, %2 : i64 + %4 = constant -2.0 -> f32 + %5 = arith.neg %4 : f32 + %6 = constant [1.0, -1.0] -> c32 + %7 = arith.add %6, %6 : c32 +; CHECK-LABEL: func @known_arith({{.*}} +; CHECK: %0 = constant 1 -> i64 +; CHECK-NEXT: %1 = constant -2 -> i64 +; CHECK-NEXT: %2 = constant 2 -> i64 +; CHECK-NEXT: %3 = constant 3 -> i64 +; CHECK-NEXT: %4 = constant -0x1p+1 -> f32 +; CHECK-NEXT: %5 = constant 0x1p+1 -> f32 +; CHECK-NEXT: %6 = constant [0x1p+0,-0x1p+0] -> c32 +; CHECK-NEXT: %7 = constant [0x1p+1,-0x1p+1] -> c32 +} + +func @known_cast() { + %0 = constant 32768 -> i32 + %1 = cast %0 : i32 -> i16 + %2 = cast %0 : i32 -> f32 + %3 = cast %0 : i32 -> c32 + %4 = cast %0 : i32 -> i1 + %5 = cast %3 : c32 -> i1 + %6 = cast %5 : i1 -> c32 +; CHECK-LABEL: func @known_cast({{.*}} +; CHECK: %0 = constant 32768 -> i32 +; CHECK-NEXT: %1 = constant -32768 -> i16 +; CHECK-NEXT: %2 = constant 0x1p+15 -> f32 +; CHECK-NEXT: %3 = constant [0x1p+15,0x0p+0] -> c32 +; CHECK-NEXT: %4 = constant 1 -> i1 +; CHECK-NEXT: %5 = constant 1 -> i1 +; CHECK-NEXT: %6 = constant [0x1p+0,0x0p+0] -> c32 +} + +func @known_compare() { + %0 = constant 1.0 -> f32 + %1 = constant 2.0 -> f32 + %2 = cmp.eq %0, %0 : f32 + %3 = cmp.eq %0, %1 : f32 +; CHECK-LABEL: func @known_compare({{.*}} +; CHECK: %0 = constant 0x1p+0 -> f32 +; CHECK-NEXT: %1 = constant 0x1p+1 -> f32 +; CHECK-NEXT: %2 = constant 1 -> i1 +; CHECK-NEXT: %3 = constant 0 -> i1 +} From 6510f1c95ad12b8d4c2c74ff308902df8cfaea44 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 2 Oct 2024 19:30:26 +0200 Subject: [PATCH 039/297] Bugfix Signed-off-by: Carsten Uphoff --- src/node/value_node.cpp | 1 + src/node/value_node.hpp | 1 + src/pass/constant_propagation.cpp | 12 +++++++++--- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/node/value_node.cpp b/src/node/value_node.cpp index 0954a43d..a859d0dd 100644 --- a/src/node/value_node.cpp +++ b/src/node/value_node.cpp @@ -18,6 +18,7 @@ auto tinytc_value::use_end() const -> const_use_iterator { return {nullptr}; } auto tinytc_value::uses() const -> iterator_range_wrapper { return {use_begin(), use_end()}; } +auto tinytc_value::has_uses() const -> bool { return first_use_ != nullptr; } namespace tinytc { diff --git a/src/node/value_node.hpp b/src/node/value_node.hpp index 91ae9d3c..8d40e9ce 100644 --- a/src/node/value_node.hpp +++ b/src/node/value_node.hpp @@ -47,6 +47,7 @@ struct tinytc_value final { auto use_begin() const -> tinytc::const_use_iterator; auto use_end() const -> tinytc::const_use_iterator; auto uses() const -> tinytc::iterator_range_wrapper; + auto has_uses() const -> bool; // Can be nullptr, e.g. if value is a region parameter auto defining_inst() const -> tinytc_inst_t { return def_inst_; } diff --git a/src/pass/constant_propagation.cpp b/src/pass/constant_propagation.cpp index be2b9ed9..0a0e9f98 100644 --- a/src/pass/constant_propagation.cpp +++ b/src/pass/constant_propagation.cpp @@ -237,7 +237,8 @@ auto constant_evaluator::operator()(size_inst &in) -> inst { } void constant_propagation_pass::run_on_function(function_node &fn) { - walk(fn, [&](region_node ®) { + // @todo: Use worklist instead of pre-order? + walk(fn, [&](region_node ®) { for (auto it = reg.begin(); it != reg.end(); ++it) { auto known_constant = visit(constant_evaluator{}, *it); if (known_constant) { @@ -250,8 +251,13 @@ void constant_propagation_pass::run_on_function(function_node &fn) { for (; r_old != it->result_end() && r_new != known_constant->result_end(); ++r_old, ++r_new) { r_new->name(r_old->name()); - for (auto &u : r_old->uses()) { - u.set(&*r_new); + auto u = r_old->use_begin(); + while (r_old->has_uses()) { + u->set(&*r_new); + u = r_old->use_begin(); + } + if (r_old->has_uses()) { + throw status::internal_compiler_error; } } // delete old instruction From 80aafc8f38b4b902250f5c298492cc661eb3f511 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 7 Oct 2024 15:03:42 +0200 Subject: [PATCH 040/297] Remove name argument from region builder to discourage using a SSA value name. (No auto-renaming is going to be implemented.) Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 4 +- include/tinytc/tinytc.hpp | 33 ++++----------- src/codegen_tools.cpp | 30 ++++++-------- src/parser/lexer.re | 15 +++++-- src/parser/parse_context.cpp | 69 ++++++++++++++++++++++--------- src/parser/parse_context.hpp | 10 +++-- src/parser/parser_impl.yy | 12 ++---- src/pass/constant_propagation.cpp | 58 +++++++++++++------------- src/pass/constant_propagation.hpp | 1 + 9 files changed, 124 insertions(+), 108 deletions(-) diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 14fc3f42..0c2e9c2d 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -52,7 +52,9 @@ are prefixed with ``@``. .. code:: abnf - identifier = 1*DIGIT / (ALPHA *(ALPHA / DIGIT / "_")) + identifier = unnamed-identifier / named-identifier + unnamed-identifier = 1*DIGIT + named-identifier = ALPHA *(ALPHA / DIGIT / "_") local-identifier = "%" identifier global-identifier = "@" identifier diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index ac0988e4..a14c1b02 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1498,15 +1498,12 @@ class region_builder { * @brief Add instruction * * @param i Instruction - * @param name Result name * * @return Value returned by instruction; may be empty */ - [[maybe_unused]] inline auto add(inst i, std::string_view name = "") -> value { + [[maybe_unused]] inline auto add(inst i) -> value { auto result = value{}; - if (i.get_values(result) > 0 && name.size() > 0) { - result.set_name(name); - } + i.get_values(result); reg_.add_instruction(std::move(i)); return result; } @@ -1515,22 +1512,13 @@ class region_builder { * @brief Add instruction that returns multiple values * * @param i Instruction - * @param name Result name * * @return Values returned by instruction */ - [[maybe_unused]] inline auto add_multivalued(inst i, - std::string_view name = "") -> std::vector { + [[maybe_unused]] inline auto add_multivalued(inst i) -> std::vector { auto num_results = i.get_values({}); auto results = std::vector(static_cast(num_results)); results.resize(i.get_values(results)); - if (name.size() > 0) { - int counter = 0; - auto name_str = std::string{name}; - for (auto &result : results) { - result.set_name(name_str + std::to_string(counter++)); - } - } reg_.add_instruction(std::move(i)); return results; } @@ -1545,14 +1533,12 @@ class region_builder { * @param to Loop variable bound * @param loop_var_ty Type of loop variable * @param f Functor - * @param loop_var_name Loop variable name * @param loc Source code location */ template - void for_loop(value from, value to, data_type loop_var_ty, F &&f, - std::string_view loop_var_name = "", location const &loc = {}) { + void for_loop(value from, value to, data_type loop_var_ty, F &&f, location const &loc = {}) { for_loop(std::move(from), std::move(to), nullptr, std::move(loop_var_ty), - std::forward(f), std::move(loop_var_name), loc); + std::forward(f), loc); } /** * @brief Build for-loop with functor f(region_builder&, value) -> void @@ -1565,12 +1551,11 @@ class region_builder { * @param step Loop variable step * @param loop_var_ty Type of loop variable * @param f Functor - * @param loop_var_name Loop variable name * @param loc Source code location */ template void for_loop(value from, value to, value step, data_type loop_var_ty, F &&f, - std::string_view loop_var_name = "", location const &loc = {}) { + location const &loc = {}) { auto fi = ::tinytc::make_for(from, to, step, loop_var_ty, loc); auto reg = region{}; fi.get_regions(reg); @@ -1579,7 +1564,6 @@ class region_builder { if (!reg || !loop_var) { throw status::internal_compiler_error; } - loop_var.set_name(loop_var_name); reg_.add_instruction(std::move(fi)); auto bb = region_builder{reg}; f(bb, loop_var); @@ -1592,12 +1576,10 @@ class region_builder { * @param to Loop variable bound * @param loop_var_ty Type of loop variable * @param f functor - * @param loop_var_name Loop variable name * @param loc Source code location */ template - void foreach (value from, value to, data_type loop_var_ty, F && f, - std::string const &loop_var_name = "", location const &loc = {}) { + void foreach (value from, value to, data_type loop_var_ty, F && f, location const &loc = {}) { auto fi = ::tinytc::make_foreach(from, to, loop_var_ty, loc); auto reg = region{}; fi.get_regions(reg); @@ -1606,7 +1588,6 @@ class region_builder { if (!reg || !loop_var) { throw status::internal_compiler_error; } - loop_var.set_name(loop_var_name); reg_.add_instruction(std::move(fi)); auto bb = region_builder{reg}; f(bb, loop_var); diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index bb64c3d7..8190c339 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -452,9 +452,8 @@ void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_co auto sg_id_index = bb.add(make_cast(sg_id, index_ty)); if (blocks > 0) { auto block_start = bb.add(make_arith(arithmetic::mul, c_sgs, sg_id_index)); - bb.for_loop( - std::move(block_start), c_sgs_blocks, c_sgs_tiles, index_ty, - [&](region_builder &bb, value block) { body(bb, block, false, c_sgs); }, "block"); + bb.for_loop(std::move(block_start), c_sgs_blocks, c_sgs_tiles, index_ty, + [&](region_builder &bb, value block) { body(bb, block, false, c_sgs); }); } if (rem > 0) { @@ -478,9 +477,8 @@ void tile_loop_by_sgs_new_dynamic(region_builder &bb, value loop_trip_count, int auto sg_id_index = bb.add(make_cast(sg_id, index_ty)); auto block_start = bb.add(make_arith(arithmetic::mul, c_sgs, sg_id_index)); auto block_end = bb.add(make_arith(arithmetic::mul, c_sgs, blocks)); - bb.for_loop( - std::move(block_start), std::move(block_end), c_sgs_tiles, index_ty, - [&](region_builder &bb, value block) { body(bb, block, false, c_sgs); }, "block"); + bb.for_loop(std::move(block_start), std::move(block_end), c_sgs_tiles, index_ty, + [&](region_builder &bb, value block) { body(bb, block, false, c_sgs); }); auto condition0 = bb.add(make_cmp(cmp_condition::gt, rem, c0)); bb.if_condition(condition0, [&](region_builder &bb) { @@ -524,18 +522,16 @@ void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip auto sg_id_index = bb.add(make_cast(sg_id, index_ty)); if (rem > 0) { auto block_start = bb.add(make_arith(arithmetic::mul, c_bs_1, sg_id_index)); - bb.for_loop( - std::move(block_start), c_bs_1_rem, c_bs_1_tiles, index_ty, - [&](region_builder &bb, value block) { body(bb, block, c_bs_1); }, "block"); + bb.for_loop(std::move(block_start), c_bs_1_rem, c_bs_1_tiles, index_ty, + [&](region_builder &bb, value block) { body(bb, block, c_bs_1); }); } auto tmp = bb.add(make_arith(arithmetic::add, sg_id_index, c_rem_mod_tiles)); auto sg_id_1 = bb.add(make_arith(arithmetic::rem, tmp, c_tiles)); auto tmp2 = bb.add(make_arith(arithmetic::mul, c_bs, sg_id_1)); auto block_start = bb.add(make_arith(arithmetic::add, c_bs_1_rem, tmp2)); - bb.for_loop( - std::move(block_start), c_loop_trip_count, c_bs_tiles, index_ty, - [&](region_builder &bb, value block) { body(bb, block, c_bs); }, "block"); + bb.for_loop(std::move(block_start), c_loop_trip_count, c_bs_tiles, index_ty, + [&](region_builder &bb, value block) { body(bb, block, c_bs); }); } void tile_loop_uniformly_new_dynamic(region_builder &bb, value loop_trip_count, int block_size, @@ -565,9 +561,8 @@ void tile_loop_uniformly_new_dynamic(region_builder &bb, value loop_trip_count, auto block_start_1 = bb.add(make_arith(arithmetic::mul, bs_1, sg_id_index)); auto block_end_1 = bb.add(make_arith(arithmetic::mul, bs_1, rem)); auto step_1 = bb.add(make_arith(arithmetic::mul, bs_1, c_tiles)); - bb.for_loop( - std::move(block_start_1), std::move(block_end_1), std::move(step_1), index_ty, - [&](region_builder &bb, value block) { body(bb, block, bs_1); }, "block"); + bb.for_loop(std::move(block_start_1), std::move(block_end_1), std::move(step_1), index_ty, + [&](region_builder &bb, value block) { body(bb, block, bs_1); }); auto tmp0 = bb.add(make_arith(arithmetic::rem, rem, c_tiles)); auto tmp1 = bb.add(make_arith(arithmetic::add, sg_id_index, tmp0)); @@ -576,9 +571,8 @@ void tile_loop_uniformly_new_dynamic(region_builder &bb, value loop_trip_count, auto tmp3 = bb.add(make_arith(arithmetic::mul, bs_1, rem)); auto block_start = bb.add(make_arith(arithmetic::add, tmp3, tmp2)); auto step = bb.add(make_arith(arithmetic::mul, bs, c_tiles)); - bb.for_loop( - std::move(block_start), loop_trip_count, std::move(step), index_ty, - [&](region_builder &bb, value block) { body(bb, block, bs); }, "block"); + bb.for_loop(std::move(block_start), loop_trip_count, std::move(step), index_ty, + [&](region_builder &bb, value block) { body(bb, block, bs); }); } } // namespace tinytc diff --git a/src/parser/lexer.re b/src/parser/lexer.re index 302d278f..38d96dd1 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -35,9 +35,11 @@ lex: newline = "\r"? "\n"; whitespace = [ \t\v\r]+; - identifier = [0-9]+ | ([a-zA-Z] [a-zA-Z0-9_]*); - local_identifier = "%" identifier; - global_identifier = "@" identifier; + unnamed_identifier = [0-9]+; + named_identifier = [a-zA-Z] [a-zA-Z0-9_]*; + local_unnamed_identifier = "%" unnamed_identifier; + local_named_identifier = "%" named_identifier; + global_identifier = "@" (unnamed_identifier | named_identifier); integer_type = "i" ("1" | "8" | "16" | "32" | "64") | "index"; floating_type = ("f" | "c") ("32" | "64"); @@ -54,7 +56,12 @@ lex: // identifier - local_identifier { + local_unnamed_identifier { + adv_loc(); + std::int64_t id = lex_integer_constant(b + 1, YYCURSOR); + return parser::make_LOCAL_IDENTIFIER(std::move(id), loc_); + } + local_named_identifier { adv_loc(); auto id = std::string(b + 1, YYCURSOR); return parser::make_LOCAL_IDENTIFIER(std::move(id), loc_); diff --git a/src/parser/parse_context.cpp b/src/parser/parse_context.cpp index 6bc4890b..153755b5 100644 --- a/src/parser/parse_context.cpp +++ b/src/parser/parse_context.cpp @@ -14,36 +14,65 @@ namespace tinytc { parse_context::parse_context(compiler_context compiler_ctx) : compiler_ctx_(compiler_ctx) {} -void parse_context::push_scope() { id_map_.push_back({}); } -void parse_context::pop_scope() { id_map_.pop_back(); } +void parse_context::push_scope() { + unnamed_id_map_.push_back({}); + named_id_map_.push_back({}); +} +void parse_context::pop_scope() { + named_id_map_.pop_back(); + unnamed_id_map_.pop_back(); +} void parse_context::push_region(tinytc_region_t r) { regions_.push(r); } void parse_context::pop_region() { regions_.pop(); } auto parse_context::top_region() -> tinytc_region_t { return regions_.top(); } auto parse_context::has_regions() -> bool { return !regions_.empty(); } -void parse_context::val(std::string const &id, tinytc_value &val, location const &l) { - if (id_map_.empty()) { - throw parser::syntax_error(l, "No active variable scope"); +void parse_context::val(std::variant const &id, tinytc_value &val, + location const &l) { + const auto handle_val = + [&l, &val](KeyT const &id, + std::vector> &map) { + if (map.empty()) { + throw parser::syntax_error(l, "No active scope"); + } + for (auto it = map.rbegin(); it != map.rend(); ++it) { + if (auto other = it->find(id); other != it->end()) { + auto oss = std::ostringstream{}; + oss << "Identifier %" << id << " was already used at " << other->second->loc(); + throw parser::syntax_error(l, std::move(oss).str()); + } + } + val.loc(l); + map.back()[id] = &val; + }; + if (std::holds_alternative(id)) { + handle_val(std::get(id), unnamed_id_map_); + } else { + auto const &sid = std::get(id); + handle_val(sid, named_id_map_); + val.name(sid); } - for (auto it = id_map_.rbegin(); it != id_map_.rend(); ++it) { - if (auto other = it->find(id); other != it->end()) { - auto oss = std::ostringstream{}; - oss << "Identifier %" << id << " was already used at " << other->second->loc(); - throw parser::syntax_error(l, oss.str()); - } - } - val.loc(l); - id_map_.back()[id] = &val; } -auto parse_context::val(std::string const &id, location const &l) -> tinytc_value_t { - for (auto it = id_map_.rbegin(); it != id_map_.rend(); ++it) { - if (auto j = it->find(id); j != it->end()) { - return j->second; - } +auto parse_context::val(std::variant const &id, + location const &l) -> tinytc_value_t { + const auto handle_val = + [&l](KeyT const &id, + std::vector> &map) { + for (auto it = map.rbegin(); it != map.rend(); ++it) { + if (auto j = it->find(id); j != it->end()) { + return j->second; + } + } + auto oss = std::ostringstream{}; + oss << "Undefined identifier %" << id; + throw parser::syntax_error(l, std::move(oss).str()); + }; + if (std::holds_alternative(id)) { + return handle_val(std::get(id), unnamed_id_map_); } - throw parser::syntax_error(l, "Undefined identifier %" + id); + return handle_val(std::get(id), named_id_map_); } void parse_context::report_error(location const &loc, std::string const &what) { diff --git a/src/parser/parse_context.hpp b/src/parser/parse_context.hpp index 3fa5defd..e57855b0 100644 --- a/src/parser/parse_context.hpp +++ b/src/parser/parse_context.hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace tinytc { @@ -21,8 +22,10 @@ class parse_context { inline auto program() { return program_; } inline void program(prog p) { program_ = std::move(p); } - void val(std::string const &id, tinytc_value &val, location const &l); - auto val(std::string const &id, location const &l) -> tinytc_value_t; + void val(std::variant const &id, tinytc_value &val, + location const &l); + auto val(std::variant const &id, + location const &l) -> tinytc_value_t; void report_error(location const &loc, std::string const &what); @@ -38,7 +41,8 @@ class parse_context { private: compiler_context compiler_ctx_; - std::vector> id_map_; + std::vector> unnamed_id_map_; + std::vector> named_id_map_; std::stack regions_; prog program_; }; diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 16fc07c1..e26859f1 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -129,7 +129,7 @@ SUM "sum" YIELD "yield" ; -%token LOCAL_IDENTIFIER +%token > LOCAL_IDENTIFIER %token GLOBAL_IDENTIFIER %token INTEGER_CONSTANT %token FLOATING_CONSTANT @@ -142,8 +142,8 @@ %nterm prog %nterm > func_list %nterm func -%nterm ,std::vector>> parameters -%nterm > parameter +%nterm >,std::vector>> parameters +%nterm ,tinytc_data_type_t>> parameter %nterm >> attributes %nterm > attribute %nterm data_type @@ -182,7 +182,7 @@ %nterm yield_inst %nterm for_loop_var_type %nterm var_definition -%nterm > identifier_list +%nterm >> identifier_list %nterm valued_inst %nterm alloca_inst %nterm arith_inst @@ -239,7 +239,6 @@ func: ctx.push_scope(); auto name_it = $parameters.first.begin(); for (auto &p : func_node->params()) { - p.name(*name_it); ctx.val(*name_it, p, @parameters); ++name_it; } @@ -569,7 +568,6 @@ for_inst: auto inode = std::make_unique($from, $to, $optional_step, $for_loop_var_type, loc); ctx.push_scope(); auto &loop_var = inode->loop_var(); - loop_var.name($loop_var); ctx.val($loop_var, loop_var, @loop_var); ctx.push_region(&inode->body()); $$ = inst{inode.release()}; @@ -599,7 +597,6 @@ foreach_inst: std::make_unique($from, $to, $for_loop_var_type, loc); ctx.push_scope(); auto &loop_var = inode->loop_var(); - loop_var.name($loop_var); ctx.val($loop_var, loop_var, @loop_var); ctx.push_region(&inode->body()); $$ = inst{inode.release()}; @@ -629,7 +626,6 @@ var_definition: } auto results = $$->result_begin(); for (std::int64_t i = 0; i < $$->num_results(); ++i) { - results[i].name($identifier_list[i]); ctx.val($identifier_list[i], results[i], @identifier_list); } } diff --git a/src/pass/constant_propagation.cpp b/src/pass/constant_propagation.cpp index 0a0e9f98..cf88afde 100644 --- a/src/pass/constant_propagation.cpp +++ b/src/pass/constant_propagation.cpp @@ -10,7 +10,6 @@ #include "scalar_type.hpp" #include "support/casting.hpp" #include "support/visit.hpp" -#include "support/walk.hpp" #include "tinytc/tinytc.hpp" #include @@ -236,37 +235,40 @@ auto constant_evaluator::operator()(size_inst &in) -> inst { return inst{}; } -void constant_propagation_pass::run_on_function(function_node &fn) { - // @todo: Use worklist instead of pre-order? - walk(fn, [&](region_node ®) { - for (auto it = reg.begin(); it != reg.end(); ++it) { - auto known_constant = visit(constant_evaluator{}, *it); - if (known_constant) { - // update uses - if (it->num_results() != known_constant->num_results()) { - throw status::internal_compiler_error; +void constant_propagation_pass::run_on_function(function_node &fn) { run_on_region(fn.body()); } + +void constant_propagation_pass::run_on_region(region_node ®) { + for (auto it = reg.begin(); it != reg.end(); ++it) { + for (auto &subreg : it->child_regions()) { + run_on_region(subreg); + } + + auto known_constant = visit(constant_evaluator{}, *it); + if (known_constant) { + // update uses + if (it->num_results() != known_constant->num_results()) { + throw status::internal_compiler_error; + } + auto r_old = it->result_begin(); + auto r_new = known_constant->result_begin(); + for (; r_old != it->result_end() && r_new != known_constant->result_end(); + ++r_old, ++r_new) { + r_new->name(r_old->name()); + auto u = r_old->use_begin(); + while (r_old->has_uses()) { + u->set(&*r_new); + u = r_old->use_begin(); } - auto r_old = it->result_begin(); - auto r_new = known_constant->result_begin(); - for (; r_old != it->result_end() && r_new != known_constant->result_end(); - ++r_old, ++r_new) { - r_new->name(r_old->name()); - auto u = r_old->use_begin(); - while (r_old->has_uses()) { - u->set(&*r_new); - u = r_old->use_begin(); - } - if (r_old->has_uses()) { - throw status::internal_compiler_error; - } + if (r_old->has_uses()) { + throw status::internal_compiler_error; } - // delete old instruction - it = reg.insts().erase(it); - // insert new instruction - it = reg.insts().insert(it, known_constant.release()); } + // delete old instruction + it = reg.insts().erase(it); + // insert new instruction + it = reg.insts().insert(it, known_constant.release()); } - }); + } } } // namespace tinytc diff --git a/src/pass/constant_propagation.hpp b/src/pass/constant_propagation.hpp index 59d8d68a..82964539 100644 --- a/src/pass/constant_propagation.hpp +++ b/src/pass/constant_propagation.hpp @@ -11,6 +11,7 @@ namespace tinytc { class constant_propagation_pass { public: void run_on_function(::tinytc_func &fn); + void run_on_region(::tinytc_region ®); }; } // namespace tinytc From 7abb2ce6d7c8bcca2eaed48a5904f9ec57aeb708 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 7 Oct 2024 18:18:41 +0200 Subject: [PATCH 041/297] Add dead code elimination Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 4 +- src/CMakeLists.txt | 1 + src/codegen_tools.cpp | 36 +++++++------ src/compiler.cpp | 15 +++++- src/pass/dead_code_elimination.cpp | 83 ++++++++++++++++++++++++++++++ src/pass/dead_code_elimination.hpp | 19 +++++++ src/pass/lower_linalg.cpp | 2 - src/passes.def | 1 + test/opt/dead-code-elimination.ir | 54 +++++++++++++++++++ 9 files changed, 194 insertions(+), 21 deletions(-) create mode 100644 src/pass/dead_code_elimination.cpp create mode 100644 src/pass/dead_code_elimination.hpp create mode 100644 test/opt/dead-code-elimination.ir diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 0c2e9c2d..df183779 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -611,7 +611,7 @@ Barrier .. code:: abnf - barrier-instruction = "barrier" [".global"] [".local"] + instruction =/ "barrier" [".global"] [".local"] Overview ~~~~~~~~ @@ -1094,7 +1094,7 @@ Additional instructions .. code:: abnf - lifetime-stop-instruction = "lifetime_stop" local-identifier + instruction =/ "lifetime_stop" local-identifier SPMD instructions ----------------- diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c24f87b2..a8e8d654 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -43,6 +43,7 @@ set(SOURCES pass/check_ir.cpp pass/constant_propagation.cpp pass/convert_to_opencl.cpp + pass/dead_code_elimination.cpp pass/dump_cfg.cpp pass/dump_def_use.cpp pass/dump_ir.cpp diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 8190c339..fa82a2e8 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -3,6 +3,7 @@ #include "codegen_tools.hpp" #include "error.hpp" +#include "node/inst_node.hpp" #include "node/value_node.hpp" #include "scalar_type.hpp" #include "support/util.hpp" @@ -538,31 +539,36 @@ void tile_loop_uniformly_new_dynamic(region_builder &bb, value loop_trip_count, int num_tiles, value sg_id, uniform_loop_body_builder_new const &body) { auto index_ty = scalar_data_type::get(loop_trip_count->context(), scalar_type::index); + auto c0 = bb.add(make_constant(0, index_ty)); auto c1 = bb.add(make_constant(1, index_ty)); - auto c_block_size = bb.add(make_constant(block_size, index_ty)); auto c_tiles = bb.add(make_constant(num_tiles, index_ty)); + // Here we compute + // blocks = ceil(loop_trip_count / block_size) = 1 + (loop_trip_count - 1) / block_size + // blocks = ceil(blocks / num_tiles) * num_tiles = (1 + (blocks - 1) / num_tiles) * num_tiles + auto c_block_size = bb.add(make_constant(block_size, index_ty)); auto blocks0 = bb.add(make_arith(arithmetic::sub, loop_trip_count, c1)); auto blocks1 = bb.add(make_arith(arithmetic::div, blocks0, c_block_size)); - auto blocks2 = bb.add(make_arith(arithmetic::add, c1, blocks1)); - auto blocks3 = bb.add(make_arith(arithmetic::sub, blocks2, c1)); - auto blocks4 = bb.add(make_arith(arithmetic::div, blocks3, c_tiles)); - auto blocks5 = bb.add(make_arith(arithmetic::add, c1, blocks4)); - auto blocks = bb.add(make_arith(arithmetic::mul, blocks5, c_tiles)); - blocks->name("blocks"); + auto blocks2 = bb.add(make_arith(arithmetic::div, blocks1, c_tiles)); + auto blocks3 = bb.add(make_arith(arithmetic::add, c1, blocks2)); + auto blocks = bb.add(make_arith(arithmetic::mul, blocks3, c_tiles)); + auto bs = bb.add(make_arith(arithmetic::div, loop_trip_count, blocks)); - bs->name("bs"); auto bs_1 = bb.add(make_arith(arithmetic::add, bs, c1)); - bs_1->name("bs_1"); auto rem = bb.add(make_arith(arithmetic::rem, loop_trip_count, blocks)); - rem->name("rem"); auto sg_id_index = bb.add(make_cast(sg_id, index_ty)); - auto block_start_1 = bb.add(make_arith(arithmetic::mul, bs_1, sg_id_index)); - auto block_end_1 = bb.add(make_arith(arithmetic::mul, bs_1, rem)); - auto step_1 = bb.add(make_arith(arithmetic::mul, bs_1, c_tiles)); - bb.for_loop(std::move(block_start_1), std::move(block_end_1), std::move(step_1), index_ty, - [&](region_builder &bb, value block) { body(bb, block, bs_1); }); + // The following if makes it easy to eliminate the remainder handler in optimization if rem == 0 + // is known at compile time. Without the if, we would need to prove that block_start_1 is + // non-negative to eliminate the for-loop. + auto is_rem_0 = bb.add(make_cmp(cmp_condition::gt, rem, c0)); + bb.if_condition(is_rem_0, [&](region_builder &bb) { + auto block_start_1 = bb.add(make_arith(arithmetic::mul, bs_1, sg_id_index)); + auto block_end_1 = bb.add(make_arith(arithmetic::mul, bs_1, rem)); + auto step_1 = bb.add(make_arith(arithmetic::mul, bs_1, c_tiles)); + bb.for_loop(std::move(block_start_1), std::move(block_end_1), std::move(step_1), index_ty, + [&](region_builder &bb, value block) { body(bb, block, bs_1); }); + }); auto tmp0 = bb.add(make_arith(arithmetic::rem, rem, c_tiles)); auto tmp1 = bb.add(make_arith(arithmetic::add, sg_id_index, tmp0)); diff --git a/src/compiler.cpp b/src/compiler.cpp index b6b47615..8e35d6ea 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -7,6 +7,7 @@ #include "pass/check_ir.hpp" #include "pass/constant_propagation.hpp" #include "pass/convert_to_opencl.hpp" +#include "pass/dead_code_elimination.hpp" #include "pass/dump_cfg.hpp" #include "pass/dump_def_use.hpp" #include "pass/dump_ir.hpp" @@ -86,12 +87,22 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ [&] { // passes run_function_pass(check_ir_pass{}, *prg); + + // We run constant propagation + dead code elimination early to capture dead allocas + // (later on they are maybe "in use" due to the lifetime_stop instruction) + run_function_pass(constant_propagation_pass{}, *prg); + run_function_pass(dead_code_elimination_pass{}, *prg); + run_function_pass(insert_lifetime_stop_pass{}, *prg); run_function_pass(set_stack_ptr_pass{}, *prg); run_function_pass(insert_barrier_pass{}, *prg); run_function_pass(work_group_size_pass{info}, *prg); - // run_function_pass(lower_linalg_pass{info}, *prg); - // run_function_pass(constant_propagation_pass{info}, *prg); + + run_function_pass(lower_linalg_pass{info}, *prg); + run_function_pass(constant_propagation_pass{}, *prg); + run_function_pass(dead_code_elimination_pass{}, *prg); + run_function_pass(dump_ir_pass{std::cout}, *prg); + // opencl auto ast = convert_to_opencl_pass{info}.run_on_program(*prg); clir::make_names_unique(ast); diff --git a/src/pass/dead_code_elimination.cpp b/src/pass/dead_code_elimination.cpp new file mode 100644 index 00000000..5f1eabaf --- /dev/null +++ b/src/pass/dead_code_elimination.cpp @@ -0,0 +1,83 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/dead_code_elimination.hpp" +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "support/casting.hpp" +#include "support/visit.hpp" +#include "tinytc/tinytc.hpp" + +#include + +#include + +namespace tinytc { + +class dead_code_analysis { + public: + auto operator()(inst_node &in) -> bool; + auto operator()(if_inst &in) -> bool; + auto operator()(loop_inst &in) -> bool; +}; + +auto dead_code_analysis::operator()(inst_node &in) -> bool { + /* Instruction have side effects if either of the following is true + * + * - More than one child region (if, for, foreach, parallel, ...) + * - Instruction does not have results (barrier, GEMM, GER, ...) + * + */ + const bool has_side_effects = in.num_child_regions() > 0 || in.num_results() == 0; + + bool any_result_has_uses = false; + for (auto &res : in.results()) { + any_result_has_uses = any_result_has_uses || res.has_uses(); + } + + return !has_side_effects && !any_result_has_uses; +} + +auto dead_code_analysis::operator()(if_inst &in) -> bool { + constant_inst *cond_const = dyn_cast(in.condition().defining_inst()); + if (cond_const) { + // If-instruction is dead if condition is constant and false + return std::holds_alternative(cond_const->value()) && + std::get(cond_const->value()) == 0; + } + + return false; +} + +auto dead_code_analysis::operator()(loop_inst &in) -> bool { + constant_inst *from_const = dyn_cast(in.from().defining_inst()); + constant_inst *to_const = dyn_cast(in.to().defining_inst()); + if (from_const && to_const) { + // For-instruction is dead if from >= to + return std::holds_alternative(from_const->value()) && + std::holds_alternative(to_const->value()) && + std::get(from_const->value()) >= + std::get(to_const->value()); + } + return false; +} + +void dead_code_elimination_pass::run_on_function(function_node &fn) { run_on_region(fn.body()); } + +void dead_code_elimination_pass::run_on_region(region_node ®) { + auto prev_it = reg.end(); + while (prev_it != reg.begin()) { + auto it = --prev_it; + auto is_dead = visit(dead_code_analysis{}, *it); + if (is_dead) { + prev_it = reg.insts().erase(it); + } else { + for (auto &subreg : it->child_regions()) { + run_on_region(subreg); + } + } + } +} + +} // namespace tinytc diff --git a/src/pass/dead_code_elimination.hpp b/src/pass/dead_code_elimination.hpp new file mode 100644 index 00000000..e56d27f8 --- /dev/null +++ b/src/pass/dead_code_elimination.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DEAD_CODE_ELIMINATION_20241007_HPP +#define DEAD_CODE_ELIMINATION_20241007_HPP + +#include "tinytc/types.h" + +namespace tinytc { + +class dead_code_elimination_pass { + public: + void run_on_function(::tinytc_func &fn); + void run_on_region(::tinytc_region ®); +}; + +} // namespace tinytc + +#endif // DEAD_CODE_ELIMINATION_20241007_HPP diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index 7107037b..81b83a23 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -61,13 +61,11 @@ auto linalg_generator::operator()(ger_inst &g) -> inst { bb.for_loop(zero, trip_count, index_ty, [&](region_builder &bb, value n) { auto nn = bb.add(make_arith(arithmetic::add, block, n, g.loc())); auto b = bb.add(make_load(&g.B(), {nn}, g.loc())); - b->name("b"); tile_loop_by_sgs_new( bb, c_shape0, core_cfg_.subgroup_size, tiling_.m_tiles(), sg_m, [&](region_builder &bb, value block, bool, value) { auto mm = bb.add(make_arith(arithmetic::add, block, m_index, g.loc())); auto a = bb.add(make_load(&g.A(), {mm}, g.loc())); - a->name("a"); auto ab = bb.add(make_arith(arithmetic::mul, a, b, g.loc())); bb.add(make_store(ab, &g.C(), {mm, nn}, g.loc())); }); diff --git a/src/passes.def b/src/passes.def index caedd283..ad833076 100644 --- a/src/passes.def +++ b/src/passes.def @@ -3,6 +3,7 @@ FUNCTION_PASS("check-ir", check_ir_pass{}) FUNCTION_PASS("constant-propagation", constant_propagation_pass{}) +FUNCTION_PASS("dead-code-elimination", dead_code_elimination_pass{}) FUNCTION_PASS("dump-control-flow-graph", dump_cfg_pass{std::cout}) FUNCTION_PASS("dump-def-use", dump_def_use_pass{std::cout}) FUNCTION_PASS("dump-ir", dump_ir_pass{std::cout}) diff --git a/test/opt/dead-code-elimination.ir b/test/opt/dead-code-elimination.ir new file mode 100644 index 00000000..66fec77e --- /dev/null +++ b/test/opt/dead-code-elimination.ir @@ -0,0 +1,54 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-opt --dead-code-elimination < %s | filecheck %s +func @dead_if(%a: memref) { + %c0 = constant 0 -> i1 + if %c0 { + %c42 = constant 42.0 -> f64 + store %c42, %a[] : memref + } + %c1 = constant 1 -> i1 + if %c1 { + %c43 = constant 43.0 -> f64 + store %c43, %a[] : memref + } +; CHECK-LABEL: func @dead_if({{.*}} +; CHECK-NEXT: %c1 = constant 1 -> i1 +; CHECK-NEXT: if %c1 { +; CHECK-NEXT: %c43{{.*}} +; CHECK-NEXT: store{{.*}} +; CHECK-NEXT: } +} + +func @dead_loop(%a: memref) { + %c2 = constant 2 -> index + for %0=%c2,%c2 { + %c42 = constant 42.0 -> f64 + store %c42, %a[] : memref + } + %c5 = constant 5 -> index + %c6 = constant 6 -> index + for %0=%c5,%c6 { + %c43 = constant 43.0 -> f64 + store %c43, %a[] : memref + } +; CHECK-LABEL: func @dead_loop({{.*}} +; CHECK-NEXT: %c5 = constant 5 -> index +; CHECK-NEXT: %c6 = constant 6 -> index +; CHECK-NEXT: for %0=%c5,%c6 : index { +; CHECK-NEXT: %c43{{.*}} +; CHECK-NEXT: store{{.*}} +; CHECK-NEXT: } +} + +func @unused_alloca(%a: memref) { + %0 = alloca -> memref + %1 = alloca -> memref + %one = constant 1.0 -> f64 + axpby.n %one, %1, %one, %a : f64, memref, f64, memref +; CHECK-LABEL: func @unused_alloca({{.*}} +; CHECK-NEXT: %0 = alloca -> memref +; CHECK-NEXT: %one{{.*}} +; CHECK-NEXT: axpby.n %one, %0{{.*}} +} From 9230f497cedae84e98f139e662c3990308966221 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 7 Oct 2024 18:50:13 +0200 Subject: [PATCH 042/297] Cleanup Signed-off-by: Carsten Uphoff --- src/analysis/alias.cpp | 1 - src/analysis/cfg.cpp | 1 - src/analysis/equal.cpp | 1 - src/codegen_tools.cpp | 4 ++-- src/compiler_context_cache.cpp | 2 ++ src/compiler_context_cache.hpp | 7 +++++-- src/data_type.cpp | 17 ++--------------- src/func.cpp | 3 --- src/inst.cpp | 5 ++++- src/node/data_type_node.cpp | 16 +++++++++------- src/node/data_type_node.hpp | 12 ++++++------ src/node/inst_node.cpp | 3 --- src/node/inst_node.hpp | 5 +++++ src/node/program_node.cpp | 2 -- src/node/program_node.hpp | 3 +-- src/node/region_node.cpp | 3 +++ src/node/region_node.hpp | 2 +- src/node/value_node.hpp | 5 +++-- src/parser/parse_context.cpp | 1 + src/parser/parse_context.hpp | 4 ++++ src/parser/parser_impl.yy | 9 ++++++--- src/pass/check_ir.cpp | 1 - src/pass/constant_propagation.cpp | 9 ++++++++- src/pass/convert_to_opencl.cpp | 5 ++++- src/pass/dead_code_elimination.cpp | 7 ++++--- src/pass/dump_def_use.cpp | 7 +++++++ src/pass/dump_ir.cpp | 6 ++++++ src/pass/dump_ir.hpp | 1 - src/pass/insert_barrier.cpp | 1 - src/pass/insert_lifetime_stop.cpp | 3 ++- src/pass/lower_linalg.cpp | 11 +++++++++++ src/pass/slot_tracker.cpp | 4 +++- src/pass/stack.cpp | 1 - src/pass/work_group_size.cpp | 1 - src/recipe/small_gemm_batched.cpp | 1 + src/region.cpp | 8 ++++++-- src/support/walk.hpp | 3 ++- src/value.cpp | 5 ----- test/opt/dump-def-use.ir | 18 +++++++++--------- test/opt/insert-barrier.ir | 24 ++++++++++++------------ 40 files changed, 129 insertions(+), 93 deletions(-) diff --git a/src/analysis/alias.cpp b/src/analysis/alias.cpp index a492257d..62c55fbe 100644 --- a/src/analysis/alias.cpp +++ b/src/analysis/alias.cpp @@ -9,7 +9,6 @@ #include "support/casting.hpp" #include "support/visit.hpp" #include "support/walk.hpp" -#include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" #include diff --git a/src/analysis/cfg.cpp b/src/analysis/cfg.cpp index 8f64aad7..bd8c657c 100644 --- a/src/analysis/cfg.cpp +++ b/src/analysis/cfg.cpp @@ -5,7 +5,6 @@ #include "node/inst_node.hpp" #include "support/casting.hpp" #include "support/ilist_base.hpp" -#include "tinytc/tinytc.hpp" #include #include diff --git a/src/analysis/equal.cpp b/src/analysis/equal.cpp index a2377bf7..3962a5b4 100644 --- a/src/analysis/equal.cpp +++ b/src/analysis/equal.cpp @@ -3,7 +3,6 @@ #include "analysis/equal.hpp" #include "support/visit.hpp" -#include "tinytc/tinytc.hpp" #include diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index fa82a2e8..3fb396a0 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -3,11 +3,11 @@ #include "codegen_tools.hpp" #include "error.hpp" -#include "node/inst_node.hpp" +#include "node/data_type_node.hpp" #include "node/value_node.hpp" #include "scalar_type.hpp" -#include "support/util.hpp" #include "support/visit.hpp" +#include "tinytc/types.h" #include #include diff --git a/src/compiler_context_cache.cpp b/src/compiler_context_cache.cpp index e415be77..c94bf416 100644 --- a/src/compiler_context_cache.cpp +++ b/src/compiler_context_cache.cpp @@ -2,7 +2,9 @@ // SPDX-License-Identifier: BSD-3-Clause #include "compiler_context_cache.hpp" +#include "compiler_context.hpp" #include "support/util.hpp" +#include "tinytc/types.hpp" namespace tinytc { diff --git a/src/compiler_context_cache.hpp b/src/compiler_context_cache.hpp index 78318d38..ca85b8a4 100644 --- a/src/compiler_context_cache.hpp +++ b/src/compiler_context_cache.hpp @@ -6,12 +6,15 @@ #include "node/data_type_node.hpp" #include "support/util.hpp" -#include "tinytc/types.hpp" +#include "tinytc/types.h" #include +#include #include -#include +#include +#include #include +#include namespace std { template <> class hash> { diff --git a/src/data_type.cpp b/src/data_type.cpp index 91866ac0..2af618c5 100644 --- a/src/data_type.cpp +++ b/src/data_type.cpp @@ -2,7 +2,6 @@ // SPDX-License-Identifier: BSD-3-Clause #include "compiler_context.hpp" -#include "compiler_context_cache.hpp" #include "error.hpp" #include "location.hpp" #include "node/data_type_node.hpp" @@ -13,9 +12,6 @@ #include "tinytc/types.hpp" #include -#include -#include -#include using namespace tinytc; @@ -41,17 +37,8 @@ tinytc_status_t tinytc_memref_type_get(tinytc_data_type_t *dt, tinytc_data_type_ } return exception_to_status_code([&] { - auto shape_span = std::span{}; - if (shape != nullptr) { - shape_span = std::span(shape, static_cast(shape_size)); - } - auto stride_span = std::span{}; - if (stride != nullptr) { - stride_span = - std::span(stride, static_cast(stride_size)); - } - - *dt = memref_data_type::get(scalar_ty, std::move(shape_span), std::move(stride_span), + *dt = memref_data_type::get(scalar_ty, array_view{shape, shape_size}, + array_view{stride, stride_size}, enum_cast(addrspace), get_optional(loc)); }); } diff --git a/src/func.cpp b/src/func.cpp index 886d9987..3afc9590 100644 --- a/src/func.cpp +++ b/src/func.cpp @@ -4,7 +4,6 @@ #include "error.hpp" #include "location.hpp" #include "node/function_node.hpp" -#include "node/region_node.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" @@ -12,8 +11,6 @@ #include #include #include -#include -#include using namespace tinytc; diff --git a/src/inst.cpp b/src/inst.cpp index 4143a72f..0749735f 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -1,9 +1,13 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause +#include "compiler_context.hpp" #include "error.hpp" #include "location.hpp" +#include "node/data_type_node.hpp" #include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" #include "support/util.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" @@ -16,7 +20,6 @@ #include #include #include -#include using namespace tinytc; diff --git a/src/node/data_type_node.cpp b/src/node/data_type_node.cpp index 9154d99e..f917a8c9 100644 --- a/src/node/data_type_node.cpp +++ b/src/node/data_type_node.cpp @@ -9,7 +9,10 @@ #include "tinytc/types.hpp" #include +#include #include +#include +#include #include namespace tinytc { @@ -66,15 +69,15 @@ scalar_type memref_data_type::element_ty() const { return dyn_cast(element_ty_)->ty(); } -auto memref_data_type::get(tinytc_data_type_t element_ty, std::span shape, - std::span stride, address_space addrspace, +auto memref_data_type::get(tinytc_data_type_t element_ty, array_view shape, + array_view stride, address_space addrspace, location const &lc) -> tinytc_data_type_t { auto ctx = element_ty->context(); auto stride_buffer = std::vector{}; if (stride.empty()) { stride_buffer = canonical_stride(shape); - stride = std::span{stride_buffer}; + stride = array_view{stride_buffer}; } auto key = memref_data_type_key(element_ty, shape, stride, addrspace); @@ -87,13 +90,12 @@ auto memref_data_type::get(tinytc_data_type_t element_ty, std::spansecond; } } - auto new_mt = std::unique_ptr(new memref_data_type( - ctx, key.element_ty, std::vector(shape.begin(), shape.end()), - std::vector(stride.begin(), stride.end()), key.addrspace, lc)); + auto new_mt = std::unique_ptr( + new memref_data_type(ctx, key.element_ty, shape, stride, key.addrspace, lc)); return tys.emplace(map_key, new_mt.release())->second; } -auto memref_data_type::canonical_stride(std::span shape) +auto memref_data_type::canonical_stride(array_view shape) -> std::vector { if (shape.empty()) { return {}; diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp index c699e27e..0e135c5a 100644 --- a/src/node/data_type_node.hpp +++ b/src/node/data_type_node.hpp @@ -4,14 +4,14 @@ #ifndef DATA_TYPE_NODE_20230309_HPP #define DATA_TYPE_NODE_20230309_HPP +#include "compiler_context.hpp" #include "support/type_list.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.h" #include "tinytc/types.hpp" #include #include -#include -#include #include namespace tinytc { @@ -60,9 +60,9 @@ class group_data_type : public data_type_node { class memref_data_type : public data_type_node { public: inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::memref; } - static auto canonical_stride(std::span shape) -> std::vector; - static auto get(tinytc_data_type_t element_ty, std::span shape, - std::span stride, + static auto canonical_stride(array_view shape) -> std::vector; + static auto get(tinytc_data_type_t element_ty, array_view shape, + array_view stride, address_space addrspace = address_space::global, location const &lc = {}) -> tinytc_data_type_t; @@ -100,7 +100,7 @@ class memref_data_type : public data_type_node { struct memref_data_type_key { tinytc_data_type_t element_ty; - std::span shape, stride; + array_view shape, stride; address_space addrspace; auto hash() -> std::uint64_t; diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index f8170dc0..0d2f89c2 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -8,12 +8,9 @@ #include "node/value_node.hpp" #include "scalar_type.hpp" #include "support/casting.hpp" -#include "support/util.hpp" #include "support/visit.hpp" #include "tinytc/types.hpp" -#include - #include #include #include diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index b6d618d8..61f8729d 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -4,24 +4,29 @@ #ifndef INST_NODE_20230327_HPP #define INST_NODE_20230327_HPP +#include "compiler_context.hpp" #include "error.hpp" #include "node/data_type_node.hpp" #include "node/region_node.hpp" #include "node/value_node.hpp" #include "support/ilist.hpp" +#include "support/ilist_base.hpp" #include "support/type_list.hpp" #include "support/util.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.h" #include "tinytc/types.hpp" #include #include #include #include +#include #include #include #include #include +#include namespace tinytc { diff --git a/src/node/program_node.cpp b/src/node/program_node.cpp index aa4c94bf..c2c8cc5e 100644 --- a/src/node/program_node.cpp +++ b/src/node/program_node.cpp @@ -2,8 +2,6 @@ // SPDX-License-Identifier: BSD-3-Clause #include "node/program_node.hpp" -#include "node/function_node.hpp" -#include "tinytc/tinytc.h" #include diff --git a/src/node/program_node.hpp b/src/node/program_node.hpp index 036a0942..7015d5c0 100644 --- a/src/node/program_node.hpp +++ b/src/node/program_node.hpp @@ -5,13 +5,12 @@ #define PROGRAM_NODE_20240208_HPP #include "compiler_context.hpp" -#include "node/function_node.hpp" #include "reference_counted.hpp" #include "support/util.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" -#include +#include #include struct tinytc_prog final : tinytc::reference_counted { diff --git a/src/node/region_node.cpp b/src/node/region_node.cpp index 0bc6891c..90421c98 100644 --- a/src/node/region_node.cpp +++ b/src/node/region_node.cpp @@ -3,6 +3,9 @@ #include "node/region_node.hpp" #include "node/inst_node.hpp" +#include "tinytc/tinytc.h" + +#include namespace tinytc { diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp index b6f23dfb..f142db8a 100644 --- a/src/node/region_node.hpp +++ b/src/node/region_node.hpp @@ -7,12 +7,12 @@ #include "node/value_node.hpp" #include "support/ilist.hpp" #include "support/util.hpp" -#include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" #include +#include #include #include diff --git a/src/node/value_node.hpp b/src/node/value_node.hpp index 8d40e9ce..bb6348fe 100644 --- a/src/node/value_node.hpp +++ b/src/node/value_node.hpp @@ -4,14 +4,15 @@ #ifndef VALUE_NODE_20230309_HPP #define VALUE_NODE_20230309_HPP -#include "location.hpp" #include "node/data_type_node.hpp" #include "support/util.hpp" #include "tinytc/types.h" +#include "tinytc/types.hpp" -#include +#include #include #include +#include #include namespace tinytc { diff --git a/src/parser/parse_context.cpp b/src/parser/parse_context.cpp index 153755b5..09c643cf 100644 --- a/src/parser/parse_context.cpp +++ b/src/parser/parse_context.cpp @@ -7,6 +7,7 @@ #include "node/value_node.hpp" #include "parser/parser_impl.hpp" +#include #include #include diff --git a/src/parser/parse_context.hpp b/src/parser/parse_context.hpp index e57855b0..4c511d01 100644 --- a/src/parser/parse_context.hpp +++ b/src/parser/parse_context.hpp @@ -4,9 +4,13 @@ #ifndef PARSE_CONTEXT_20231221_HPP #define PARSE_CONTEXT_20231221_HPP +#include "node/region_node.hpp" +#include "node/value_node.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.h" #include "tinytc/types.hpp" +#include #include #include #include diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index e26859f1..97102cfd 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -5,11 +5,13 @@ %language "c++" %code requires { + #include "node/data_type_node.hpp" #include "node/function_node.hpp" #include "node/inst_node.hpp" + #include "node/value_node.hpp" #include "tinytc/tinytc.hpp" + #include "tinytc/types.h" #include "tinytc/types.hpp" - #include #include #include #include @@ -28,18 +30,19 @@ %code { #include "error.hpp" - #include "node/data_type_node.hpp" #include "node/program_node.hpp" #include "node/region_node.hpp" - #include "node/value_node.hpp" #include "parser/lexer.hpp" #include "parser/parse_context.hpp" + #include "support/ilist.hpp" #include "support/util.hpp" #include "support/visit.hpp" #include + #include #include #include + #include #include #include diff --git a/src/pass/check_ir.cpp b/src/pass/check_ir.cpp index 873d805d..e8100376 100644 --- a/src/pass/check_ir.cpp +++ b/src/pass/check_ir.cpp @@ -6,7 +6,6 @@ #include "node/inst_node.hpp" #include "node/region_node.hpp" #include "support/walk.hpp" -#include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" namespace tinytc { diff --git a/src/pass/constant_propagation.cpp b/src/pass/constant_propagation.cpp index cf88afde..64c34527 100644 --- a/src/pass/constant_propagation.cpp +++ b/src/pass/constant_propagation.cpp @@ -3,17 +3,24 @@ #include "pass/constant_propagation.hpp" #include "error.hpp" +#include "node/data_type_node.hpp" #include "node/function_node.hpp" #include "node/inst_node.hpp" #include "node/region_node.hpp" +#include "node/value_node.hpp" #include "pass/constant_propagation_helper.hpp" #include "scalar_type.hpp" #include "support/casting.hpp" +#include "support/ilist.hpp" +#include "support/ilist_base.hpp" #include "support/visit.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" -#include #include +#include +#include +#include #include namespace tinytc { diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 3a617337..db1e43f4 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -7,6 +7,8 @@ #include "gemm_generator.hpp" #include "scalar_type.hpp" #include "support/casting.hpp" +#include "support/ilist.hpp" +#include "support/ilist_base.hpp" #include "support/util.hpp" #include "support/visit.hpp" #include "tinytc/tinytc.hpp" @@ -25,11 +27,12 @@ #include #include #include +#include #include #include +#include #include #include -#include #include namespace tinytc { diff --git a/src/pass/dead_code_elimination.cpp b/src/pass/dead_code_elimination.cpp index 5f1eabaf..17775b8e 100644 --- a/src/pass/dead_code_elimination.cpp +++ b/src/pass/dead_code_elimination.cpp @@ -5,14 +5,15 @@ #include "node/function_node.hpp" #include "node/inst_node.hpp" #include "node/region_node.hpp" +#include "node/value_node.hpp" #include "support/casting.hpp" +#include "support/ilist.hpp" +#include "support/ilist_base.hpp" #include "support/visit.hpp" -#include "tinytc/tinytc.hpp" +#include #include -#include - namespace tinytc { class dead_code_analysis { diff --git a/src/pass/dump_def_use.cpp b/src/pass/dump_def_use.cpp index efcfad4b..795e7731 100644 --- a/src/pass/dump_def_use.cpp +++ b/src/pass/dump_def_use.cpp @@ -2,11 +2,18 @@ // SPDX-License-Identifier: BSD-3-Clause #include "pass/dump_def_use.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" #include "pass/dump_ir.hpp" #include "support/util.hpp" #include "support/visit.hpp" #include "support/walk.hpp" +#include +#include +#include + namespace tinytc { dump_def_use_pass::dump_def_use_pass(std::ostream &os) : os_(&os) {} diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 28c0b5f6..498d2a31 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -2,11 +2,17 @@ // SPDX-License-Identifier: BSD-3-Clause #include "pass/dump_ir.hpp" +#include "support/ilist_base.hpp" +#include "support/util.hpp" #include "support/visit.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" #include +#include +#include #include +#include #include #include #include diff --git a/src/pass/dump_ir.hpp b/src/pass/dump_ir.hpp index 788115da..9d080acc 100644 --- a/src/pass/dump_ir.hpp +++ b/src/pass/dump_ir.hpp @@ -7,7 +7,6 @@ #include "node/data_type_node.hpp" #include "node/function_node.hpp" #include "node/inst_node.hpp" -#include "node/program_node.hpp" #include "node/region_node.hpp" #include "node/value_node.hpp" #include "pass/slot_tracker.hpp" diff --git a/src/pass/insert_barrier.cpp b/src/pass/insert_barrier.cpp index 24298181..3070dc31 100644 --- a/src/pass/insert_barrier.cpp +++ b/src/pass/insert_barrier.cpp @@ -12,7 +12,6 @@ #include "support/ilist.hpp" #include "support/util.hpp" #include "support/visit.hpp" -#include "tinytc/tinytc.hpp" #include #include diff --git a/src/pass/insert_lifetime_stop.cpp b/src/pass/insert_lifetime_stop.cpp index c27c68e6..c9563f4c 100644 --- a/src/pass/insert_lifetime_stop.cpp +++ b/src/pass/insert_lifetime_stop.cpp @@ -9,7 +9,8 @@ #include "support/casting.hpp" #include "support/ilist.hpp" #include "support/ilist_base.hpp" -#include "tinytc/tinytc.hpp" +#include "support/util.hpp" +#include "tinytc/types.h" #include #include diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index 81b83a23..b27b8ba9 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -3,14 +3,25 @@ #include "pass/lower_linalg.hpp" #include "codegen_tools.hpp" +#include "device_info.hpp" #include "error.hpp" +#include "node/data_type_node.hpp" #include "node/function_node.hpp" #include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" #include "support/casting.hpp" +#include "support/ilist.hpp" +#include "support/ilist_base.hpp" #include "support/visit.hpp" #include "support/walk.hpp" #include "tiling.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" + +#include +#include +#include namespace tinytc { diff --git a/src/pass/slot_tracker.cpp b/src/pass/slot_tracker.cpp index d174b5d9..685ed47c 100644 --- a/src/pass/slot_tracker.cpp +++ b/src/pass/slot_tracker.cpp @@ -3,10 +3,12 @@ #include "pass/slot_tracker.hpp" #include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "support/util.hpp" #include "support/walk.hpp" -#include "tinytc/tinytc.hpp" #include +#include namespace tinytc { diff --git a/src/pass/stack.cpp b/src/pass/stack.cpp index 9c233afe..a8bfdff8 100644 --- a/src/pass/stack.cpp +++ b/src/pass/stack.cpp @@ -9,7 +9,6 @@ #include "support/casting.hpp" #include "support/visit.hpp" #include "support/walk.hpp" -#include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" #include diff --git a/src/pass/work_group_size.cpp b/src/pass/work_group_size.cpp index 7e13d417..00e02bb7 100644 --- a/src/pass/work_group_size.cpp +++ b/src/pass/work_group_size.cpp @@ -11,7 +11,6 @@ #include "support/visit.hpp" #include "support/walk.hpp" #include "tiling.hpp" -#include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" #include diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index 4a02263f..8c969ee1 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -12,6 +12,7 @@ #include "tinytc/types.h" #include "tinytc/types.hpp" +#include #include #include #include diff --git a/src/region.cpp b/src/region.cpp index 9f1d9ea6..a5e016a3 100644 --- a/src/region.cpp +++ b/src/region.cpp @@ -2,14 +2,18 @@ // SPDX-License-Identifier: BSD-3-Clause #include "error.hpp" -#include "location.hpp" #include "node/inst_node.hpp" #include "node/region_node.hpp" +#include "node/value_node.hpp" #include "support/ilist.hpp" #include "tinytc/tinytc.h" #include "tinytc/types.h" -#include +#include +#include +#include +#include +#include using namespace tinytc; diff --git a/src/support/walk.hpp b/src/support/walk.hpp index 57b55220..2d0d38d6 100644 --- a/src/support/walk.hpp +++ b/src/support/walk.hpp @@ -8,9 +8,10 @@ #include "node/inst_node.hpp" #include "node/region_node.hpp" #include "support/ilist_base.hpp" -#include "tinytc/tinytc.hpp" #include +#include +#include namespace tinytc { diff --git a/src/value.cpp b/src/value.cpp index b6316a7a..3f87165b 100644 --- a/src/value.cpp +++ b/src/value.cpp @@ -2,16 +2,11 @@ // SPDX-License-Identifier: BSD-3-Clause #include "error.hpp" -#include "location.hpp" #include "node/value_node.hpp" -#include "support/util.hpp" #include "tinytc/tinytc.h" -#include "tinytc/tinytc.hpp" #include "tinytc/types.h" -#include "tinytc/types.hpp" #include -#include #include using namespace tinytc; diff --git a/test/opt/dump-def-use.ir b/test/opt/dump-def-use.ir index 8e9dc3e5..d0e88a9c 100644 --- a/test/opt/dump-def-use.ir +++ b/test/opt/dump-def-use.ir @@ -8,14 +8,14 @@ func @foobar() { %lb = constant 0 -> index %ub = constant 5 -> index for %i=%lb,%ub : index { - %1 = arith.add %i, %one : index - %2 = arith.rem %1, %one : index + %0 = arith.add %i, %one : index + %1 = arith.rem %0, %one : index } ; CHECK: Def-use in foobar ; CHECK-NEXT: > %one = constant 1 -> index ; CHECK-NEXT: def %one -; CHECK-NEXT: > %2 = arith.rem %1, %one : index -; CHECK-NEXT: > %1 = arith.add %i, %one : index +; CHECK-NEXT: > %1 = arith.rem %0, %one : index +; CHECK-NEXT: > %0 = arith.add %i, %one : index ; CHECK-NEXT: > %lb = constant 0 -> index ; CHECK-NEXT: def %lb ; CHECK-NEXT: > for %i=%lb,%ub : index {...} @@ -24,10 +24,10 @@ func @foobar() { ; CHECK-NEXT: > for %i=%lb,%ub : index {...} ; CHECK-NEXT: > for %i=%lb,%ub : index {...} ; CHECK-NEXT: def %i -; CHECK-NEXT: > %1 = arith.add %i, %one : index -; CHECK-NEXT: > %1 = arith.add %i, %one : index +; CHECK-NEXT: > %0 = arith.add %i, %one : index +; CHECK-NEXT: > %0 = arith.add %i, %one : index +; CHECK-NEXT: def %0 +; CHECK-NEXT: > %1 = arith.rem %0, %one : index +; CHECK-NEXT: > %1 = arith.rem %0, %one : index ; CHECK-NEXT: def %1 -; CHECK-NEXT: > %2 = arith.rem %1, %one : index -; CHECK-NEXT: > %2 = arith.rem %1, %one : index -; CHECK-NEXT: def %2 } diff --git a/test/opt/insert-barrier.ir b/test/opt/insert-barrier.ir index 31a1b74b..c9f78e8f 100644 --- a/test/opt/insert-barrier.ir +++ b/test/opt/insert-barrier.ir @@ -161,23 +161,23 @@ func @no_barrier_spmd(%a: f32, %b: f32, %A: memref, %B: memref %c3 = constant 3 -> i32 %c4 = constant 4 -> i32 parallel { - %1 = subgroup_id - %2 = cmp.eq %1, %c0 : i32 - if %2 { - %3 = load %A[%c3,%c4] : memref - store %3, %A[%c3,%c4] : memref + %0 = subgroup_id + %1 = cmp.eq %1, %c0 : i32 + if %1 { + %2 = load %A[%c3,%c4] : memref + store %2, %A[%c3,%c4] : memref } } - %0 = load %A[%c3,%c4] : memref + %3 = load %A[%c3,%c4] : memref ; CHECK-LABEL: func @no_barrier_spmd({{.*}} ; CHECK: parallel { -; CHECK-NEXT: %1 = subgroup_id -; CHECK-NEXT: %2 = cmp.eq %1, %c0 : i32 -; CHECK-NEXT: if %2 { -; CHECK-NEXT: %3 = load %A[%c3,%c4] : memref -; CHECK-NEXT: store %3, %A[%c3,%c4] : memref +; CHECK-NEXT: %0 = subgroup_id +; CHECK-NEXT: %1 = cmp.eq %1, %c0 : i32 +; CHECK-NEXT: if %1 { +; CHECK-NEXT: %2 = load %A[%c3,%c4] : memref +; CHECK-NEXT: store %2, %A[%c3,%c4] : memref ; CHECK-NEXT: } ; CHECK-NEXT: } ; CHECK-NEXT: barrier.global -; CHECK-NEXT: %0 = load %A[%c3,%c4] : memref +; CHECK-NEXT: %3 = load %A[%c3,%c4] : memref } From 9ee19e44cad991aaeb3f5f7a172f33f64a184c28 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 8 Oct 2024 09:12:01 +0200 Subject: [PATCH 043/297] Fix tests Signed-off-by: Carsten Uphoff --- docs/api/core_capi.rst | 7 +++ docs/api/core_capi.yaml | 1 + include/tinytc/tinytc.h | 11 ++++ include/tinytc/tinytc.hpp | 8 +++ src/compiler.cpp | 19 +++--- src/compiler_context.cpp | 9 +++ src/compiler_context.hpp | 4 ++ test/codegen/dope_vector0.ir | 12 ++-- test/codegen/dope_vector_group0.ir | 12 ++-- test/codegen/expand.ir | 98 +++++++++++++++--------------- test/codegen/for.ir | 4 +- test/codegen/fuse.ir | 14 ++--- test/codegen/if.ir | 16 ++--- test/codegen/load.ir | 6 +- test/codegen/scalar_arithmetic.ir | 68 ++++++++++----------- test/codegen/size.ir | 6 +- test/codegen/store.ir | 2 +- test/codegen/subgroup.ir | 4 +- test/opt/insert-barrier.ir | 4 +- tools/offline_compiler/args.cpp | 7 +++ tools/offline_compiler/args.hpp | 2 + tools/offline_compiler/main.cpp | 1 + tools/opt/args.cpp | 1 + 23 files changed, 186 insertions(+), 130 deletions(-) diff --git a/docs/api/core_capi.rst b/docs/api/core_capi.rst index 8188e74f..d83e05c0 100644 --- a/docs/api/core_capi.rst +++ b/docs/api/core_capi.rst @@ -291,6 +291,8 @@ Compiler Context * :ref:`tinytc_compiler_context_set_error_reporter` + * :ref:`tinytc_compiler_context_set_optimization_level` + * :ref:`tinytc_compiler_context_report_error` * :ref:`tinytc_compiler_context_release` @@ -315,6 +317,11 @@ tinytc_compiler_context_set_error_reporter .. doxygenfunction:: tinytc_compiler_context_set_error_reporter +tinytc_compiler_context_set_optimization_level +.............................................. + +.. doxygenfunction:: tinytc_compiler_context_set_optimization_level + tinytc_compiler_context_report_error .................................... diff --git a/docs/api/core_capi.yaml b/docs/api/core_capi.yaml index d921060b..f5a1da3d 100644 --- a/docs/api/core_capi.yaml +++ b/docs/api/core_capi.yaml @@ -49,6 +49,7 @@ Core C-API: - tinytc_compiler_context_create - tinytc_compiler_context_add_source - tinytc_compiler_context_set_error_reporter + - tinytc_compiler_context_set_optimization_level - tinytc_compiler_context_report_error - tinytc_compiler_context_release - tinytc_compiler_context_retain diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index fec78fec..f16b731d 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -1195,6 +1195,17 @@ TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_add_source(tinytc_compiler TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_set_error_reporter( tinytc_compiler_context_t ctx, tinytc_error_reporter_t reporter, void *user_data); +/** + * @brief Set optimization level (from 0 to 2) + * + * @param ctx [inout] context object + * @param level [in] optimization level + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t +tinytc_compiler_context_set_optimization_level(tinytc_compiler_context_t ctx, int32_t level); + /** * @brief Report an error and augment the error with source context * diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index a14c1b02..bbb6ffda 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -494,6 +494,14 @@ class compiler_context : public shared_handle { inline void set_error_reporter(error_reporter_t reporter, void *user_data) { CHECK_STATUS(tinytc_compiler_context_set_error_reporter(obj_, reporter, user_data)); } + /** + * @brief Set optimization level + * + * @param level optimization level + */ + inline void set_optimization_level(std::int32_t level) { + CHECK_STATUS(tinytc_compiler_context_set_optimization_level(obj_, level)); + } /** * @brief Enhance error message with compiler context; useful when builder is used * diff --git a/src/compiler.cpp b/src/compiler.cpp index 8e35d6ea..71b13cfc 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -85,13 +85,17 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ } return exception_to_status_code( [&] { + const auto opt_level = prg->get_context()->opt_level(); + // passes run_function_pass(check_ir_pass{}, *prg); - // We run constant propagation + dead code elimination early to capture dead allocas - // (later on they are maybe "in use" due to the lifetime_stop instruction) - run_function_pass(constant_propagation_pass{}, *prg); - run_function_pass(dead_code_elimination_pass{}, *prg); + if (opt_level >= 1) { + // We run constant propagation + dead code elimination early to capture dead allocas + // (later on they are maybe "in use" due to the lifetime_stop instruction) + run_function_pass(constant_propagation_pass{}, *prg); + run_function_pass(dead_code_elimination_pass{}, *prg); + } run_function_pass(insert_lifetime_stop_pass{}, *prg); run_function_pass(set_stack_ptr_pass{}, *prg); @@ -99,9 +103,10 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ run_function_pass(work_group_size_pass{info}, *prg); run_function_pass(lower_linalg_pass{info}, *prg); - run_function_pass(constant_propagation_pass{}, *prg); - run_function_pass(dead_code_elimination_pass{}, *prg); - run_function_pass(dump_ir_pass{std::cout}, *prg); + if (opt_level >= 1) { + run_function_pass(constant_propagation_pass{}, *prg); + run_function_pass(dead_code_elimination_pass{}, *prg); + } // opencl auto ast = convert_to_opencl_pass{info}.run_on_program(*prg); diff --git a/src/compiler_context.cpp b/src/compiler_context.cpp index be8f54c9..505da8e9 100644 --- a/src/compiler_context.cpp +++ b/src/compiler_context.cpp @@ -71,6 +71,15 @@ tinytc_status_t tinytc_compiler_context_set_error_reporter(tinytc_compiler_conte return exception_to_status_code([&] { ctx->set_error_reporter(reporter, user_data); }); } +tinytc_status_t tinytc_compiler_context_set_optimization_level(tinytc_compiler_context_t ctx, + int32_t level) { + if (ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + ctx->opt_level(level); + return tinytc_status_success; +} + tinytc_status_t tinytc_compiler_context_report_error(tinytc_compiler_context_t ctx, const tinytc_location_t *location, char const *what) { diff --git a/src/compiler_context.hpp b/src/compiler_context.hpp index 5a19a258..dd14aa99 100644 --- a/src/compiler_context.hpp +++ b/src/compiler_context.hpp @@ -48,6 +48,9 @@ struct tinytc_compiler_context : tinytc::reference_counted { auto source_text(std::int32_t source_id) -> std::pair; void report_error(tinytc_location const &l, char const *what); + auto opt_level() const noexcept -> std::int32_t { return opt_level_; } + void opt_level(std::int32_t level) noexcept { opt_level_ = level; } + private: struct source_input { std::string name, text; @@ -61,6 +64,7 @@ struct tinytc_compiler_context : tinytc::reference_counted { tinytc::error_reporter_t reporter_ = &tinytc::default_error_reporter; void *user_data_ = nullptr; std::vector sources_; + std::int32_t opt_level_ = 2; }; #endif // COMPILER_CONTEXT_20240924_HPP diff --git a/test/codegen/dope_vector0.ir b/test/codegen/dope_vector0.ir index 67ed61c7..71d6bcae 100644 --- a/test/codegen/dope_vector0.ir +++ b/test/codegen/dope_vector0.ir @@ -1,18 +1,18 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc < %s | filecheck %s +; RUN: %tinytc-oc -O0 < %s | filecheck %s func @kernel(%K0: memref, %offset: index, %size: index) { %0 = subview %K0[4:%size, %offset] : memref ; CHECK: void kernel({{.*}} -; CHECK-NEXT: global float* x0 = K0 + 4ll * 1 + offset * K0_stride1; -; CHECK-NEXT: long x0_shape0 = size; +; CHECK-NEXT: global float* x = K0 + 4ll * 1 + offset * K0_stride1; +; CHECK-NEXT: long x_shape0 = size; } func @kernel2(%K0: memref, %offset: index, %size: index) { %0 = subview %K0[%offset, 4:%size] : memref ; CHECK: void kernel2({{.*}} -; CHECK-NEXT: global float* x0 = K0 + offset * 1 + 4ll * K0_stride1; -; CHECK-NEXT: long x0_shape0 = size; -; CHECK-NEXT: long x0_stride0 = K0_stride1; +; CHECK-NEXT: global float* x = K0 + offset * 1 + 4ll * K0_stride1; +; CHECK-NEXT: long x_shape0 = size; +; CHECK-NEXT: long x_stride0 = K0_stride1; } diff --git a/test/codegen/dope_vector_group0.ir b/test/codegen/dope_vector_group0.ir index acdbf875..81886f91 100644 --- a/test/codegen/dope_vector_group0.ir +++ b/test/codegen/dope_vector_group0.ir @@ -1,15 +1,15 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc < %s | filecheck %s +; RUN: %tinytc-oc -O0 < %s | filecheck %s func @kernel1(%in: group>) { ; CHECK: void kernel1(global float*global* in, global long* in_shape1, global long* in_stride2) %c5 = constant 5 -> index %0 = load %in[%c5] : group> ; CHECK-NEXT: long c5 = 5ll; - ; CHECK-NEXT: global float* x0 = *(in + c5) + 0; - ; CHECK-NEXT: long x0_shape1 = in_shape1[c5]; - ; CHECK-NEXT: long x0_stride2 = in_stride2[c5]; + ; CHECK-NEXT: global float* x = *(in + c5) + 0; + ; CHECK-NEXT: long x_shape1 = in_shape1[c5]; + ; CHECK-NEXT: long x_stride2 = in_stride2[c5]; } func @kernel2(%in: group, offset: ?>) { @@ -17,6 +17,6 @@ func @kernel2(%in: group, offset: ?>) { %c5 = constant 5 -> index %0 = load %in[%c5] : group, offset: ?> ; CHECK-NEXT: long c5 = 5ll; - ; CHECK-NEXT: global float* x0 = *(in + c5) + in_offset; - ; CHECK-NEXT: long x0_shape0 = in_shape0[c5]; + ; CHECK-NEXT: global float* x = *(in + c5) + in_offset; + ; CHECK-NEXT: long x_shape0 = in_shape0[c5]; } diff --git a/test/codegen/expand.ir b/test/codegen/expand.ir index 217f5e97..188651af 100644 --- a/test/codegen/expand.ir +++ b/test/codegen/expand.ir @@ -1,13 +1,13 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc < %s | filecheck %s +; RUN: %tinytc-oc -O0 < %s | filecheck %s func @t1(%0: memref) { %z = constant 0 -> index %1 = expand %0[1->2x8] : memref %2 = load %1[%z,%z,%z,%z] : memref ; CHECK-LABEL: void t1( -; CHECK: global float* x1 = x0; +; CHECK: global float* x1 = x; ; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * 32 + z * 64 + z * 512); } func @t2(%0: memref) { @@ -15,7 +15,7 @@ func @t2(%0: memref) { %1 = expand %0[1->2x2x2x2] : memref %2 = load %1[%z,%z,%z,%z,%z,%z] : memref ; CHECK-LABEL: void t2( -; CHECK: global float* x1 = x0; +; CHECK: global float* x1 = x; ; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * 32 + z * 64 + z * 128 + z * 256 + z * 512); } func @t3(%0: memref, %1: index) { @@ -23,18 +23,18 @@ func @t3(%0: memref, %1: index) { %2 = expand %0[1->%1 x 2] : memref %3 = load %2[%z,%z,%z] : memref ; CHECK-LABEL: void t3( -; CHECK: global float* x2 = x0; -; CHECK-NEXT: long x2_shape1 = x1; -; CHECK-NEXT: long x2_stride2 = 32 * x1; -; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * x2_stride2); +; CHECK: global float* x2 = x; +; CHECK-NEXT: long x_shape11 = x1; +; CHECK-NEXT: long x_stride2 = 32 * x1; +; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * x_stride2); } func @t4(%0: memref, %1: index) { %z = constant 0 -> index %2 = expand %0[1->2 x %1] : memref %3 = load %2[%z,%z,%z] : memref ; CHECK-LABEL: void t4( -; CHECK: global float* x2 = x0; -; CHECK-NEXT: long x2_shape2 = x1; +; CHECK: global float* x2 = x; +; CHECK-NEXT: long x_shape2 = x1; ; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * 64); } func @t5(%0: memref, %1: index) { @@ -42,73 +42,73 @@ func @t5(%0: memref, %1: index) { %2 = expand %0[1->%1 x 2] : memref %3 = load %2[%z,%z,%z] : memref ; CHECK-LABEL: void t5( -; CHECK: global float* x2 = x0; -; CHECK-NEXT: long x2_shape1 = x1; -; CHECK-NEXT: long x2_stride2 = 32 * x1; -; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * x2_stride2); +; CHECK: global float* x2 = x; +; CHECK-NEXT: long x_shape1 = x1; +; CHECK-NEXT: long x_stride2 = 32 * x1; +; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * x_stride2); } func @t6(%0: memref, %1: index) { %z = constant 0 -> index %2 = expand %0[1->%1 x 2] : memref %3 = load %2[%z,%z,%z] : memref ; CHECK-LABEL: void t6( -; CHECK: global float* x2 = x0; -; CHECK-NEXT: long x2_shape1 = x1; -; CHECK-NEXT: long x2_stride2 = 32 * x1; -; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * x2_stride2); +; CHECK: global float* x2 = x; +; CHECK-NEXT: long x_shape11 = x1; +; CHECK-NEXT: long x_stride2 = 32 * x1; +; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * x_stride2); } func @t7(%0: memref, %1: index, %2: index) { %z = constant 0 -> index %3 = expand %0[1->%1 x %2 x 2] : memref %4 = load %3[%z,%z,%z,%z] : memref ; CHECK-LABEL: void t7( -; CHECK: global float* x3 = x0; -; CHECK-NEXT: long x3_shape1 = x1; -; CHECK-NEXT: long x3_shape2 = x2; -; CHECK-NEXT: long x3_stride2 = 32 * x1; -; CHECK-NEXT: long x3_stride3 = 32 * x1 * x2; -; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x3_stride2 + z * x3_stride3); +; CHECK: global float* x3 = x; +; CHECK-NEXT: long x_shape1 = x1; +; CHECK-NEXT: long x_shape2 = x2; +; CHECK-NEXT: long x_stride2 = 32 * x1; +; CHECK-NEXT: long x_stride3 = 32 * x1 * x2; +; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x_stride2 + z * x_stride3); } func @t8(%0: memref, %1: index, %2: index) { %z = constant 0 -> index %3 = expand %0[1->%2 x 2 x %1] : memref %4 = load %3[%z,%z,%z,%z] : memref ; CHECK-LABEL: void t8( -; CHECK: global float* x3 = x0; -; CHECK-NEXT: long x3_shape1 = x2; -; CHECK-NEXT: long x3_stride2 = 32 * x2; -; CHECK-NEXT: long x3_shape3 = x1; -; CHECK-NEXT: long x3_stride3 = 32 * x2 * 2ll; -; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x3_stride2 + z * x3_stride3); +; CHECK: global float* x3 = x; +; CHECK-NEXT: long x_shape1 = x2; +; CHECK-NEXT: long x_stride2 = 32 * x2; +; CHECK-NEXT: long x_shape3 = x1; +; CHECK-NEXT: long x_stride3 = 32 * x2 * 2ll; +; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x_stride2 + z * x_stride3); } func @t9(%0: memref, %1: index, %2: index) { %z = constant 0 -> index %3 = expand %0[1->%1 x %2] : memref %4 = load %3[%z,%z,%z] : memref ; CHECK-LABEL: void t9( -; CHECK: global float* x3 = x0; -; CHECK-NEXT: long x3_shape1 = x1; -; CHECK-NEXT: long x3_shape2 = x2; -; CHECK-NEXT: long x3_stride2 = 32 * x1; -; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x3_stride2); +; CHECK: global float* x3 = x; +; CHECK-NEXT: long x_shape11 = x1; +; CHECK-NEXT: long x_shape2 = x2; +; CHECK-NEXT: long x_stride2 = 32 * x1; +; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x_stride2); } func @t10(%0: memref, %1: index, %2: index) { %z = constant 0 -> index %3 = expand %0[1->%1 x %2] : memref %4 = load %3[%z,%z,%z] : memref ; CHECK-LABEL: void t10( -; CHECK: global float* x3 = x0; -; CHECK-NEXT: long x3_shape1 = x1; -; CHECK-NEXT: long x3_shape2 = x2; -; CHECK-NEXT: long x3_stride2 = 32 * x1; -; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x3_stride2); +; CHECK: global float* x3 = x; +; CHECK-NEXT: long x_shape1 = x1; +; CHECK-NEXT: long x_shape2 = x2; +; CHECK-NEXT: long x_stride2 = 32 * x1; +; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x_stride2); } func @t11(%0: memref>) { %z = constant 0 -> index %1 = expand %0[0->4 x 8] : memref> %2 = load %1[%z,%z,%z] : memref> ; CHECK-LABEL: void t11( -; CHECK: global float* x1 = x0; +; CHECK: global float* x1 = x; ; CHECK-NEXT: float x2 = *(x1 + z * 2 + z * 8 + z * 64); } func @t12(%0: memref>, %1: index) { @@ -116,19 +116,19 @@ func @t12(%0: memref>, %1: index) { %2 = expand %0[0->%1 x 4] : memref> %3 = load %2[%z,%z,%z] : memref> ; CHECK-LABEL: void t12( -; CHECK: global float* x2 = x0; -; CHECK-NEXT: long x2_shape0 = x1; -; CHECK-NEXT: long x2_stride1 = 2 * x1; -; CHECK-NEXT: long x2_stride2 = x0_stride1; -; CHECK-NEXT: float x3 = *(x2 + z * 2 + z * x2_stride1 + z * x2_stride2); +; CHECK: global float* x2 = x; +; CHECK-NEXT: long x_shape01 = x1; +; CHECK-NEXT: long x_stride11 = 2 * x1; +; CHECK-NEXT: long x_stride2 = x_stride1; +; CHECK-NEXT: float x3 = *(x2 + z * 2 + z * x_stride11 + z * x_stride2); } func @t13(%0: memref>, %1: index) { %z = constant 0 -> index %2 = expand %0[0->4 x %1] : memref> %3 = load %2[%z,%z,%z] : memref> ; CHECK-LABEL: void t13( -; CHECK: global float* x2 = x0; -; CHECK-NEXT: long x2_shape1 = x1; -; CHECK-NEXT: long x2_stride2 = x0_stride1; -; CHECK-NEXT: float x3 = *(x2 + z * 2 + z * 8 + z * x2_stride2); +; CHECK: global float* x2 = x; +; CHECK-NEXT: long x_shape1 = x1; +; CHECK-NEXT: long x_stride2 = x_stride1; +; CHECK-NEXT: float x3 = *(x2 + z * 2 + z * 8 + z * x_stride2); } diff --git a/test/codegen/for.ir b/test/codegen/for.ir index 1fda98d9..3b6564f6 100644 --- a/test/codegen/for.ir +++ b/test/codegen/for.ir @@ -6,11 +6,11 @@ func @for1() { %lb0 = constant 0 -> index %ub0 = constant 10 -> index for %0 = %lb0,%ub0 { - ; CHECK: for (long x0 = lb0; x0 < ub0; ++x0) + ; CHECK: for (long x = lb0; x < ub0; ++x) } %lb1 = constant -2 -> i16 %ub1 = constant 2 -> i16 for %1 = %lb1,%ub1 : i16 { - ; CHECK: for (short x1 = lb1; x1 < ub1; ++x1) + ; CHECK: for (short x = lb1; x < ub1; ++x) } } diff --git a/test/codegen/fuse.ir b/test/codegen/fuse.ir index 2873ccbc..a6c4e96e 100644 --- a/test/codegen/fuse.ir +++ b/test/codegen/fuse.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc < %s | filecheck %s +; RUN: %tinytc-oc -O0 < %s | filecheck %s func @t1(%0: memref) { %z = constant 0 -> index %1 = fuse %0[1,3] : memref @@ -12,9 +12,9 @@ func @t2(%0: memref) { %z = constant 0 -> index %1 = fuse %0[1,3] : memref %2 = load %1[%z,%z,%z] : memref> -; CHECK: long x1_shape1 = 16 * x0_shape2 * 4; -; CHECK-NEXT: long x1_stride2 = x0_stride4; -; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * 32 + z * x1_stride2); +; CHECK: long x_shape1 = 16 * x_shape2 * 4; +; CHECK-NEXT: long x_stride2 = x_stride4; +; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * 32 + z * x_stride2); } func @t3(%0: memref>) { %z = constant 0 -> index @@ -26,7 +26,7 @@ func @t4(%0: memref>) { %z = constant 0 -> index %1 = fuse %0[0,1] : memref> %2 = load %1[%z,%z] : memref> -; CHECK: long x1_shape0 = 8 * x0_shape1; -; CHECK-NEXT: long x1_stride1 = x0_stride2; -; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * x1_stride1); +; CHECK: long x_shape0 = 8 * x_shape1; +; CHECK-NEXT: long x_stride11 = x_stride2; +; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * x_stride11); } diff --git a/test/codegen/if.ir b/test/codegen/if.ir index 4accd465..2efba14f 100644 --- a/test/codegen/if.ir +++ b/test/codegen/if.ir @@ -11,8 +11,8 @@ func @if0(%0: i32) { if %3 { } else { } -; CHECK: bool x1 = x0 < c16; -; CHECK: bool x2 = x0 >= c0; +; CHECK: bool x1 = x < c16; +; CHECK: bool x2 = x >= c0; ; CHECK: bool x3 = x1 && x2; ; CHECK: if (x3) { ; CHECK-NEXT: } @@ -49,11 +49,11 @@ func @if3(%0: i32) { } else { yield %c16 : i32 } -; CHECK: int x; +; CHECK: int x2; ; CHECK-NEXT: if (x1) { -; CHECK-NEXT: x = x0; +; CHECK-NEXT: x2 = x; ; CHECK-NEXT: } else { -; CHECK-NEXT: x = c16; +; CHECK-NEXT: x2 = c16; ; CHECK-NEXT: } } @@ -75,13 +75,13 @@ func @if4(%0: i32) { } yield %c16, %z : i32, f32 } -; CHECK: int x; +; CHECK: int x2; ; CHECK-NEXT: float y; ; CHECK-NEXT: if (x1) { ; CHECK-NEXT: if (x1) { ; CHECK-NEXT: } ; CHECK-NEXT: float one = 0x1p+0f; -; CHECK-NEXT: x = x0; +; CHECK-NEXT: x2 = x; ; CHECK-NEXT: y = one; ; CHECK-NEXT: } else { ; CHECK-NEXT: float z; @@ -92,7 +92,7 @@ func @if4(%0: i32) { ; CHECK-NEXT: float zero = 0x0p+0f; ; CHECK-NEXT: z = zero; ; CHECK-NEXT: } -; CHECK-NEXT: x = c16; +; CHECK-NEXT: x2 = c16; ; CHECK-NEXT: y = z; ; CHECK-NEXT: } } diff --git a/test/codegen/load.ir b/test/codegen/load.ir index 53288b52..98fdeec2 100644 --- a/test/codegen/load.ir +++ b/test/codegen/load.ir @@ -1,14 +1,14 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc < %s | filecheck %s +; RUN: %tinytc-oc -O0 < %s | filecheck %s func @kernel1(%a: memref, %b: memref, %c: group>) { %c5 = constant 5 -> index %0 = load %a[] : memref %1 = group_id %2 = load %b[%c5, %1] : memref %3 = load %c[%1] : group> - ; CHECK: float x0 = *a; + ; CHECK: float x = *a; ; CHECK-NEXT: long x1 = get_global_id(2); ; CHECK-NEXT: float x2 = *(b + c5 * 1 + x1 * 10); ; CHECK-NEXT: global float* x3 = *(c + x1) + 0; @@ -17,5 +17,5 @@ func @kernel1(%a: memref, %b: memref, %c: group>) func @kernel2(%c: group, offset: 21>) { %0 = group_id %1 = load %c[%0] : group, offset: 21> - ; CHECK: global float* x1 = *(c + x0) + 21; + ; CHECK: global float* x1 = *(c + x) + 21; } diff --git a/test/codegen/scalar_arithmetic.ir b/test/codegen/scalar_arithmetic.ir index 4334d329..4a8b7f48 100644 --- a/test/codegen/scalar_arithmetic.ir +++ b/test/codegen/scalar_arithmetic.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc < %s | filecheck %s +; RUN: %tinytc-oc -O0 < %s | filecheck %s func @t1(%a: i32, %b: i32, %a1: i1, %b1: i1) { %1 = arith.add %a, %b : i32 %2 = arith.sub %a, %b : i32 @@ -18,21 +18,21 @@ func @t1(%a: i32, %b: i32, %a1: i1, %b1: i1) { %13 = arith.neg %a : i32 %14 = arith.not %a : i32 %15 = arith.not %a1 : i1 -; CHECK: int x1 = a + b; -; CHECK-NEXT: int x2 = a - b; -; CHECK-NEXT: int x3 = a * b; -; CHECK-NEXT: int x4 = a / b; -; CHECK-NEXT: int x5 = a % b; -; CHECK-NEXT: int x6 = a << b; -; CHECK-NEXT: int x7 = a >> b; -; CHECK-NEXT: int x8 = a & b; -; CHECK-NEXT: bool x9 = a1 && b1; -; CHECK-NEXT: int x10 = a | b; -; CHECK-NEXT: bool x11 = a1 || b1; -; CHECK-NEXT: int x12 = a ^ b; -; CHECK-NEXT: int x13 = -a; -; CHECK-NEXT: int x14 = ~a; -; CHECK-NEXT: bool x15 = !a1; +; CHECK: int x = a + b; +; CHECK-NEXT: int x1 = a - b; +; CHECK-NEXT: int x2 = a * b; +; CHECK-NEXT: int x3 = a / b; +; CHECK-NEXT: int x4 = a % b; +; CHECK-NEXT: int x5 = a << b; +; CHECK-NEXT: int x6 = a >> b; +; CHECK-NEXT: int x7 = a & b; +; CHECK-NEXT: bool x8 = a1 && b1; +; CHECK-NEXT: int x9 = a | b; +; CHECK-NEXT: bool x10 = a1 || b1; +; CHECK-NEXT: int x11 = a ^ b; +; CHECK-NEXT: int x12 = -a; +; CHECK-NEXT: int x13 = ~a; +; CHECK-NEXT: bool x14 = !a1; } func @t2(%a: i32, %b: i32) { %1 = cmp.eq %a, %b : i32 @@ -41,12 +41,12 @@ func @t2(%a: i32, %b: i32) { %4 = cmp.ge %a, %b : i32 %5 = cmp.lt %a, %b : i32 %6 = cmp.le %a, %b : i32 -; CHECK: bool x1 = a == b; -; CHECK-NEXT: bool x2 = a != b; -; CHECK-NEXT: bool x3 = a > b; -; CHECK-NEXT: bool x4 = a >= b; -; CHECK-NEXT: bool x5 = a < b; -; CHECK-NEXT: bool x6 = a <= b; +; CHECK: bool x = a == b; +; CHECK-NEXT: bool x1 = a != b; +; CHECK-NEXT: bool x2 = a > b; +; CHECK-NEXT: bool x3 = a >= b; +; CHECK-NEXT: bool x4 = a < b; +; CHECK-NEXT: bool x5 = a <= b; } func @t3(%a: f32, %b: f32) { %1 = arith.add %a, %b : f32 @@ -55,12 +55,12 @@ func @t3(%a: f32, %b: f32) { %4 = arith.div %a, %b : f32 %5 = arith.rem %a, %b : f32 %6 = arith.neg %a : f32 -; CHECK: float x1 = a + b; -; CHECK-NEXT: float x2 = a - b; -; CHECK-NEXT: float x3 = a * b; -; CHECK-NEXT: float x4 = a / b; -; CHECK-NEXT: float x5 = fmod(a, b); -; CHECK-NEXT: float x6 = -a; +; CHECK: float x = a + b; +; CHECK-NEXT: float x1 = a - b; +; CHECK-NEXT: float x2 = a * b; +; CHECK-NEXT: float x3 = a / b; +; CHECK-NEXT: float x4 = fmod(a, b); +; CHECK-NEXT: float x5 = -a; } func @t4(%a: f32, %b: f32) { %1 = cmp.eq %a, %b : f32 @@ -69,12 +69,12 @@ func @t4(%a: f32, %b: f32) { %4 = cmp.ge %a, %b : f32 %5 = cmp.lt %a, %b : f32 %6 = cmp.le %a, %b : f32 -; CHECK: bool x1 = a == b; -; CHECK-NEXT: bool x2 = a != b; -; CHECK-NEXT: bool x3 = a > b; -; CHECK-NEXT: bool x4 = a >= b; -; CHECK-NEXT: bool x5 = a < b; -; CHECK-NEXT: bool x6 = a <= b; +; CHECK: bool x = a == b; +; CHECK-NEXT: bool x1 = a != b; +; CHECK-NEXT: bool x2 = a > b; +; CHECK-NEXT: bool x3 = a >= b; +; CHECK-NEXT: bool x4 = a < b; +; CHECK-NEXT: bool x5 = a <= b; } func @t5(%a: i32) { %b = cast %a : i32 -> index diff --git a/test/codegen/size.ir b/test/codegen/size.ir index 4eee84d6..c8c8d01d 100644 --- a/test/codegen/size.ir +++ b/test/codegen/size.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc < %s | filecheck %s +; RUN: %tinytc-oc -O0 < %s | filecheck %s func @t1(%0: memref) { %1 = size %0[0] : memref %2 = size %0[1] : memref @@ -11,6 +11,6 @@ func @t1(%0: memref) { func @t2(%0: memref) { %1 = size %0[0] : memref %2 = size %0[1] : memref -; CHECK: long x1 = x0_shape0; -; CHECK-NEXT: long x2 = x0_shape1; +; CHECK: long x1 = x_shape0; +; CHECK-NEXT: long x2 = x_shape1; } diff --git a/test/codegen/store.ir b/test/codegen/store.ir index cca632c5..82161a12 100644 --- a/test/codegen/store.ir +++ b/test/codegen/store.ir @@ -8,5 +8,5 @@ func @kernel(%a: memref, %b: memref, %c: f32) { store %c, %a[] : memref store %c, %b[%c5, %1] : memref ; CHECK: *a = c; - ; CHECK-NEXT: *(b + c5 * 1 + x1 * 10) = c; + ; CHECK-NEXT: *(b + c5 * 1 + x * 10) = c; } diff --git a/test/codegen/subgroup.ir b/test/codegen/subgroup.ir index a834a55b..d6094a0a 100644 --- a/test/codegen/subgroup.ir +++ b/test/codegen/subgroup.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc < %s | filecheck %s +; RUN: %tinytc-oc -O0 < %s | filecheck %s func @t1() { parallel { %0 = num_subgroups @@ -9,7 +9,7 @@ func @t1() { %2 = subgroup_local_id %3 = subgroup_size } -; CHECK: int x0 = get_num_sub_groups(); +; CHECK: int x = get_num_sub_groups(); ; CHECK-NEXT: int x1 = get_sub_group_id(); ; CHECK-NEXT: int x2 = get_sub_group_local_id(); ; CHECK-NEXT: int x3 = get_sub_group_size(); diff --git a/test/opt/insert-barrier.ir b/test/opt/insert-barrier.ir index c9f78e8f..c6711a9a 100644 --- a/test/opt/insert-barrier.ir +++ b/test/opt/insert-barrier.ir @@ -162,7 +162,7 @@ func @no_barrier_spmd(%a: f32, %b: f32, %A: memref, %B: memref %c4 = constant 4 -> i32 parallel { %0 = subgroup_id - %1 = cmp.eq %1, %c0 : i32 + %1 = cmp.eq %0, %c0 : i32 if %1 { %2 = load %A[%c3,%c4] : memref store %2, %A[%c3,%c4] : memref @@ -172,7 +172,7 @@ func @no_barrier_spmd(%a: f32, %b: f32, %A: memref, %B: memref ; CHECK-LABEL: func @no_barrier_spmd({{.*}} ; CHECK: parallel { ; CHECK-NEXT: %0 = subgroup_id -; CHECK-NEXT: %1 = cmp.eq %1, %c0 : i32 +; CHECK-NEXT: %1 = cmp.eq %0, %c0 : i32 ; CHECK-NEXT: if %1 { ; CHECK-NEXT: %2 = load %A[%c3,%c4] : memref ; CHECK-NEXT: store %2, %A[%c3,%c4] : memref diff --git a/tools/offline_compiler/args.cpp b/tools/offline_compiler/args.cpp index 0f580656..9848a043 100644 --- a/tools/offline_compiler/args.cpp +++ b/tools/offline_compiler/args.cpp @@ -24,6 +24,7 @@ auto make_core_info_from_string(char const *name) -> core_info { args arg_parser::parse_args(int argc, char **argv) { args a = {}; a.filename = nullptr; + a.opt_level = 2; int npos = 0; for (int i = 1; i < argc; ++i) { @@ -34,6 +35,12 @@ args arg_parser::parse_args(int argc, char **argv) { }; if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0) { a.help = true; + } else if (std::strcmp(argv[i], "-O0") == 0) { + a.opt_level = 0; + } else if (std::strcmp(argv[i], "-O1") == 0) { + a.opt_level = 1; + } else if (std::strcmp(argv[i], "-O2") == 0) { + a.opt_level = 2; } else if (i + 1 < argc) { if (std::strcmp(argv[i], "-d") == 0 || std::strcmp(argv[i], "--device") == 0) { a.info = make_core_info_from_string(argv[++i]); diff --git a/tools/offline_compiler/args.hpp b/tools/offline_compiler/args.hpp index 6d23aa61..4e135ace 100644 --- a/tools/offline_compiler/args.hpp +++ b/tools/offline_compiler/args.hpp @@ -6,12 +6,14 @@ #include "tinytc/tinytc.hpp" +#include #include struct args { char const *filename; tinytc::core_info info; bool help; + std::int32_t opt_level; }; class arg_parser { diff --git a/tools/offline_compiler/main.cpp b/tools/offline_compiler/main.cpp index 7662a0f4..8fc1f71e 100644 --- a/tools/offline_compiler/main.cpp +++ b/tools/offline_compiler/main.cpp @@ -32,6 +32,7 @@ int main(int argc, char **argv) { auto ctx = compiler_context{}; try { ctx = make_compiler_context(); + ctx.set_optimization_level(a.opt_level); auto p = prog{}; if (!a.filename) { p = parse_stdin(ctx); diff --git a/tools/opt/args.cpp b/tools/opt/args.cpp index 65cbb4ac..b0358946 100644 --- a/tools/opt/args.cpp +++ b/tools/opt/args.cpp @@ -91,6 +91,7 @@ positional arguments: file-name Path to source code; leave empty to read from stdin optional arguments: + -O0,-O1,-O2 Optimization level, default is -O2 -d, --device Device name (cf. intel_gpu_architecture enum), default is "pvc" -h, --help Show help text and exit From d5bbc5876f25d88639d948a082cbd7aa1de22809 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 8 Oct 2024 11:47:20 +0200 Subject: [PATCH 044/297] Removed unused code Signed-off-by: Carsten Uphoff --- src/codegen_tools.cpp | 92 +++++-------------------------------------- src/codegen_tools.hpp | 11 ------ 2 files changed, 9 insertions(+), 94 deletions(-) diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 3fb396a0..5c003978 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -433,39 +433,6 @@ void write_matrix_block(block_builder &bb, block_accessor const &block, void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, value sg_id, sgs_loop_body_builder_new const &body) { - tile_loop_by_sgs_new_dynamic(bb, std::move(loop_trip_count), sgs, num_tiles, std::move(sg_id), - body); -} - -void tile_loop_by_sgs_new_constant(region_builder &bb, std::int64_t loop_trip_count, int sgs, - int num_tiles, value sg_id, - sgs_loop_body_builder_new const &body) { - auto index_ty = scalar_data_type::get(sg_id->context(), scalar_type::index); - std::int64_t blocks = loop_trip_count / sgs; - std::int64_t rem = loop_trip_count % sgs; - - auto c_sgs = bb.add(make_constant(sgs, index_ty)); - auto c_sgs_blocks = bb.add(make_constant(sgs * blocks, index_ty)); - auto c_sgs_tiles = bb.add(make_constant(sgs * num_tiles, index_ty)); - auto c_tiles_1 = bb.add(make_constant(num_tiles - 1, index_ty)); - auto c_rem = bb.add(make_constant(rem, index_ty)); - - auto sg_id_index = bb.add(make_cast(sg_id, index_ty)); - if (blocks > 0) { - auto block_start = bb.add(make_arith(arithmetic::mul, c_sgs, sg_id_index)); - bb.for_loop(std::move(block_start), c_sgs_blocks, c_sgs_tiles, index_ty, - [&](region_builder &bb, value block) { body(bb, block, false, c_sgs); }); - } - - if (rem > 0) { - auto condition = bb.add(make_cmp(cmp_condition::eq, sg_id_index, c_tiles_1)); - bb.if_condition(condition, - [&](region_builder &bb) { body(bb, c_sgs_blocks, true, c_rem); }); - } -} - -void tile_loop_by_sgs_new_dynamic(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, - value sg_id, sgs_loop_body_builder_new const &body) { auto index_ty = scalar_data_type::get(sg_id->context(), scalar_type::index); auto c_sgs = bb.add(make_constant(sgs, index_ty)); auto c_sgs_tiles = bb.add(make_constant(sgs * num_tiles, index_ty)); @@ -476,10 +443,13 @@ void tile_loop_by_sgs_new_dynamic(region_builder &bb, value loop_trip_count, int auto rem = bb.add(make_arith(arithmetic::rem, loop_trip_count, c_sgs)); auto sg_id_index = bb.add(make_cast(sg_id, index_ty)); - auto block_start = bb.add(make_arith(arithmetic::mul, c_sgs, sg_id_index)); - auto block_end = bb.add(make_arith(arithmetic::mul, c_sgs, blocks)); - bb.for_loop(std::move(block_start), std::move(block_end), c_sgs_tiles, index_ty, - [&](region_builder &bb, value block) { body(bb, block, false, c_sgs); }); + auto is_blocks_gt_0 = bb.add(make_cmp(cmp_condition::gt, blocks, c0)); + bb.if_condition(is_blocks_gt_0, [&](region_builder &bb) { + auto block_start = bb.add(make_arith(arithmetic::mul, c_sgs, sg_id_index)); + auto block_end = bb.add(make_arith(arithmetic::mul, c_sgs, blocks)); + bb.for_loop(std::move(block_start), std::move(block_end), c_sgs_tiles, index_ty, + [&](region_builder &bb, value block) { body(bb, block, false, c_sgs); }); + }); auto condition0 = bb.add(make_cmp(cmp_condition::gt, rem, c0)); bb.if_condition(condition0, [&](region_builder &bb) { @@ -494,50 +464,6 @@ void tile_loop_by_sgs_new_dynamic(region_builder &bb, value loop_trip_count, int void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int block_size, int num_tiles, value sg_id, uniform_loop_body_builder_new const &body) { - tile_loop_uniformly_new_dynamic(bb, std::move(loop_trip_count), block_size, num_tiles, - std::move(sg_id), body); -} - -void tile_loop_uniformly_new_constant(region_builder &bb, std::int64_t loop_trip_count, - int block_size, int num_tiles, value sg_id, - uniform_loop_body_builder_new const &body) { - auto index_ty = scalar_data_type::get(sg_id->context(), scalar_type::index); - // Find minimum number of blocks such that the block sizes are smaller or equal block_size - std::int64_t blocks = 1 + (loop_trip_count - 1) / block_size; - // Increase the number of blocks if such that the number of blocks is a multiple - // of the number of tiles - blocks = (1 + (blocks - 1) / num_tiles) * num_tiles; - std::int64_t bs = loop_trip_count / blocks; - std::int64_t bs_1 = bs + 1; - std::int64_t rem = loop_trip_count % blocks; - - auto c_bs = bb.add(make_constant(bs, index_ty)); - auto c_bs_tiles = bb.add(make_constant(bs * num_tiles, index_ty)); - auto c_bs_1 = bb.add(make_constant(bs_1, index_ty)); - auto c_bs_1_rem = bb.add(make_constant(bs_1 * rem, index_ty)); - auto c_bs_1_tiles = bb.add(make_constant(bs_1 * num_tiles, index_ty)); - auto c_rem_mod_tiles = bb.add(make_constant(rem % num_tiles, index_ty)); - auto c_tiles = bb.add(make_constant(num_tiles, index_ty)); - auto c_loop_trip_count = bb.add(make_constant(loop_trip_count, index_ty)); - - auto sg_id_index = bb.add(make_cast(sg_id, index_ty)); - if (rem > 0) { - auto block_start = bb.add(make_arith(arithmetic::mul, c_bs_1, sg_id_index)); - bb.for_loop(std::move(block_start), c_bs_1_rem, c_bs_1_tiles, index_ty, - [&](region_builder &bb, value block) { body(bb, block, c_bs_1); }); - } - - auto tmp = bb.add(make_arith(arithmetic::add, sg_id_index, c_rem_mod_tiles)); - auto sg_id_1 = bb.add(make_arith(arithmetic::rem, tmp, c_tiles)); - auto tmp2 = bb.add(make_arith(arithmetic::mul, c_bs, sg_id_1)); - auto block_start = bb.add(make_arith(arithmetic::add, c_bs_1_rem, tmp2)); - bb.for_loop(std::move(block_start), c_loop_trip_count, c_bs_tiles, index_ty, - [&](region_builder &bb, value block) { body(bb, block, c_bs); }); -} - -void tile_loop_uniformly_new_dynamic(region_builder &bb, value loop_trip_count, int block_size, - int num_tiles, value sg_id, - uniform_loop_body_builder_new const &body) { auto index_ty = scalar_data_type::get(loop_trip_count->context(), scalar_type::index); auto c0 = bb.add(make_constant(0, index_ty)); auto c1 = bb.add(make_constant(1, index_ty)); @@ -561,8 +487,8 @@ void tile_loop_uniformly_new_dynamic(region_builder &bb, value loop_trip_count, // The following if makes it easy to eliminate the remainder handler in optimization if rem == 0 // is known at compile time. Without the if, we would need to prove that block_start_1 is // non-negative to eliminate the for-loop. - auto is_rem_0 = bb.add(make_cmp(cmp_condition::gt, rem, c0)); - bb.if_condition(is_rem_0, [&](region_builder &bb) { + auto is_rem_gt_0 = bb.add(make_cmp(cmp_condition::gt, rem, c0)); + bb.if_condition(is_rem_gt_0, [&](region_builder &bb) { auto block_start_1 = bb.add(make_arith(arithmetic::mul, bs_1, sg_id_index)); auto block_end_1 = bb.add(make_arith(arithmetic::mul, bs_1, rem)); auto step_1 = bb.add(make_arith(arithmetic::mul, bs_1, c_tiles)); diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index 2241d7cc..15792ddf 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -127,20 +127,9 @@ using uniform_loop_body_builder_new = std::function Date: Tue, 8 Oct 2024 11:49:12 +0200 Subject: [PATCH 045/297] Review and update casting rules Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 16 +++++++++ include/tinytc/types.h | 1 + include/tinytc/types.hpp | 1 + src/error.cpp | 2 ++ src/node/inst_node.cpp | 12 +++++-- src/pass/constant_propagation_helper.hpp | 12 +++++-- src/pass/convert_to_opencl.cpp | 21 ++++++++++- test/codegen/cast.ir | 46 ++++++++++++++++++++++++ test/opt/check-ir/cast_forbidden.ir | 9 +++++ test/opt/constant-propagation.ir | 30 +++++++++------- 10 files changed, 131 insertions(+), 19 deletions(-) create mode 100644 test/codegen/cast.ir create mode 100644 test/opt/check-ir/cast_forbidden.ir diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index df183779..bfd7d2bd 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -645,6 +645,22 @@ Overview ~~~~~~~~ Cast scalar values. +Casts from complex types to non-complex types are forbidden. +The following table summarizes the casts and the mapping to SPIR-V: + +============= ============= ================================================== +Operand type Result type SPIR-V Op +============= ============= ================================================== +integer-type integer-type OpSConvert +floating-type floating-type OpFConvert +complex-type complex-type OpFConvert (on vector2) +integer-type floating-type OpConvertSToF +floating-type integer-type OpConvertFToS +floating-type complex-type OpFConvert on real part, imaginary part is zero +integer-type complex-type OpConvertSToF on real part, imaginary part is zero +complex-type integer-type Forbidden +complex-type floating-type Forbidden +============= ============= ================================================== Comparison .......... diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 046b03d9..080b98ac 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -67,6 +67,7 @@ typedef enum { tinytc_status_ir_invalid_offset = 0x116, ///< Invalid offset tinytc_status_ir_i1_unsupported = 0x117, ///< Instruction does not support i1 type tinytc_status_ir_complex_unsupported = 0x118, ///< Instruction does not support complex type + tinytc_status_ir_forbidden_cast = 0x119, ///< Forbidden cast // Level zero errors tinytc_status_ze_result_not_ready = 0x10000, ///< ZE_RESULT_NOT_READY tinytc_status_ze_result_error_device_lost = 0x10001, ///< ZE_RESULT_ERROR_DEVICE_LOST diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index d1fba78d..bf7b757b 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -77,6 +77,7 @@ enum class status { ir_invalid_offset = tinytc_status_ir_invalid_offset, ir_i1_unsupported = tinytc_status_ir_i1_unsupported, ir_complex_unsupported = tinytc_status_ir_complex_unsupported, + ir_forbidden_cast = tinytc_status_ir_forbidden_cast, ze_result_not_ready = tinytc_status_ze_result_not_ready, ze_result_error_device_lost = tinytc_status_ze_result_error_device_lost, ze_result_error_out_of_host_memory = tinytc_status_ze_result_error_out_of_host_memory, diff --git a/src/error.cpp b/src/error.cpp index 819301b8..c3d95cef 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -169,6 +169,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "i1 type unsupported by instruction"; case tinytc_status_ir_complex_unsupported: return "complex type unsupported by instruction"; + case tinytc_status_ir_forbidden_cast: + return "Forbidden cast"; // Level Zero case tinytc_status_ze_result_not_ready: return "ZE_RESULT_NOT_READY"; diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 0d2f89c2..1b096be3 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -194,15 +194,21 @@ arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0 result(0) = value_node{at, this, lc}; } -cast_inst::cast_inst(tinytc_value_t a, tinytc_data_type_t to_ty, location const &lc) +cast_inst::cast_inst(tinytc_value_t a0, tinytc_data_type_t to_ty, location const &lc) : standard_inst{IK::cast} { - op(op_a, a); + op(op_a, a0); loc(lc); - if (!isa(*to_ty)) { + auto rt = dyn_cast(to_ty); + if (rt == nullptr) { throw compilation_error(lc, status::ir_expected_scalar); } + auto at = get_scalar_type(loc(), a()); + if (is_complex_type(at->ty()) && !is_complex_type(rt->ty())) { + throw compilation_error(lc, status::ir_forbidden_cast); + } + result(0) = value_node{to_ty, this, lc}; } diff --git a/src/pass/constant_propagation_helper.hpp b/src/pass/constant_propagation_helper.hpp index c215830c..8373423b 100644 --- a/src/pass/constant_propagation_helper.hpp +++ b/src/pass/constant_propagation_helper.hpp @@ -201,16 +201,24 @@ template struct value_cast_impl { auto operator()(U const &u) { return static_cast(u); } }; +template struct value_cast_impl, U> { + auto operator()(U const &u) { return std::complex{static_cast(u), static_cast(0)}; } +}; + +template struct value_cast_impl, std::complex> { + auto operator()(std::complex const &u) { return static_cast>(u); } +}; + template struct value_cast_impl { auto operator()(U const &u) { return u != U{}; } }; template struct value_cast_impl> { - auto operator()(std::complex const &u) { return u != std::complex{}; } + auto operator()(std::complex const &) -> bool { throw status::ir_forbidden_cast; } }; template struct value_cast_impl> { - auto operator()(std::complex const &u) { return static_cast(u.real()); } + auto operator()(std::complex const &) -> T { throw status::ir_forbidden_cast; } }; template auto value_cast(U const &u) { return value_cast_impl{}(u); } diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index db1e43f4..20fb2e41 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -382,8 +382,27 @@ std::vector convert_to_opencl_pass::operator()(arith_unary_inst cons std::vector convert_to_opencl_pass::operator()(cast_inst const &c) { auto v = declare(*c.result()); + auto aty = get_scalar_type(c.a().ty()); + auto rty = get_scalar_type(c.result(0).ty()); + auto av = val(c.a()); + auto cst = clir::expr{}; auto result_ty = visit(*this, *c.result()->ty()); - auto cst = cast(result_ty, val(c.a())); + if (is_complex_type(aty) && is_complex_type(rty)) { + switch (rty) { + case scalar_type::c32: + cst = clir::call("convert_float2", {std::move(av)}); + break; + case scalar_type::c64: + cst = clir::call("convert_double2", {std::move(av)}); + break; + default: + throw status::internal_compiler_error; + } + } else if (is_complex_type(rty)) { + cst = clir::init_vector(result_ty, {std::move(av), 0}); + } else { + cst = cast(result_ty, std::move(av)); + } return {declaration_assignment(std::move(result_ty), std::move(v), std::move(cst))}; } diff --git a/test/codegen/cast.ir b/test/codegen/cast.ir new file mode 100644 index 00000000..d417eb8f --- /dev/null +++ b/test/codegen/cast.ir @@ -0,0 +1,46 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -O0 < %s | filecheck %s +func @cast_ii() { + %0 = constant 2 -> index + %1 = cast %0 : index -> i32 +; CHECK-LABEL: void cast_ii() { +; CHECK: int x1 = (int) x; +} +func @cast_ff() { + %0 = constant 2.0 -> f32 + %1 = cast %0 : f32 -> f64 +; CHECK-LABEL: void cast_ff() { +; CHECK: double x1 = (double) x; +} +func @cast_cc() { + %0 = constant [2.0, 0.0] -> c32 + %1 = cast %0 : c32 -> c64 +; CHECK-LABEL: void cast_cc() { +; CHECK: double2 x1 = convert_double2(x); +} +func @cast_if() { + %0 = constant 2 -> i32 + %1 = cast %0 : i32 -> f32 +; CHECK-LABEL: void cast_if() { +; CHECK: float x1 = (float) x; +} +func @cast_fi() { + %0 = constant 2.0 -> f32 + %1 = cast %0 : f32 -> i16 +; CHECK-LABEL: void cast_fi() { +; CHECK: short x1 = (short) x; +} +func @cast_ic() { + %0 = constant 2 -> i8 + %1 = cast %0 : i8 -> c32 +; CHECK-LABEL: void cast_ic() { +; CHECK: float2 x1 = (float2) (x, 0); +} +func @cast_fc() { + %0 = constant 2.0 -> f64 + %1 = cast %0 : f64 -> c32 +; CHECK-LABEL: void cast_fc() { +; CHECK: float2 x1 = (float2) (x, 0); +} diff --git a/test/opt/check-ir/cast_forbidden.ir b/test/opt/check-ir/cast_forbidden.ir new file mode 100644 index 00000000..f9665793 --- /dev/null +++ b/test/opt/check-ir/cast_forbidden.ir @@ -0,0 +1,9 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: not %tinytc-opt --check-ir < %s 2>&1 | filecheck %s +func @cast_cf() { + %0 = constant [2.0, 1.0] -> c32 + %1 = cast %0 : c32 -> i32 +; CHECK: :7.8-27: Forbidden cast +} diff --git a/test/opt/constant-propagation.ir b/test/opt/constant-propagation.ir index 3e862bf9..7c133e04 100644 --- a/test/opt/constant-propagation.ir +++ b/test/opt/constant-propagation.ir @@ -51,20 +51,24 @@ func @known_arith() { } func @known_cast() { - %0 = constant 32768 -> i32 - %1 = cast %0 : i32 -> i16 - %2 = cast %0 : i32 -> f32 - %3 = cast %0 : i32 -> c32 - %4 = cast %0 : i32 -> i1 - %5 = cast %3 : c32 -> i1 - %6 = cast %5 : i1 -> c32 + %c0 = constant 32768 -> i32 + %c1 = constant [3.0, -2.0] -> c32 + %0 = cast %c0 : i32 -> i16 + %1 = cast %c0 : i32 -> f32 + %2 = cast %c0 : i32 -> c32 + %3 = cast %c0 : i32 -> i1 + %4 = cast %c0 : i32 -> c32 + %5 = cast %c1 : c32 -> c64 + %6 = cast %3 : i1 -> c32 ; CHECK-LABEL: func @known_cast({{.*}} -; CHECK: %0 = constant 32768 -> i32 -; CHECK-NEXT: %1 = constant -32768 -> i16 -; CHECK-NEXT: %2 = constant 0x1p+15 -> f32 -; CHECK-NEXT: %3 = constant [0x1p+15,0x0p+0] -> c32 -; CHECK-NEXT: %4 = constant 1 -> i1 -; CHECK-NEXT: %5 = constant 1 -> i1 +; CHECK: %c0 = constant 32768 -> i32 +; CHECK: %c1 = constant [0x1.8p+1,-0x1p+1] -> c32 +; CHECK-NEXT: %0 = constant -32768 -> i16 +; CHECK-NEXT: %1 = constant 0x1p+15 -> f32 +; CHECK-NEXT: %2 = constant [0x1p+15,0x0p+0] -> c32 +; CHECK-NEXT: %3 = constant 1 -> i1 +; CHECK-NEXT: %4 = constant [0x1p+15,0x0p+0] -> c32 +; CHECK-NEXT: %5 = constant [0x1.8p+1,-0x1p+1] -> c64 ; CHECK-NEXT: %6 = constant [0x1p+0,0x0p+0] -> c32 } From 5a02f00a440b8235e017a1682c1a43549675f210 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 8 Oct 2024 14:54:01 +0200 Subject: [PATCH 046/297] Add unary instructions for complex Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 23 ++++--- include/tinytc/types.h | 15 +++-- include/tinytc/types.hpp | 9 ++- src/codegen_tools.cpp | 11 ++++ src/codegen_tools.hpp | 1 + src/error.cpp | 2 + src/node/inst_node.cpp | 30 +++++++-- src/parser/lexer.re | 4 ++ src/pass/constant_propagation_helper.hpp | 81 +++++++++++++++++++++--- src/pass/convert_to_opencl.cpp | 17 ++++- test/codegen/scalar_arithmetic.ir | 24 +++++++ test/opt/constant-propagation.ir | 28 ++++++++ 12 files changed, 216 insertions(+), 29 deletions(-) diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index bfd7d2bd..8d24ef50 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -588,23 +588,30 @@ Arithmetic (unary) .. code:: abnf - arith-unary-type = ".neg" / ".not" + arith-unary-type = ".abs" / ".neg" / ".not" / ".conj" / ".im" / ".re" value-instruction =/ "arith" arith-unary-type local-identifier ":" scalar-type Overview ~~~~~~~~ Unary arithmetic operation on scalars. -The returned value has the same type as the operand. +For integer and floating point input, the returned value has the same type as the operand. +For complex input, the returned value has the underlying floating point type +for ".abs", ".im", and ".re", and the returned value has the same type as the operand +for ".neg" and ".conj". The following table shows the operations' description and the types that are allowed for the operation. -==== ============ ============================================================================== -Op Allowed type Description -==== ============ ============================================================================== -.neg scalar-type Negation -.not integer-type Bitwise not -==== ============ ============================================================================== +===== ============ ============================================================================== +Op Allowed type Description +===== ============ ============================================================================== +.abs scalar-type Compute absolute value +.neg scalar-type Negation +.not integer-type Bitwise not +.conj complex-type Complex conjugate +.im complex-type Extract imaginary part +.re complex-type Extract real part +===== ============ ============================================================================== Barrier ....... diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 080b98ac..bd1a934d 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -65,9 +65,10 @@ typedef enum { tinytc_status_ir_expected_local_address_space = 0x114, ///< Expected local address space tinytc_status_ir_expected_global_address_space = 0x115, ///< Expected global address space tinytc_status_ir_invalid_offset = 0x116, ///< Invalid offset - tinytc_status_ir_i1_unsupported = 0x117, ///< Instruction does not support i1 type - tinytc_status_ir_complex_unsupported = 0x118, ///< Instruction does not support complex type - tinytc_status_ir_forbidden_cast = 0x119, ///< Forbidden cast + tinytc_status_ir_int_unsupported = 0x117, ///< Instruction does not support int type + tinytc_status_ir_i1_unsupported = 0x118, ///< Instruction does not support i1 type + tinytc_status_ir_complex_unsupported = 0x119, ///< Instruction does not support complex type + tinytc_status_ir_forbidden_cast = 0x11a, ///< Forbidden cast // Level zero errors tinytc_status_ze_result_not_ready = 0x10000, ///< ZE_RESULT_NOT_READY tinytc_status_ze_result_error_device_lost = 0x10001, ///< ZE_RESULT_ERROR_DEVICE_LOST @@ -250,8 +251,12 @@ typedef enum { //! Arithmetic operations (unary) typedef enum { - tinytc_arithmetic_unary_neg = 0, ///< negation - tinytc_arithmetic_unary_not = 1 ///< bitwise not + tinytc_arithmetic_unary_neg = 0, ///< negation + tinytc_arithmetic_unary_not = 1, ///< bitwise not + tinytc_arithmetic_unary_abs = 2, ///< absolute value + tinytc_arithmetic_unary_conj = 3, ///< complex conjugate + tinytc_arithmetic_unary_im = 4, ///< imaginary part + tinytc_arithmetic_unary_re = 5 ///< real part } tinytc_arithmetic_unary_t; //! Compare operation diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index bf7b757b..3550982c 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -75,6 +75,7 @@ enum class status { ir_expected_local_address_space = tinytc_status_ir_expected_local_address_space, ir_expected_global_address_space = tinytc_status_ir_expected_global_address_space, ir_invalid_offset = tinytc_status_ir_invalid_offset, + ir_int_unsupported = tinytc_status_ir_int_unsupported, ir_i1_unsupported = tinytc_status_ir_i1_unsupported, ir_complex_unsupported = tinytc_status_ir_complex_unsupported, ir_forbidden_cast = tinytc_status_ir_forbidden_cast, @@ -233,8 +234,12 @@ enum class arithmetic { //! Arithmetic operations (unary) enum class arithmetic_unary { - neg = tinytc_arithmetic_unary_neg, ///< negation - not_ = tinytc_arithmetic_unary_not ///< bitwise not + neg = tinytc_arithmetic_unary_neg, ///< negation + not_ = tinytc_arithmetic_unary_not, ///< bitwise not + abs = tinytc_arithmetic_unary_abs, ///< absolute value + conj = tinytc_arithmetic_unary_conj, ///< complex conjugate + im = tinytc_arithmetic_unary_im, ///< imaginary part + re = tinytc_arithmetic_unary_re ///< real part }; //! Compare operation diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 5c003978..6f4801f6 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -39,6 +39,17 @@ expr multiply(scalar_type ty_a, scalar_type ty_b, expr a, expr b) { } return a * b; } +expr divide(scalar_type ty_a, scalar_type ty_b, expr a, expr b) { + if (is_complex_type(ty_a) && is_complex_type(ty_b)) { + return (a * b.s(0) - init_vector(to_clir_ty(ty_a), {-a.s(1), a.s(0)}) * b.s(1)) / + (b.s(0) * b.s(0) + b.s(1) * b.s(1)); + } + if (is_complex_type(ty_b)) { + return a * init_vector(to_clir_ty(ty_b), {b.s(0), -b.s(1)}) / + (b.s(0) * b.s(0) + b.s(1) * b.s(1)); + } + return a / b; +} expr vload_helper(short vec_size, expr offset, expr ptr) { switch (vec_size) { diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index 15792ddf..1ce6d989 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -25,6 +25,7 @@ short bits(scalar_type ty); clir::expr constant(scalar_type ty, std::int64_t value); clir::expr constant(scalar_type ty, double value); clir::expr multiply(scalar_type ty_a, scalar_type ty_b, clir::expr a, clir::expr b); +clir::expr divide(scalar_type ty_a, scalar_type ty_b, clir::expr a, clir::expr b); clir::expr vload_helper(short vec_size, clir::expr offset, clir::expr ptr); clir::expr sub_group_block_read_helper(clir::expr pointer, scalar_type ty, clir::address_space as); clir::expr sub_group_block_write_helper(clir::expr pointer, clir::expr data, scalar_type ty, diff --git a/src/error.cpp b/src/error.cpp index c3d95cef..0c760c5e 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -165,6 +165,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "A memref with global address space is expected"; case tinytc_status_ir_invalid_offset: return "Offset must be non-negative or dynamic"; + case tinytc_status_ir_int_unsupported: + return "int type unsupported by instruction"; case tinytc_status_ir_i1_unsupported: return "i1 type unsupported by instruction"; case tinytc_status_ir_complex_unsupported: diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 1b096be3..4ac8de8d 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -174,24 +174,46 @@ arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0 op(op_a, a0); loc(lc); - auto at = get_scalar_type(loc(), a()); + auto a_ty = get_scalar_type(loc(), a()); + tinytc_data_type_t to_ty = a_ty; + + bool inst_supports_int = true; bool inst_supports_fp = true; bool inst_supports_complex = true; switch (operation) { + case arithmetic_unary::abs: case arithmetic_unary::neg: break; case arithmetic_unary::not_: inst_supports_fp = false; inst_supports_complex = false; break; + case arithmetic_unary::conj: + case arithmetic_unary::im: + case arithmetic_unary::re: + inst_supports_int = false; + inst_supports_fp = false; + break; } - if (!inst_supports_fp && is_floating_type(at->ty())) { + if (!inst_supports_int && is_integer_type(a_ty->ty())) { + throw compilation_error(loc(), status::ir_int_unsupported); + } + if (!inst_supports_fp && is_floating_type(a_ty->ty())) { throw compilation_error(loc(), status::ir_fp_unsupported); } - if (!inst_supports_complex && is_complex_type(at->ty())) { + if (!inst_supports_complex && is_complex_type(a_ty->ty())) { throw compilation_error(loc(), status::ir_complex_unsupported); } - result(0) = value_node{at, this, lc}; + switch (operation) { + case arithmetic_unary::abs: + case arithmetic_unary::im: + case arithmetic_unary::re: + to_ty = scalar_data_type::get(a_ty->context(), element_type(a_ty->ty())); + break; + default: + break; + } + result(0) = value_node{to_ty, this, lc}; } cast_inst::cast_inst(tinytc_value_t a0, tinytc_data_type_t to_ty, location const &lc) diff --git a/src/parser/lexer.re b/src/parser/lexer.re index 38d96dd1..d2056cbd 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -176,8 +176,12 @@ lex: ".xor" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::xor_, loc_); } // unary op + ".abs" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::abs, loc_); } ".neg" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::neg, loc_); } ".not" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::not_, loc_); } + ".conj" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::conj, loc_); } + ".im" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::im, loc_); } + ".re" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::re, loc_); } // comparison condition ".eq" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::eq, diff --git a/src/pass/constant_propagation_helper.hpp b/src/pass/constant_propagation_helper.hpp index 8373423b..04cc552b 100644 --- a/src/pass/constant_propagation_helper.hpp +++ b/src/pass/constant_propagation_helper.hpp @@ -5,13 +5,21 @@ #define CONSTANT_PROPAGATION_HELPER_20241002_HPP #include "scalar_type.hpp" +#include "support/casting.hpp" #include "tinytc/tinytc.hpp" +#include #include #include namespace tinytc { +template struct is_complex : public std::false_type {}; +template +requires(std::is_floating_point_v) +struct is_complex> : public std::true_type {}; +template inline constexpr bool is_complex_v = is_complex::value; + struct compute_unary_op { arithmetic_unary operation; data_type ty; @@ -22,6 +30,9 @@ struct compute_unary_op { auto operator()(T a) { T val = 0; switch (operation) { + case arithmetic_unary::abs: + val = a < 0 ? -a : a; + break; case arithmetic_unary::neg: val = -a; break; @@ -32,28 +43,80 @@ struct compute_unary_op { val = ~a; } break; + default: + throw compilation_error(loc, status::ir_int_unsupported); } return make_constant(val, ty, loc); } - template - requires(!std::is_integral_v) - auto operator()(U const &A) -> inst { - const auto a = static_cast(A); - T val = {}; + template + requires(std::is_floating_point_v) + auto operator()(T a) -> inst { + T val = 0; switch (operation) { + case arithmetic_unary::abs: + val = a < T{0} ? -a : a; + break; case arithmetic_unary::neg: val = -a; break; default: - if constexpr (!std::is_floating_point_v) { - throw compilation_error(loc, status::ir_complex_unsupported); - } throw compilation_error(loc, status::ir_fp_unsupported); - break; } return make_constant(val, ty, loc); } + + template + requires(is_complex_v) + auto operator()(U const &A) -> inst { + const auto neg_conj = [&](T const &a) { + T val = {}; + switch (operation) { + case arithmetic_unary::neg: + val = -a; + break; + case arithmetic_unary::conj: + val = std::conj(a); + break; + default: + return inst{nullptr}; + } + return make_constant(val, ty, loc); + }; + const auto abs_im_re = [&](T const &a) -> inst { + typename T::value_type val = {}; + switch (operation) { + case arithmetic_unary::abs: + val = std::abs(a); + break; + case arithmetic_unary::im: + val = std::imag(a); + break; + case arithmetic_unary::re: + val = std::real(a); + break; + default: + return inst{nullptr}; + } + scalar_data_type *sty = dyn_cast(ty); + if (!sty) { + throw compilation_error(loc, status::ir_expected_scalar); + } + auto cst_ty = scalar_data_type::get(sty->context(), element_type(sty->ty())); + return make_constant(val, cst_ty, loc); + }; + + const auto a = static_cast(A); + auto result = neg_conj(a); + if (result) { + return result; + } + result = abs_im_re(a); + if (result) { + return result; + } + throw compilation_error(loc, status::ir_complex_unsupported); + } }; struct compute_binary_op { diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 20fb2e41..1c884e52 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -330,7 +330,7 @@ std::vector convert_to_opencl_pass::operator()(arith_inst const &a) case arithmetic::mul: return multiply(sty, sty, std::move(a), std::move(b)); case arithmetic::div: - return std::move(a) / std::move(b); + return divide(sty, sty, std::move(a), std::move(b)); case arithmetic::rem: if (is_floating_type(sty)) { return clir::fmod(std::move(a), std::move(b)); @@ -364,6 +364,15 @@ std::vector convert_to_opencl_pass::operator()(arith_inst const &a) std::vector convert_to_opencl_pass::operator()(arith_unary_inst const &a) { auto const make = [](arithmetic_unary op, clir::expr a, scalar_type sty) -> clir::expr { switch (op) { + case arithmetic_unary::abs: + if (is_complex_type(sty)) { + return clir::call_builtin(clir::builtin_function::sqrt, + {a.s(0) * a.s(0) + a.s(1) * a.s(1)}); + } + if (is_floating_type(sty)) { + return clir::call_builtin(clir::builtin_function::fabs, {std::move(a)}); + } + return clir::call_builtin(clir::builtin_function::abs, {std::move(a)}); case arithmetic_unary::neg: return -std::move(a); case arithmetic_unary::not_: @@ -371,6 +380,12 @@ std::vector convert_to_opencl_pass::operator()(arith_unary_inst cons return !std::move(a); } return ~std::move(a); + case arithmetic_unary::conj: + return clir::init_vector(to_clir_ty(sty), {a.s(0), -a.s(1)}); + case arithmetic_unary::im: + return std::move(a).s(1); + case arithmetic_unary::re: + return std::move(a).s(0); } return {}; }; diff --git a/test/codegen/scalar_arithmetic.ir b/test/codegen/scalar_arithmetic.ir index 4a8b7f48..1198b34b 100644 --- a/test/codegen/scalar_arithmetic.ir +++ b/test/codegen/scalar_arithmetic.ir @@ -18,6 +18,7 @@ func @t1(%a: i32, %b: i32, %a1: i1, %b1: i1) { %13 = arith.neg %a : i32 %14 = arith.not %a : i32 %15 = arith.not %a1 : i1 + %16 = arith.abs %a : i32 ; CHECK: int x = a + b; ; CHECK-NEXT: int x1 = a - b; ; CHECK-NEXT: int x2 = a * b; @@ -33,6 +34,7 @@ func @t1(%a: i32, %b: i32, %a1: i1, %b1: i1) { ; CHECK-NEXT: int x12 = -a; ; CHECK-NEXT: int x13 = ~a; ; CHECK-NEXT: bool x14 = !a1; +; CHECK-NEXT: int x15 = abs(a); } func @t2(%a: i32, %b: i32) { %1 = cmp.eq %a, %b : i32 @@ -55,12 +57,14 @@ func @t3(%a: f32, %b: f32) { %4 = arith.div %a, %b : f32 %5 = arith.rem %a, %b : f32 %6 = arith.neg %a : f32 + %7 = arith.abs %a : f32 ; CHECK: float x = a + b; ; CHECK-NEXT: float x1 = a - b; ; CHECK-NEXT: float x2 = a * b; ; CHECK-NEXT: float x3 = a / b; ; CHECK-NEXT: float x4 = fmod(a, b); ; CHECK-NEXT: float x5 = -a; +; CHECK-NEXT: float x6 = fabs(a); } func @t4(%a: f32, %b: f32) { %1 = cmp.eq %a, %b : f32 @@ -80,3 +84,23 @@ func @t5(%a: i32) { %b = cast %a : i32 -> index ; CHECK: long b = (long) a; } +func @t6(%a: c32, %b: c32) { + %0 = arith.add %a, %b : c32 + %1 = arith.sub %a, %b : c32 + %2 = arith.mul %a, %b : c32 + %3 = arith.div %a, %b : c32 + %4 = arith.neg %a : c32 + %5 = arith.abs %a : c32 + %6 = arith.conj %a : c32 + %7 = arith.im %a : c32 + %8 = arith.re %a : c32 +; CHECK: float2 x = a + b; +; CHECK-NEXT: float2 x1 = a - b; +; CHECK-NEXT: float2 x2 = a * b.x + (float2) (-a.y, a.x) * b.y; +; CHECK-NEXT: float2 x3 = (a * b.x - (float2) (-a.y, a.x) * b.y) / (b.x * b.x + b.y * b.y); +; CHECK-NEXT: float2 x4 = -a; +; CHECK-NEXT: float x5 = sqrt(a.x * a.x + a.y * a.y); +; CHECK-NEXT: float2 x6 = (float2) (a.x, -a.y); +; CHECK-NEXT: float x7 = a.y; +; CHECK-NEXT: float x8 = a.x; +} diff --git a/test/opt/constant-propagation.ir b/test/opt/constant-propagation.ir index 7c133e04..7add13c2 100644 --- a/test/opt/constant-propagation.ir +++ b/test/opt/constant-propagation.ir @@ -39,6 +39,7 @@ func @known_arith() { %5 = arith.neg %4 : f32 %6 = constant [1.0, -1.0] -> c32 %7 = arith.add %6, %6 : c32 + %8 = arith.abs %4 : f32 ; CHECK-LABEL: func @known_arith({{.*}} ; CHECK: %0 = constant 1 -> i64 ; CHECK-NEXT: %1 = constant -2 -> i64 @@ -48,6 +49,7 @@ func @known_arith() { ; CHECK-NEXT: %5 = constant 0x1p+1 -> f32 ; CHECK-NEXT: %6 = constant [0x1p+0,-0x1p+0] -> c32 ; CHECK-NEXT: %7 = constant [0x1p+1,-0x1p+1] -> c32 +; CHECK-NEXT: %8 = constant 0x1p+1 -> f32 } func @known_cast() { @@ -83,3 +85,29 @@ func @known_compare() { ; CHECK-NEXT: %2 = constant 1 -> i1 ; CHECK-NEXT: %3 = constant 0 -> i1 } + +func @known_arith_complex() { + %a = constant [3.0, 2.0] -> c32 + %b = constant [-1.0, 5.0] -> c32 + %0 = arith.add %a, %b : c32 + %1 = arith.sub %a, %b : c32 + %2 = arith.mul %a, %b : c32 + %3 = arith.div %a, %b : c32 + %4 = arith.neg %a : c32 + %5 = arith.conj %a : c32 + %6 = arith.abs %a : c32 + %7 = arith.im %a : c32 + %8 = arith.re %a : c32 +; CHECK-LABEL: func @known_arith_complex({{.*}} +; CHECK: %a = constant [0x1.8p+1,0x1p+1] -> c32 +; CHECK: %b = constant [-0x1p+0,0x1.4p+2] -> c32 +; CHECK-NEXT: %0 = constant [0x1p+1,0x1.cp+2] -> c32 +; CHECK-NEXT: %1 = constant [0x1p+2,-0x1.8p+1] -> c32 +; CHECK-NEXT: %2 = constant [-0x1.ap+3,0x1.ap+3] -> c32 +; CHECK-NEXT: %3 = constant [0x1.13b13cp-2,-0x1.4ec4eep-1] -> c32 +; CHECK-NEXT: %4 = constant [-0x1.8p+1,-0x1p+1] -> c32 +; CHECK-NEXT: %5 = constant [0x1.8p+1,-0x1p+1] -> c32 +; CHECK-NEXT: %6 = constant 0x1.cd82b4p+1 -> f32 +; CHECK-NEXT: %7 = constant 0x1p+1 -> f32 +; CHECK-NEXT: %8 = constant 0x1.8p+1 -> f32 +} From 7614a5beba352f7abda3b11c7cca2a495b3d689e Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 10 Oct 2024 19:00:20 +0200 Subject: [PATCH 047/297] Added tiny argument parser library; mixed precision arithmetic Signed-off-by: Carsten Uphoff --- docs/api/core_capi.yaml | 3 + docs/api/core_cxxapi.yaml | 1 + include/tinytc/tinytc.h | 27 ++- include/tinytc/tinytc.hpp | 27 +++ include/tinytc/types.h | 6 + include/tinytc/types.hpp | 6 + src/codegen_tools.cpp | 22 ++ src/codegen_tools.hpp | 7 + src/compiler.cpp | 27 ++- src/compiler_context.cpp | 23 +- src/compiler_context.hpp | 17 +- src/compiler_context_cache.hpp | 4 +- src/device_info.cpp | 25 +++ src/inst.cpp | 12 +- src/node/data_type_node.cpp | 13 +- src/pass/constant_propagation.cpp | 7 +- src/pass/constant_propagation.hpp | 6 + src/pass/lower_linalg.cpp | 11 +- src/passes.def | 2 +- src/scalar_type.cpp | 7 + src/scalar_type.hpp | 1 + src/support/fnv1a.hpp | 45 ++++ src/support/fnv1a_array_view.hpp | 22 ++ src/support/util.hpp | 23 -- test/generator.cpp | 77 +++++++ test/opt/check-ir/cast_forbidden.ir | 2 +- test/opt/check-ir/nesting0.ir | 2 +- test/opt/check-ir/nesting1.ir | 2 +- test/opt/check-ir/nesting2.ir | 2 +- test/opt/check-ir/nesting3.ir | 2 +- test/opt/constant-propagation.ir | 2 +- test/opt/dead-code-elimination.ir | 2 +- test/opt/dump-def-use.ir | 2 +- test/opt/insert-barrier.ir | 2 +- test/opt/insert-lifetime-stop.ir | 2 +- test/opt/work-group-size.ir | 2 +- tools/CMakeLists.txt | 1 + tools/argparser/CMakeLists.txt | 12 + tools/argparser/argparser.cpp | 301 ++++++++++++++++++++++++++ tools/argparser/argparser.hpp | 239 ++++++++++++++++++++ tools/argparser/argparser_common.hpp | 30 +++ tools/argparser/test.cpp | 71 ++++++ tools/offline_compiler/CMakeLists.txt | 4 +- tools/offline_compiler/args.cpp | 83 ------- tools/offline_compiler/args.hpp | 25 --- tools/offline_compiler/main.cpp | 47 +++- tools/opt/CMakeLists.txt | 4 +- tools/opt/args.cpp | 106 --------- tools/opt/args.hpp | 26 --- tools/opt/main.cpp | 65 +++++- 50 files changed, 1132 insertions(+), 325 deletions(-) create mode 100644 src/support/fnv1a.hpp create mode 100644 src/support/fnv1a_array_view.hpp create mode 100644 tools/argparser/CMakeLists.txt create mode 100644 tools/argparser/argparser.cpp create mode 100644 tools/argparser/argparser.hpp create mode 100644 tools/argparser/argparser_common.hpp create mode 100644 tools/argparser/test.cpp delete mode 100644 tools/offline_compiler/args.cpp delete mode 100644 tools/offline_compiler/args.hpp delete mode 100644 tools/opt/args.cpp delete mode 100644 tools/opt/args.hpp diff --git a/docs/api/core_capi.yaml b/docs/api/core_capi.yaml index f5a1da3d..4a4d2594 100644 --- a/docs/api/core_capi.yaml +++ b/docs/api/core_capi.yaml @@ -40,6 +40,7 @@ Core C-API: Compiler: enum: - tinytc_bundle_format_t + - tinytc_optflag_t function: - tinytc_run_function_pass - tinytc_list_function_passes @@ -49,6 +50,7 @@ Core C-API: - tinytc_compiler_context_create - tinytc_compiler_context_add_source - tinytc_compiler_context_set_error_reporter + - tinytc_compiler_context_set_optimization_flag - tinytc_compiler_context_set_optimization_level - tinytc_compiler_context_report_error - tinytc_compiler_context_release @@ -65,6 +67,7 @@ Core C-API: - tinytc_core_info_generic_create - tinytc_core_info_intel_create - tinytc_core_info_intel_create_from_arch + - tinytc_core_info_intel_create_from_name - tinytc_core_info_release - tinytc_core_info_retain typedef: diff --git a/docs/api/core_cxxapi.yaml b/docs/api/core_cxxapi.yaml index a9bd4ec0..fbca9723 100644 --- a/docs/api/core_cxxapi.yaml +++ b/docs/api/core_cxxapi.yaml @@ -43,6 +43,7 @@ Core C++-API: - tinytc::make_core_info_generic - tinytc::make_core_info_intel - tinytc::make_core_info_intel_from_arch + - tinytc::make_core_info_intel_from_name class: - tinytc::core_info Parser: diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index f16b731d..75b4f9a8 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -1027,6 +1027,17 @@ TINYTC_EXPORT tinytc_status_t tinytc_core_info_generic_create(tinytc_core_info_t TINYTC_EXPORT tinytc_status_t tinytc_core_info_intel_create_from_arch( tinytc_core_info_t *info, tinytc_intel_gpu_architecture_t arch); +/** + * @brief Look up core info for Intel GPU architecture + * + * @param info [out] pointer to the core_info object created + * @param name [in] architecture name + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_intel_create_from_name(tinytc_core_info_t *info, + char const *name); + /** * @brief Create core_info for Intel GPUs * @@ -1185,7 +1196,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_add_source(tinytc_compiler * * Error reporting function that is called whenever an error occurs in the parser or the builder. * - * @param ctx [in] context object + * @param ctx [inout] context object * @param reporter [in] error reporting callback; set to nullptr to disable reporting * @param user_data [in][optional] pointer to user data that is passed to the callback; can be * nullptr @@ -1195,6 +1206,20 @@ TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_add_source(tinytc_compiler TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_set_error_reporter( tinytc_compiler_context_t ctx, tinytc_error_reporter_t reporter, void *user_data); +/** + * @brief Sets an optimization flag + * + * The state can be 0 (disabled), 1 (enabled), or -1 (use default according to optimization level). + * + * @param ctx [inout] context object + * @param flag [in] optimization flag + * @param state [in] flag state + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_set_optimization_flag( + tinytc_compiler_context_t ctx, tinytc_optflag_t flag, int32_t state); + /** * @brief Set optimization level (from 0 to 2) * diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index bbb6ffda..5b1c53c6 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -494,6 +494,20 @@ class compiler_context : public shared_handle { inline void set_error_reporter(error_reporter_t reporter, void *user_data) { CHECK_STATUS(tinytc_compiler_context_set_error_reporter(obj_, reporter, user_data)); } + + /** + * @brief Sets an optimization flag + * + * The state can be 0 (disabled), 1 (enabled), or -1 (use default according to optimization + * level). + * + * @param flag optimization flag + * @param state flag state + */ + inline void set_optimization_flag(optflag flag, std::int32_t state) { + CHECK_STATUS(tinytc_compiler_context_set_optimization_flag( + obj_, static_cast(flag), state)); + } /** * @brief Set optimization level * @@ -1757,6 +1771,19 @@ inline auto make_core_info_intel_from_arch(intel_gpu_architecture arch) -> core_ return core_info{info}; } +/** + * @brief Get core info for Intel GPUs from lookup table + * + * @param name architecture name + * + * @return Core info + */ +inline auto make_core_info_intel_from_name(char const *name) -> core_info { + tinytc_core_info_t info; + CHECK_STATUS(tinytc_core_info_intel_create_from_name(&info, name)); + return core_info{info}; +} + /** * @brief Create core info for Intel GPUs manually * diff --git a/include/tinytc/types.h b/include/tinytc/types.h index bd1a934d..4af6cb0f 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -316,6 +316,12 @@ typedef enum { tinytc_bundle_format_native = 1 ///< Native device binary } tinytc_bundle_format_t; +//! Flags for optimizer +typedef enum { + tinytc_optflag_unsafe_fp_math = 0 ///< Unsafe floating point math (e.g. 0.0 * x = 0.0) +} tinytc_optflag_t; +#define TINYTC_NUMBER_OF_OPTFLAGS 10 // @todo Keep up to date with tinytc_optflag_t + //! Memory object type typedef enum { tinytc_mem_type_buffer = 0x0, ///< Buffer object (e.g. cl_mem) diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 3550982c..fb714ebf 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -278,6 +278,12 @@ enum class bundle_format { native = tinytc_bundle_format_native ///< Native device binary }; +//! Flags for optimizer +enum class optflag { + unsafe_fp_math = + tinytc_optflag_unsafe_fp_math ///< Unsafe floating point math (e.g. 0.0 * x = 0.0) +}; + //! Memory object type enum class mem_type { buffer = tinytc_mem_type_buffer, ///< Buffer object (e.g. cl_mem) diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 6f4801f6..be822845 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -6,6 +6,7 @@ #include "node/data_type_node.hpp" #include "node/value_node.hpp" #include "scalar_type.hpp" +#include "support/casting.hpp" #include "support/visit.hpp" #include "tinytc/types.h" @@ -518,4 +519,25 @@ void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int bloc [&](region_builder &bb, value block) { body(bb, block, bs); }); } +auto mixed_precision_arithmetic(region_builder &bb, arithmetic operation, value a, value b, + location const &loc) -> value { + scalar_data_type *at = dyn_cast(a->ty()); + scalar_data_type *bt = dyn_cast(b->ty()); + if (at == nullptr || bt == nullptr) { + throw compilation_error(loc, status::ir_expected_scalar); + } + if (at->ty() != bt->ty()) { + auto compatible_scalar_ty = compatible_type(at->ty(), bt->ty()); + auto compatible_ty = scalar_data_type::get(at->context(), compatible_scalar_ty); + + if (at->ty() != compatible_scalar_ty) { + a = bb.add(make_cast(a, compatible_ty, loc)); + } + if (bt->ty() != compatible_scalar_ty) { + b = bb.add(make_cast(b, compatible_ty, loc)); + } + } + return bb.add(make_arith(operation, a, b)); +} + } // namespace tinytc diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index 1ce6d989..fb47e9e5 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -21,6 +21,8 @@ namespace tinytc { +// tools for OpenCL codegen + short bits(scalar_type ty); clir::expr constant(scalar_type ty, std::int64_t value); clir::expr constant(scalar_type ty, double value); @@ -123,6 +125,8 @@ void write_matrix_block(clir::block_builder &bb, block_accessor const &block, matrix_block_description const &d, bool is_atomic, scalar_type beta_ty, clir::expr beta, core_config const &core_cfg); +// tools for tinytc lowering + using sgs_loop_body_builder_new = std::function; using uniform_loop_body_builder_new = std::function; @@ -132,6 +136,9 @@ void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, in void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int block_size, int num_tiles, value sg_id, uniform_loop_body_builder_new const &body); +auto mixed_precision_arithmetic(region_builder &bb, arithmetic operation, value a, value b, + location const &loc) -> value; + } // namespace tinytc #endif // CODEGEN_TOOLS_20240229_HPP diff --git a/src/compiler.cpp b/src/compiler.cpp index 71b13cfc..2197a615 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -36,6 +36,15 @@ using namespace tinytc; +template struct optflag_setter { + PassT &pass; + tinytc_compiler_context_t ctx; + + template void operator()(Flags &&...flags) { + (pass.set_opt_flag(flags, ctx->opt_flag(flags)), ...); + } +}; + extern "C" { tinytc_status_t tinytc_run_function_pass(char const *pass_name, tinytc_prog_t prg, @@ -45,9 +54,11 @@ tinytc_status_t tinytc_run_function_pass(char const *pass_name, tinytc_prog_t pr } return exception_to_status_code( [&] { -#define FUNCTION_PASS(NAME, CREATE_PASS) \ +#define FUNCTION_PASS(NAME, CREATE_PASS, ...) \ if (strcmp(NAME, pass_name) == 0) { \ - return run_function_pass(CREATE_PASS, *prg); \ + auto pass = CREATE_PASS; \ + optflag_setter{pass, prg->get_context()}(__VA_ARGS__); \ + return run_function_pass(std::move(pass), *prg); \ } #define FUNCTION_PASS_WITH_INFO(NAME, CREATE_PASS) \ if (strcmp(NAME, pass_name) == 0) { \ @@ -65,7 +76,7 @@ tinytc_status_t tinytc_list_function_passes(uint32_t *names_size, char const *co if (names_size == nullptr || names == nullptr) { return tinytc_status_invalid_arguments; } -#define FUNCTION_PASS(NAME, CREATE_PASS) NAME, +#define FUNCTION_PASS(NAME, CREATE_PASS, ...) NAME, #define FUNCTION_PASS_WITH_INFO(NAME, CREATE_PASS) NAME, static char const *const pass_names[] = { #include "passes.def" @@ -85,15 +96,19 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ } return exception_to_status_code( [&] { - const auto opt_level = prg->get_context()->opt_level(); + const auto ctx = prg->get_context(); + const auto opt_level = ctx->opt_level(); // passes + auto cpp = constant_propagation_pass{}; + optflag_setter{cpp, ctx}(tinytc::optflag::unsafe_fp_math); + run_function_pass(check_ir_pass{}, *prg); if (opt_level >= 1) { // We run constant propagation + dead code elimination early to capture dead allocas // (later on they are maybe "in use" due to the lifetime_stop instruction) - run_function_pass(constant_propagation_pass{}, *prg); + run_function_pass(cpp, *prg); run_function_pass(dead_code_elimination_pass{}, *prg); } @@ -104,7 +119,7 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ run_function_pass(lower_linalg_pass{info}, *prg); if (opt_level >= 1) { - run_function_pass(constant_propagation_pass{}, *prg); + run_function_pass(cpp, *prg); run_function_pass(dead_code_elimination_pass{}, *prg); } diff --git a/src/compiler_context.cpp b/src/compiler_context.cpp index 505da8e9..604ff9e4 100644 --- a/src/compiler_context.cpp +++ b/src/compiler_context.cpp @@ -8,6 +8,7 @@ #include "tinytc/types.h" #include "tinytc/types.hpp" +#include #include namespace tinytc { @@ -21,7 +22,9 @@ using namespace tinytc; extern "C" { tinytc_compiler_context::tinytc_compiler_context() - : cache_{std::make_unique(this)} {} + : cache_{std::make_unique(this)} { + opt_flags_.fill(-1); +} auto tinytc_compiler_context::source_name(std::int32_t source_id) -> std::pair { @@ -46,6 +49,15 @@ void tinytc_compiler_context::report_error(location const &l, char const *what) reporter_(err.c_str(), &l, user_data_); } +auto tinytc_compiler_context::opt_flag(tinytc_optflag_t flag) const -> bool { + const auto state = opt_flags_[flag]; + if (state >= 0) { + return state > 0; + } + const auto clamped_opt_level = std::min(2, std::max(0, opt_level_)); + return default_opt_flags[clamped_opt_level][flag]; +} + tinytc_status_t tinytc_compiler_context_create(tinytc_compiler_context_t *ctx) { if (ctx == nullptr) { return tinytc_status_invalid_arguments; @@ -71,6 +83,15 @@ tinytc_status_t tinytc_compiler_context_set_error_reporter(tinytc_compiler_conte return exception_to_status_code([&] { ctx->set_error_reporter(reporter, user_data); }); } +tinytc_status_t tinytc_compiler_context_set_optimization_flag(tinytc_compiler_context_t ctx, + tinytc_optflag_t flag, + int32_t state) { + if (ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { ctx->opt_flag(flag, state); }); +} + tinytc_status_t tinytc_compiler_context_set_optimization_level(tinytc_compiler_context_t ctx, int32_t level) { if (ctx == nullptr) { diff --git a/src/compiler_context.hpp b/src/compiler_context.hpp index dd14aa99..1bb2a568 100644 --- a/src/compiler_context.hpp +++ b/src/compiler_context.hpp @@ -8,6 +8,7 @@ #include "tinytc/types.h" #include "tinytc/types.hpp" +#include #include #include #include @@ -25,6 +26,8 @@ class compiler_context_cache; struct tinytc_compiler_context : tinytc::reference_counted { public: constexpr static const char unavailable_source_name[] = "Source name unavailable"; + constexpr static std::array, 3u> default_opt_flags = + {{{false}, {false}, {true}}}; tinytc_compiler_context(); @@ -48,8 +51,17 @@ struct tinytc_compiler_context : tinytc::reference_counted { auto source_text(std::int32_t source_id) -> std::pair; void report_error(tinytc_location const &l, char const *what); - auto opt_level() const noexcept -> std::int32_t { return opt_level_; } - void opt_level(std::int32_t level) noexcept { opt_level_ = level; } + auto opt_flag(tinytc_optflag_t flag) const -> bool; + inline void opt_flag(tinytc_optflag_t flag, std::int32_t state) { opt_flags_[flag] = state; } + inline auto opt_flag(tinytc::optflag flag) const -> bool { + return opt_flag(static_cast(flag)); + } + inline void opt_flag(tinytc::optflag flag, std::int32_t state) { + opt_flag(static_cast(flag), state); + } + + inline auto opt_level() const noexcept -> std::int32_t { return opt_level_; } + inline void opt_level(std::int32_t level) noexcept { opt_level_ = level; } private: struct source_input { @@ -64,6 +76,7 @@ struct tinytc_compiler_context : tinytc::reference_counted { tinytc::error_reporter_t reporter_ = &tinytc::default_error_reporter; void *user_data_ = nullptr; std::vector sources_; + std::array opt_flags_; std::int32_t opt_level_ = 2; }; diff --git a/src/compiler_context_cache.hpp b/src/compiler_context_cache.hpp index ca85b8a4..679bf182 100644 --- a/src/compiler_context_cache.hpp +++ b/src/compiler_context_cache.hpp @@ -5,7 +5,7 @@ #define COMPILER_CONTEXT_CACHE_20240925_HPP #include "node/data_type_node.hpp" -#include "support/util.hpp" +#include "support/fnv1a.hpp" #include "tinytc/types.h" #include @@ -20,7 +20,7 @@ namespace std { template <> class hash> { public: auto operator()(std::pair const &key) const -> std::size_t { - return tinytc::fnv1a(key.first, key.second); + return tinytc::fnv1a_combine(key.first, key.second); } }; } // namespace std diff --git a/src/device_info.cpp b/src/device_info.cpp index 5d239378..5fb16074 100644 --- a/src/device_info.cpp +++ b/src/device_info.cpp @@ -3,11 +3,14 @@ #include "device_info.hpp" #include "error.hpp" +#include "support/fnv1a.hpp" #include "tinytc/tinytc.h" +#include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" #include +#include #include #include #include @@ -155,6 +158,28 @@ tinytc_status_t tinytc_core_info_intel_create_from_arch(tinytc_core_info_t *info }); } +tinytc_status_t tinytc_core_info_intel_create_from_name(tinytc_core_info_t *info, + char const *name) { + if (info == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + switch (fnv1a(name, std::strlen(name))) { + case "tgl"_fnv1a: + CHECK_STATUS( + tinytc_core_info_intel_create_from_arch(info, tinytc_intel_gpu_architecture_tgl)); + break; + case "pvc"_fnv1a: + CHECK_STATUS( + tinytc_core_info_intel_create_from_arch(info, tinytc_intel_gpu_architecture_pvc)); + break; + default: + *info = nullptr; + throw status::invalid_arguments; + } + }); +} + tinytc_status_t tinytc_core_info_intel_create(tinytc_core_info_t *info, uint32_t ip_version, int32_t num_eus_per_subslice, int32_t num_threads_per_eu, uint32_t sgs_size, diff --git a/src/inst.cpp b/src/inst.cpp index 0749735f..acd570b1 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -63,10 +63,18 @@ char const *tinytc_arithmetic_to_string(tinytc_arithmetic_t op) { char const *tinytc_arithmetic_unary_to_string(tinytc_arithmetic_unary_t op) { switch (op) { - case tinytc_arithmetic_unary_neg: - return "neg"; + case tinytc_arithmetic_unary_abs: + return "abs"; case tinytc_arithmetic_unary_not: return "not"; + case tinytc_arithmetic_unary_neg: + return "neg"; + case tinytc_arithmetic_unary_conj: + return "conj"; + case tinytc_arithmetic_unary_im: + return "im"; + case tinytc_arithmetic_unary_re: + return "re"; } return "unknown"; } diff --git a/src/node/data_type_node.cpp b/src/node/data_type_node.cpp index f917a8c9..5a196b30 100644 --- a/src/node/data_type_node.cpp +++ b/src/node/data_type_node.cpp @@ -5,7 +5,8 @@ #include "compiler_context_cache.hpp" #include "error.hpp" #include "support/casting.hpp" -#include "support/util.hpp" +#include "support/fnv1a.hpp" +#include "support/fnv1a_array_view.hpp" #include "tinytc/types.hpp" #include @@ -109,15 +110,7 @@ auto memref_data_type::canonical_stride(array_view shape) } auto memref_data_type_key::hash() -> std::uint64_t { - std::uint64_t hash = fnv1a0(); - hash = fnv1a_step(hash, element_ty); - for (auto &s : shape) { - hash = fnv1a_step(hash, s); - } - for (auto &s : stride) { - hash = fnv1a_step(hash, s); - } - return fnv1a_step(hash, addrspace); + return fnv1a_combine(element_ty, shape, stride, addrspace); } auto memref_data_type_key::operator==(memref_data_type const &mt) -> bool { diff --git a/src/pass/constant_propagation.cpp b/src/pass/constant_propagation.cpp index 64c34527..555970c0 100644 --- a/src/pass/constant_propagation.cpp +++ b/src/pass/constant_propagation.cpp @@ -15,7 +15,6 @@ #include "support/ilist_base.hpp" #include "support/visit.hpp" #include "tinytc/tinytc.hpp" -#include "tinytc/types.hpp" #include #include @@ -278,4 +277,10 @@ void constant_propagation_pass::run_on_region(region_node ®) { } } +void constant_propagation_pass::set_opt_flag(tinytc::optflag flag, bool enabled) { + if (flag == tinytc::optflag::unsafe_fp_math) { + enable_unsafe_fp_math_ = enabled; + } +} + } // namespace tinytc diff --git a/src/pass/constant_propagation.hpp b/src/pass/constant_propagation.hpp index 82964539..2d18090e 100644 --- a/src/pass/constant_propagation.hpp +++ b/src/pass/constant_propagation.hpp @@ -5,6 +5,7 @@ #define CONSTANT_PROPAGATION_20240807_HPP #include "tinytc/types.h" +#include "tinytc/types.hpp" namespace tinytc { @@ -12,6 +13,11 @@ class constant_propagation_pass { public: void run_on_function(::tinytc_func &fn); void run_on_region(::tinytc_region ®); + + void set_opt_flag(tinytc::optflag flag, bool enabled); + + private: + bool enable_unsafe_fp_math_; }; } // namespace tinytc diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index b27b8ba9..be8e953e 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -77,8 +77,15 @@ auto linalg_generator::operator()(ger_inst &g) -> inst { [&](region_builder &bb, value block, bool, value) { auto mm = bb.add(make_arith(arithmetic::add, block, m_index, g.loc())); auto a = bb.add(make_load(&g.A(), {mm}, g.loc())); - auto ab = bb.add(make_arith(arithmetic::mul, a, b, g.loc())); - bb.add(make_store(ab, &g.C(), {mm, nn}, g.loc())); + auto ab = mixed_precision_arithmetic(bb, arithmetic::mul, a, b, g.loc()); + auto alpha_ab = mixed_precision_arithmetic(bb, arithmetic::mul, &g.alpha(), + ab, g.loc()); + auto c = bb.add(make_load(&g.C(), {mm, nn}, g.loc())); + auto beta_c = + mixed_precision_arithmetic(bb, arithmetic::mul, &g.beta(), c, g.loc()); + auto alpha_ab_plus_beta_c = mixed_precision_arithmetic( + bb, arithmetic::add, alpha_ab, beta_c, g.loc()); + bb.add(make_store(alpha_ab_plus_beta_c, &g.C(), {mm, nn}, g.loc())); }); }); }); diff --git a/src/passes.def b/src/passes.def index ad833076..eda7d20a 100644 --- a/src/passes.def +++ b/src/passes.def @@ -2,7 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause FUNCTION_PASS("check-ir", check_ir_pass{}) -FUNCTION_PASS("constant-propagation", constant_propagation_pass{}) +FUNCTION_PASS("constant-propagation", constant_propagation_pass{}, tinytc::optflag::unsafe_fp_math) FUNCTION_PASS("dead-code-elimination", dead_code_elimination_pass{}) FUNCTION_PASS("dump-control-flow-graph", dump_cfg_pass{std::cout}) FUNCTION_PASS("dump-def-use", dump_def_use_pass{std::cout}) diff --git a/src/scalar_type.cpp b/src/scalar_type.cpp index ac879872..6ee734cc 100644 --- a/src/scalar_type.cpp +++ b/src/scalar_type.cpp @@ -2,10 +2,12 @@ // SPDX-License-Identifier: BSD-3-Clause #include "scalar_type.hpp" +#include "support/util.hpp" #include "tinytc/tinytc.h" #include "tinytc/types.h" #include "tinytc/types.hpp" +#include #include namespace tinytc { @@ -59,6 +61,11 @@ scalar_type element_type(scalar_type ty) { return ty; } +scalar_type compatible_type(scalar_type a_ty, scalar_type b_ty) { + int max = std::max(static_cast(a_ty), static_cast(b_ty)); + return enum_cast(max); +} + clir::data_type to_clir_ty(scalar_type ty, clir::address_space as, clir::type_qualifier q) { return to_clir_ty(ty, 1, as, q); } diff --git a/src/scalar_type.hpp b/src/scalar_type.hpp index 6f88e771..47fe0b97 100644 --- a/src/scalar_type.hpp +++ b/src/scalar_type.hpp @@ -19,6 +19,7 @@ bool is_floating_type(scalar_type ty); bool is_complex_type(scalar_type ty); bool is_integer_type(scalar_type ty); scalar_type element_type(scalar_type ty); +scalar_type compatible_type(scalar_type a_ty, scalar_type b_ty); clir::data_type to_clir_ty(scalar_type ty, clir::address_space as = clir::address_space::generic_t, clir::type_qualifier q = clir::type_qualifier::none); clir::data_type to_clir_ty(scalar_type ty, short size, diff --git a/src/support/fnv1a.hpp b/src/support/fnv1a.hpp new file mode 100644 index 00000000..e0bb5ba5 --- /dev/null +++ b/src/support/fnv1a.hpp @@ -0,0 +1,45 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef FNV1A_20241009_HPP +#define FNV1A_20241009_HPP + +#include +#include +#include + +namespace tinytc { + +constexpr auto fnv1a0() -> std::uint64_t { return 0xcbf29ce484222325; } +constexpr auto fnv1a_step(std::uint64_t hash, char ch) -> std::uint64_t { + return (hash ^ ch) * 0x00000100000001b3; +} +constexpr auto fnv1a_steps(std::uint64_t hash, char const *s, std::size_t len) -> std::uint64_t { + for (std::size_t i = 0; i < len; ++i) { + hash = fnv1a_step(hash, s[i]); + } + return hash; +} +constexpr auto fnv1a(char const *s, std::size_t len) -> std::uint64_t { + return fnv1a_steps(fnv1a0(), s, len); +} +constexpr auto operator""_fnv1a(char const *s, std::size_t len) -> std::uint64_t { + return fnv1a_steps(fnv1a0(), s, len); +} + +template +requires(std::is_trivial_v) +constexpr auto fnv1a_step(std::uint64_t hash, T const &data) -> std::uint64_t { + char buf[sizeof(T)]; + std::memcpy(buf, &data, sizeof(T)); + return fnv1a_steps(hash, buf, sizeof(T)); +} + +template constexpr auto fnv1a_combine(T const &...t) -> std::uint64_t { + auto impl = [hash = fnv1a0()](auto const &ti) mutable { return hash = fnv1a_step(hash, ti); }; + return (..., impl(t)); +} + +} // namespace tinytc + +#endif // FNV1A_20241009_HPP diff --git a/src/support/fnv1a_array_view.hpp b/src/support/fnv1a_array_view.hpp new file mode 100644 index 00000000..d95ee615 --- /dev/null +++ b/src/support/fnv1a_array_view.hpp @@ -0,0 +1,22 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef FNV1A_ARRAY_VIEW_20241010_HPP +#define FNV1A_ARRAY_VIEW_20241010_HPP + +#include "support/fnv1a.hpp" +#include "tinytc/tinytc.hpp" + +namespace tinytc { + +template +constexpr auto fnv1a_step(std::uint64_t hash, array_view const &data) -> std::uint64_t { + for (auto const &i : data) { + hash = fnv1a_step(hash, i); + } + return hash; +} + +} // namespace tinytc + +#endif // FNV1A_ARRAY_VIEW_20241010_HPP diff --git a/src/support/util.hpp b/src/support/util.hpp index 1f1672a6..ff4bc722 100644 --- a/src/support/util.hpp +++ b/src/support/util.hpp @@ -15,29 +15,6 @@ template auto enum_cast(V val) { return T{std::underlying_type_t(val)}; } -constexpr auto fnv1a0() -> std::uint64_t { return 0xcbf29ce484222325; } -constexpr auto fnv1a_step(std::uint64_t hash, char ch) -> std::uint64_t { - return (hash ^ ch) * 0x00000100000001b3; -} -template constexpr auto fnv1a_step(std::uint64_t hash, T &&t) -> std::uint64_t { - char buf[sizeof(T)]; - std::memcpy(buf, &t, sizeof(T)); - for (std::size_t i = 0; i < sizeof(T); ++i) { - hash = fnv1a_step(hash, buf[i]); - } - return hash; -} - -template -constexpr auto fnv1a_step(std::uint64_t hash, Head &&head, Tail &&...tail) -> std::uint64_t { - return fnv1a_step(fnv1a_step(hash, std::forward(tail)...), std::forward(head)); -} - -template -constexpr auto fnv1a(Head &&head, Tail &&...tail) -> std::uint64_t { - return fnv1a_step(fnv1a_step(fnv1a0(), std::forward(tail)...), std::forward(head)); -} - template class iterator_range_wrapper { public: iterator_range_wrapper(ItT begin, ItT end) : begin_(std::move(begin)), end_(std::move(end)) {} diff --git a/test/generator.cpp b/test/generator.cpp index 94302ecc..8b88d2dc 100644 --- a/test/generator.cpp +++ b/test/generator.cpp @@ -4,6 +4,8 @@ #include "device_info.hpp" #include "gemm_generator.hpp" #include "reference_counted.hpp" +#include "scalar_type.hpp" +#include "support/util.hpp" #include "tiling.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" @@ -81,3 +83,78 @@ TEST_CASE("max register block") { CHECK(d2.first == 2); CHECK(d2.second == 19); } + +TEST_CASE("compatible scalar type") { + for (int i = 0; i < TINYTC_NUMBER_OF_SCALAR_TYPES; ++i) { + const auto si = enum_cast(i); + for (int j = 0; j < TINYTC_NUMBER_OF_SCALAR_TYPES; ++j) { + const auto sj = enum_cast(j); + CHECK(compatible_type(si, sj) == compatible_type(sj, si)); + } + } + + CHECK(compatible_type(scalar_type::i1, scalar_type::i1) == scalar_type::i1); + CHECK(compatible_type(scalar_type::i1, scalar_type::i8) == scalar_type::i8); + CHECK(compatible_type(scalar_type::i1, scalar_type::i16) == scalar_type::i16); + CHECK(compatible_type(scalar_type::i1, scalar_type::i32) == scalar_type::i32); + CHECK(compatible_type(scalar_type::i1, scalar_type::i64) == scalar_type::i64); + CHECK(compatible_type(scalar_type::i1, scalar_type::index) == scalar_type::index); + CHECK(compatible_type(scalar_type::i1, scalar_type::f32) == scalar_type::f32); + CHECK(compatible_type(scalar_type::i1, scalar_type::f64) == scalar_type::f64); + CHECK(compatible_type(scalar_type::i1, scalar_type::c32) == scalar_type::c32); + CHECK(compatible_type(scalar_type::i1, scalar_type::c64) == scalar_type::c64); + + CHECK(compatible_type(scalar_type::i8, scalar_type::i8) == scalar_type::i8); + CHECK(compatible_type(scalar_type::i8, scalar_type::i16) == scalar_type::i16); + CHECK(compatible_type(scalar_type::i8, scalar_type::i32) == scalar_type::i32); + CHECK(compatible_type(scalar_type::i8, scalar_type::i64) == scalar_type::i64); + CHECK(compatible_type(scalar_type::i8, scalar_type::index) == scalar_type::index); + CHECK(compatible_type(scalar_type::i8, scalar_type::f32) == scalar_type::f32); + CHECK(compatible_type(scalar_type::i8, scalar_type::f64) == scalar_type::f64); + CHECK(compatible_type(scalar_type::i8, scalar_type::c32) == scalar_type::c32); + CHECK(compatible_type(scalar_type::i8, scalar_type::c64) == scalar_type::c64); + + CHECK(compatible_type(scalar_type::i16, scalar_type::i16) == scalar_type::i16); + CHECK(compatible_type(scalar_type::i16, scalar_type::i32) == scalar_type::i32); + CHECK(compatible_type(scalar_type::i16, scalar_type::i64) == scalar_type::i64); + CHECK(compatible_type(scalar_type::i16, scalar_type::index) == scalar_type::index); + CHECK(compatible_type(scalar_type::i16, scalar_type::f32) == scalar_type::f32); + CHECK(compatible_type(scalar_type::i16, scalar_type::f64) == scalar_type::f64); + CHECK(compatible_type(scalar_type::i16, scalar_type::c32) == scalar_type::c32); + CHECK(compatible_type(scalar_type::i16, scalar_type::c64) == scalar_type::c64); + + CHECK(compatible_type(scalar_type::i32, scalar_type::i32) == scalar_type::i32); + CHECK(compatible_type(scalar_type::i32, scalar_type::i64) == scalar_type::i64); + CHECK(compatible_type(scalar_type::i32, scalar_type::index) == scalar_type::index); + CHECK(compatible_type(scalar_type::i32, scalar_type::f32) == scalar_type::f32); + CHECK(compatible_type(scalar_type::i32, scalar_type::f64) == scalar_type::f64); + CHECK(compatible_type(scalar_type::i32, scalar_type::c32) == scalar_type::c32); + CHECK(compatible_type(scalar_type::i32, scalar_type::c64) == scalar_type::c64); + + CHECK(compatible_type(scalar_type::i64, scalar_type::i64) == scalar_type::i64); + CHECK(compatible_type(scalar_type::i64, scalar_type::index) == scalar_type::index); + CHECK(compatible_type(scalar_type::i64, scalar_type::f32) == scalar_type::f32); + CHECK(compatible_type(scalar_type::i64, scalar_type::f64) == scalar_type::f64); + CHECK(compatible_type(scalar_type::i64, scalar_type::c32) == scalar_type::c32); + CHECK(compatible_type(scalar_type::i64, scalar_type::c64) == scalar_type::c64); + + CHECK(compatible_type(scalar_type::index, scalar_type::index) == scalar_type::index); + CHECK(compatible_type(scalar_type::index, scalar_type::f32) == scalar_type::f32); + CHECK(compatible_type(scalar_type::index, scalar_type::f64) == scalar_type::f64); + CHECK(compatible_type(scalar_type::index, scalar_type::c32) == scalar_type::c32); + CHECK(compatible_type(scalar_type::index, scalar_type::c64) == scalar_type::c64); + + CHECK(compatible_type(scalar_type::f32, scalar_type::f32) == scalar_type::f32); + CHECK(compatible_type(scalar_type::f32, scalar_type::f64) == scalar_type::f64); + CHECK(compatible_type(scalar_type::f32, scalar_type::c32) == scalar_type::c32); + CHECK(compatible_type(scalar_type::f32, scalar_type::c64) == scalar_type::c64); + + CHECK(compatible_type(scalar_type::f64, scalar_type::f64) == scalar_type::f64); + CHECK(compatible_type(scalar_type::f64, scalar_type::c32) == scalar_type::c32); + CHECK(compatible_type(scalar_type::f64, scalar_type::c64) == scalar_type::c64); + + CHECK(compatible_type(scalar_type::c32, scalar_type::c32) == scalar_type::c32); + CHECK(compatible_type(scalar_type::c32, scalar_type::c64) == scalar_type::c64); + + CHECK(compatible_type(scalar_type::c64, scalar_type::c64) == scalar_type::c64); +} diff --git a/test/opt/check-ir/cast_forbidden.ir b/test/opt/check-ir/cast_forbidden.ir index f9665793..d09c0fe4 100644 --- a/test/opt/check-ir/cast_forbidden.ir +++ b/test/opt/check-ir/cast_forbidden.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: not %tinytc-opt --check-ir < %s 2>&1 | filecheck %s +; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @cast_cf() { %0 = constant [2.0, 1.0] -> c32 %1 = cast %0 : c32 -> i32 diff --git a/test/opt/check-ir/nesting0.ir b/test/opt/check-ir/nesting0.ir index f2627b77..aa30d186 100644 --- a/test/opt/check-ir/nesting0.ir +++ b/test/opt/check-ir/nesting0.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: not %tinytc-opt --check-ir < %s 2>&1 | filecheck %s +; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @illegal_nesting(%c: f32, %A: memref, %B: memref, %C: memref) { %lb = constant 1 -> index %ub = constant 16 -> index diff --git a/test/opt/check-ir/nesting1.ir b/test/opt/check-ir/nesting1.ir index 8247c30b..a3b13333 100644 --- a/test/opt/check-ir/nesting1.ir +++ b/test/opt/check-ir/nesting1.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: not %tinytc-opt --check-ir < %s 2>&1 | filecheck %s +; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @illegal_nesting() { %lb = constant 1 -> index %ub = constant 16 -> index diff --git a/test/opt/check-ir/nesting2.ir b/test/opt/check-ir/nesting2.ir index 35f302e5..f9aa0714 100644 --- a/test/opt/check-ir/nesting2.ir +++ b/test/opt/check-ir/nesting2.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: not %tinytc-opt --check-ir < %s 2>&1 | filecheck %s +; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @illegal_nesting() { %0 = subgroup_id ; CHECK: 6.10-20: SPMD instruction must not be called from collective region diff --git a/test/opt/check-ir/nesting3.ir b/test/opt/check-ir/nesting3.ir index ab1558c4..28005dcc 100644 --- a/test/opt/check-ir/nesting3.ir +++ b/test/opt/check-ir/nesting3.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: not %tinytc-opt --check-ir < %s 2>&1 | filecheck %s +; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @illegal_nesting() { %lb = constant 1 -> index %ub = constant 16 -> index diff --git a/test/opt/constant-propagation.ir b/test/opt/constant-propagation.ir index 7add13c2..0d58463a 100644 --- a/test/opt/constant-propagation.ir +++ b/test/opt/constant-propagation.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-opt --constant-propagation < %s | filecheck %s +; RUN: %tinytc-opt -pconstant-propagation < %s | filecheck %s func @known_size(%a: memref) { %0 = size %a[0] : memref %1 = size %a[1] : memref diff --git a/test/opt/dead-code-elimination.ir b/test/opt/dead-code-elimination.ir index 66fec77e..5f63d6d1 100644 --- a/test/opt/dead-code-elimination.ir +++ b/test/opt/dead-code-elimination.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-opt --dead-code-elimination < %s | filecheck %s +; RUN: %tinytc-opt -pdead-code-elimination < %s | filecheck %s func @dead_if(%a: memref) { %c0 = constant 0 -> i1 if %c0 { diff --git a/test/opt/dump-def-use.ir b/test/opt/dump-def-use.ir index d0e88a9c..23fe73a9 100644 --- a/test/opt/dump-def-use.ir +++ b/test/opt/dump-def-use.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-opt --dump-def-use < %s | filecheck %s +; RUN: %tinytc-opt -pdump-def-use < %s | filecheck %s func @foobar() { %one = constant 1 -> index diff --git a/test/opt/insert-barrier.ir b/test/opt/insert-barrier.ir index c6711a9a..824ada77 100644 --- a/test/opt/insert-barrier.ir +++ b/test/opt/insert-barrier.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-opt --insert-barrier < %s | filecheck %s +; RUN: %tinytc-opt -pinsert-barrier < %s | filecheck %s func @rar(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { gemm.n.n %a, %A, %B, %b, %C : f32, memref, memref, f32, memref gemm.n.n %a, %A, %B, %b, %D : f32, memref, memref, f32, memref diff --git a/test/opt/insert-lifetime-stop.ir b/test/opt/insert-lifetime-stop.ir index 7a196e7e..d06605f0 100644 --- a/test/opt/insert-lifetime-stop.ir +++ b/test/opt/insert-lifetime-stop.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-opt --insert-lifetime-stop < %s | filecheck %s +; RUN: %tinytc-opt -pinsert-lifetime-stop < %s | filecheck %s func @basic() { %0 = alloca -> memref ; CHECK: %0 = alloca -> memref diff --git a/test/opt/work-group-size.ir b/test/opt/work-group-size.ir index be976dc2..98086c96 100644 --- a/test/opt/work-group-size.ir +++ b/test/opt/work-group-size.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-opt -d pvc --work-group-size < %s | filecheck %s +; RUN: %tinytc-opt -dpvc -pwork-group-size < %s | filecheck %s func @default_pvc() { ; CHECK: func @default_pvc() subgroup_size(32) work_group_size(32,1) { } diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index cb910352..01f3fe59 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -1,5 +1,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause +add_subdirectory(argparser) add_subdirectory(offline_compiler) add_subdirectory(opt) diff --git a/tools/argparser/CMakeLists.txt b/tools/argparser/CMakeLists.txt new file mode 100644 index 00000000..d23245c2 --- /dev/null +++ b/tools/argparser/CMakeLists.txt @@ -0,0 +1,12 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause + +include(CommonOptions) + +add_library(argparser STATIC argparser.cpp) +target_include_directories(argparser PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ${PROJECT_SOURCE_DIR}/src) +set_cxx_common_options(argparser) + +add_executable(test-argparser test.cpp) +target_link_libraries(test-argparser PRIVATE argparser) +set_cxx_common_options(test-argparser) diff --git a/tools/argparser/argparser.cpp b/tools/argparser/argparser.cpp new file mode 100644 index 00000000..c9e980c7 --- /dev/null +++ b/tools/argparser/argparser.cpp @@ -0,0 +1,301 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "argparser.hpp" +#include "support/fnv1a.hpp" + +#include +#include +#include +#include +#include +#include + +namespace tinytc::cmd { + +auto to_string(parser_status status) -> char const * { + switch (status) { + case parser_status::invalid_short_opt: + return "Short options must be alphanumeric"; + case parser_status::unknown_short_opt: + return "Unknown short option"; + case parser_status::invalid_long_opt: + return "Long options must be lowercase alphanumeric words, optionally separated by hyphens"; + case parser_status::unknown_long_opt: + return "Unknown long option"; + case parser_status::unknown_positional_arg: + return "Unknown positional argument"; + case parser_status::required_argument_missing: + return "Required argument missing"; + case parser_status::flag_does_not_take_argument: + return "Flag does not take argument"; + case parser_status::converter_functional_missing: + return "Non-default convertible type need converter functional"; + case parser_status::invalid_argument: + return "Invalid argument"; + case parser_status::argument_out_of_range: + return "Argument is out of range"; + case parser_status::required_must_not_follow_optional: + return "Required positional argument must not follow optional positional argument"; + case parser_status::positional_must_not_follow_multiarg: + return "Positional argument must not follow positional ellipsis argument"; + case parser_status::hash_conflict: + return "Long option hash conflict, please rename one of the long options"; + case parser_status::success: + break; + } + return ""; +} + +arg_parser::arg_parser() : short_(2 * 26 + 10) {} + +void arg_parser::parse(int argc, char **argv) { + auto const parse_short = [&](int pos, int subpos) -> int { + char const *str = argv[pos]; + do { + if (!std::isalnum(str[subpos])) { + throw arg_parser_error(argc, argv, pos, subpos, parser_status::invalid_short_opt); + } + auto &shortopt = short_[short_index(argv[pos][subpos])]; + if (shortopt.par) { + ++subpos; + if (shortopt.par->is_flag()) { + shortopt.par->set(nullptr); + } else { + auto status = parser_status::success; + if (str[subpos] != 0) { + status = shortopt.par->set(str + subpos); + } else if (shortopt.par->is_argument_required()) { + status = shortopt.par->set(pos + 1 < argc ? argv[++pos] : nullptr); + } else { + status = shortopt.par->set(nullptr); + } + if (status != parser_status::success) { + throw arg_parser_error(argc, argv, pos, subpos, status); + } + break; + } + } else { + throw arg_parser_error(argc, argv, pos, subpos, parser_status::unknown_short_opt); + } + } while (str[subpos]); + return pos; + }; + + auto const parse_long = [&](int pos, int subpos) -> int { + char const *str = argv[pos]; + + int key_end = subpos; + while (str[key_end] != '=' && str[key_end] != 0) { + ++key_end; + } + const auto key = fnv1a(str + subpos, key_end - subpos); + if (long_.find(key) == long_.end()) { + throw arg_parser_error(argc, argv, pos, subpos, parser_status::unknown_long_opt); + } + + auto &longopt = long_[key]; + if (longopt.par->is_flag()) { + longopt.par->set(nullptr); + if (str[key_end] != 0) { + throw arg_parser_error(argc, argv, pos, subpos, + parser_status::flag_does_not_take_argument); + } + } else { + auto status = parser_status::success; + if (str[key_end] != 0) { + status = longopt.par->set(str + key_end + 1); + } else { + status = longopt.par->set(nullptr); + } + if (status != parser_status::success) { + throw arg_parser_error(argc, argv, pos, key_end + 1, status); + } + } + return pos; + }; + + std::size_t positional_arg_index = 0; + auto parse_positional = [&](int pos, int subpos) -> int { + if (positional_arg_index >= positional_.size()) { + throw arg_parser_error(argc, argv, pos, subpos, parser_status::unknown_positional_arg); + } + auto &arg = positional_[positional_arg_index]; + arg.par->set(argv[pos]); + if (!arg.par->does_store_multiple()) { + ++positional_arg_index; + } + return pos; + }; + + int pos = 1; + for (; pos < argc; ++pos) { + int subpos = 0; + if (argv[pos][subpos] == '-') { + ++subpos; + if (argv[pos][subpos] == '-') { + ++subpos; + if (argv[pos][subpos] == 0) { + ++pos; + pos = parse_positional(pos, subpos); + } else { + pos = parse_long(pos, subpos); + } + } else { + pos = parse_short(pos, subpos); + } + } else { + pos = parse_positional(pos, subpos); + } + } + + if (positional_arg_index < positional_.size() && + positional_[positional_arg_index].par->is_argument_required()) { + throw arg_parser_error(argc, argv, pos, 0, parser_status::required_argument_missing); + } +} + +void arg_parser::print_help(std::ostream &os, char const *name, char const *description) { + constexpr int optwidth = 20; + + const auto print = [&](auto const &key, auto const &par, char const *init, char const *sep_req, + char const *sep_nonreq) { + if (par) { + os << '[' << init << key; + if (!par->is_flag()) { + const bool req = par->is_argument_required(); + if (!req) { + os << '['; + } + os << (req ? sep_req : sep_nonreq); + os << "arg"; + if (!req) { + os << ']'; + } + } + os << ']'; + if (par->does_store_multiple()) { + os << "..."; + } + } + }; + const auto print_short = [&](char i) { + auto const &par = short_[short_index(i)].par; + print(i, par, "-", " ", ""); + }; + const auto print_short_help = [&](char i) { + auto const &opt = short_[short_index(i)]; + if (opt.par) { + os << " -" << std::left << std::setw(optwidth) << i << opt.help << std::endl; + } + }; + os << "Usage: " << name; + for (char i = '0'; i < '9'; ++i) { + print_short(i); + } + for (char i = 'a'; i < 'z'; ++i) { + print_short(i); + print_short(std::toupper(i)); + } + auto long_opts = std::vector{}; + long_opts.reserve(long_.size()); + for (auto it = long_.begin(); it != long_.end(); ++it) { + long_opts.emplace_back(&it->second); + } + std::sort(long_opts.begin(), long_opts.end(), + [&](long_opt *a, long_opt *b) { return std::strcmp(a->opt, b->opt) < 0; }); + for (auto const &opt : long_opts) { + print(opt->opt, opt->par, "--", "=", "="); + } + for (auto const &pos : positional_) { + const bool req = pos.par->is_argument_required(); + os << (req ? ' ' : '['); + os << pos.opt; + if (pos.par->does_store_multiple()) { + os << "..."; + } + os << (req ? ' ' : ']'); + } + os << std::endl << description << std::endl << std::endl; + + os << "Positional arguments:" << std::endl; + for (auto const &pos : positional_) { + os << " " << std::left << std::setw(optwidth) << pos.opt << pos.help << std::endl; + } + + os << std::endl << "Options:" << std::endl; + for (char i = '0'; i < '9'; ++i) { + print_short_help(i); + } + for (char i = 'a'; i < 'z'; ++i) { + print_short_help(i); + print_short_help(std::toupper(i)); + } + for (auto const &opt : long_opts) { + os << " --" << std::left << std::setw(optwidth) << opt->opt << opt->help << std::endl; + } +} + +auto arg_parser::short_index(char opt) const -> std::size_t { + if (opt >= 'a') { + return 10 + 26 + (opt - 'a'); + } else if (opt >= 'A') { + return 10 + (opt - 'A'); + } + return opt - '0'; +} + +void arg_parser::set_short_opt(char key, short_opt value) { + if (!std::isalnum(key)) { + throw std::logic_error(to_string(parser_status::invalid_short_opt)); + } + short_[short_index(key)] = std::move(value); +} + +void arg_parser::set_long_opt(long_opt value) { + const auto opt_len = std::strlen(value.opt); + if (!std::all_of(value.opt, value.opt + opt_len, + [](char c) { return std::islower(c) || std::isdigit(c) || c == '-'; })) { + throw std::logic_error(to_string(parser_status::invalid_long_opt)); + } + const auto key = fnv1a(value.opt, opt_len); + if (auto it = long_.find(key); + it != long_.end() && std::strcmp(value.opt, it->second.opt) != 0) { + throw std::logic_error(to_string(parser_status::hash_conflict)); + } + long_[key] = std::move(value); +} + +void arg_parser::add_positional_arg(positional_arg value) { + if (!positional_.empty()) { + if (value.par->is_argument_required() && !positional_.back().par->is_argument_required()) { + throw std::logic_error(to_string(parser_status::required_must_not_follow_optional)); + } + if (positional_.back().par->does_store_multiple()) { + throw std::logic_error(to_string(parser_status::positional_must_not_follow_multiarg)); + } + } + positional_.emplace_back(std::move(value)); +} + +arg_parser_error::arg_parser_error(int argc, char **argv, int pos, int subpos, + parser_status status) { + auto oss = std::ostringstream{}; + oss << "==> Error in" << std::endl; + int offset = 0; + for (int i = 0; i < pos; ++i) { + oss << argv[i] << ' '; + offset += std::strlen(argv[i]) + 1; + } + if (pos < argc) { + oss << argv[pos]; + } + oss << std::endl; + const auto offset_str = std::string(offset + subpos, ' '); + oss << offset_str << '^' << std::endl; + oss << offset_str << to_string(status); + + what_ = std::move(oss).str(); +} + +} // namespace tinytc::cmd diff --git a/tools/argparser/argparser.hpp b/tools/argparser/argparser.hpp new file mode 100644 index 00000000..9c4a62a2 --- /dev/null +++ b/tools/argparser/argparser.hpp @@ -0,0 +1,239 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef ARGPARSER_20241008_HPP +#define ARGPARSER_20241008_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc::cmd { + +enum class parser_status { + success, + invalid_short_opt, + unknown_short_opt, + invalid_long_opt, + unknown_long_opt, + unknown_positional_arg, + required_argument_missing, + flag_does_not_take_argument, + converter_functional_missing, + invalid_argument, + argument_out_of_range, + required_must_not_follow_optional, + positional_must_not_follow_multiarg, + hash_conflict, +}; +auto to_string(parser_status status) -> char const *; + +template struct default_converter; +template struct default_converter { + auto operator()(char const *str, T &val) const -> parser_status { + long v = strtol(str, nullptr, 0); + if (errno == ERANGE || v < std::numeric_limits::min() || + v > std::numeric_limits::max()) { + return parser_status::argument_out_of_range; + } + val = v; + return parser_status::success; + } +}; +template <> struct default_converter { + auto operator()(char const *str, char const *&val) const -> parser_status { + val = str; + return parser_status::success; + } +}; + +template struct is_defined : std::false_type {}; +template +requires(sizeof(T) > 0) +struct is_defined : std::true_type {}; +template constexpr bool is_defined_v = is_defined::value; + +class par_concept { + public: + virtual ~par_concept() = default; + virtual auto set(char const *optional_argument) -> parser_status = 0; + virtual auto is_flag() const -> bool = 0; + virtual auto is_argument_required() const -> bool = 0; + virtual auto does_store_multiple() const -> bool = 0; +}; + +template class par_model : public par_concept { + public: + using value_type = T; + + par_model(T *ptr, std::optional default_argument) + : ptr_(ptr), default_argument_(std::move(default_argument)) {} + auto set(char const *optional_argument) -> parser_status override { + auto status = parser_status::success; + if (optional_argument != nullptr) { + if (converter_) { + status = converter_(optional_argument, *ptr_); + } else { + if constexpr (is_defined_v>) { + status = default_converter{}(optional_argument, *ptr_); + } else { + status = parser_status::converter_functional_missing; + } + } + } else if (default_argument_.has_value()) { + *ptr_ = *default_argument_; + } else { + status = parser_status::required_argument_missing; + } + if (validator_ && !validator_(*ptr_)) { + status = parser_status::invalid_argument; + } + return status; + } + auto is_flag() const -> bool override { return false; } + auto is_argument_required() const -> bool override { return !default_argument_.has_value(); } + auto does_store_multiple() const -> bool override { return false; } + + template auto converter(F &&fun) { converter_ = std::forward(fun); } + template auto validator(F &&fun) { validator_ = std::forward(fun); } + + protected: + T *ptr_; + + private: + std::optional default_argument_; + std::function converter_; + std::function validator_; +}; + +template <> class par_model : public par_concept { + public: + par_model(bool *ptr) : ptr_(ptr) {} + auto set(char const *) -> parser_status override { + *ptr_ = true; + return parser_status::success; + } + auto is_flag() const -> bool override { return true; } + auto is_argument_required() const -> bool override { return false; } + auto does_store_multiple() const -> bool override { return false; } + + protected: + bool *ptr_; +}; + +template class par_model> : public par_model { + public: + par_model(std::vector *ptr, std::optional default_argument) + : par_model{nullptr, std::move(default_argument)}, vptr_(ptr) {} + + auto set(char const *optional_argument) -> parser_status override { + vptr_->emplace_back(T{}); + this->ptr_ = &vptr_->back(); + return this->template par_model::set(optional_argument); + } + auto does_store_multiple() const -> bool override { return true; } + + private: + std::vector *vptr_; +}; + +class arg_parser { + public: + arg_parser(); + + inline void set_short_opt(char opt, bool *ptr, char const *help = nullptr) { + set_short_opt(opt, {help, std::make_unique>(ptr)}); + } + + template + auto set_short_opt(char opt, T *ptr, char const *help = nullptr, + std::optional::value_type> default_argument = + std::nullopt) -> par_model & { + auto model = std::make_unique>(ptr, std::move(default_argument)); + auto model_ptr = model.get(); + set_short_opt(opt, {help, std::move(model)}); + return *model_ptr; + } + + inline void set_long_opt(char const *opt, bool *ptr, char const *help = nullptr) { + set_long_opt({opt, help, std::make_unique>(ptr)}); + } + + template + auto set_long_opt(char const *opt, T *ptr, char const *help = nullptr, + std::optional::value_type> default_argument = + std::nullopt) -> par_model & { + auto model = std::make_unique>(ptr, std::move(default_argument)); + auto model_ptr = model.get(); + set_long_opt({opt, help, std::move(model)}); + return *model_ptr; + } + + template + auto add_positional_arg(char const *opt, T *ptr, char const *help = nullptr, + bool required = false) { + add_positional_arg( + positional_arg{opt, help, + std::make_unique>( + ptr, required ? std::nullopt : std::make_optional(*ptr))}); + } + + template + auto add_positional_arg(char const *opt, std::vector *ptr, char const *help = nullptr) { + add_positional_arg( + {opt, help, std::make_unique>>(ptr, std::make_optional(T{}))}); + } + + void parse(int argc, char **argv); + void print_help(std::ostream &os, char const *name, char const *description); + + private: + struct short_opt { + char const *help; + std::unique_ptr par; + }; + struct long_opt { + char const *opt; + char const *help; + std::unique_ptr par; + }; + struct positional_arg { + char const *opt; + char const *help; + std::unique_ptr par; + }; + + auto short_index(char opt) const -> std::size_t; + void set_short_opt(char key, short_opt value); + void set_long_opt(long_opt value); + void add_positional_arg(positional_arg value); + + std::vector short_; + std::unordered_map long_; + std::vector positional_; +}; + +class arg_parser_error : public std::exception { + public: + arg_parser_error(int argc, char **argv, int pos, int subpos, parser_status status); + inline char const *what() const noexcept override { return what_.c_str(); } + + private: + std::string what_; +}; + +} // namespace tinytc::cmd + +#endif // ARGPARSER_20241008_HPP diff --git a/tools/argparser/argparser_common.hpp b/tools/argparser/argparser_common.hpp new file mode 100644 index 00000000..dca7dfe6 --- /dev/null +++ b/tools/argparser/argparser_common.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef ARGPARSER_COMMON_20241010_HPP +#define ARGPARSER_COMMON_20241010_HPP + +#include "argparser.hpp" +#include "tinytc/tinytc.hpp" + +namespace tinytc::cmd { + +struct optflag_states { + std::int32_t unsafe_fp_math = -1; +}; + +inline void add_optflag_states(arg_parser &parser, optflag_states &flags) { + auto const validator = [](std::int32_t value) { return -1 <= value && value <= 1; }; + parser + .set_long_opt("unsafe-fp-math", &flags.unsafe_fp_math, + "Enable unsafe floating point math (e.g. 0.0 * x = 0.0)", 1) + .validator(validator); +} + +inline void set_optflags(compiler_context &ctx, optflag_states const &flags) { + ctx.set_optimization_flag(optflag::unsafe_fp_math, flags.unsafe_fp_math); +} + +} // namespace tinytc::cmd + +#endif // ARGPARSER_COMMON_20241010_HPP diff --git a/tools/argparser/test.cpp b/tools/argparser/test.cpp new file mode 100644 index 00000000..c73dd071 --- /dev/null +++ b/tools/argparser/test.cpp @@ -0,0 +1,71 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "argparser.hpp" + +#include + +using namespace tinytc; + +struct args { + bool f = false; + int a = -1; + short b = -1; + bool foo = false; + short bar = -1; + bool help = false; + int c = 0; + int d = 0; + std::vector m = {}; + std::vector m2 = {}; +}; + +int main(int argc, char **argv) { + auto a = args{}; + auto parser = cmd::arg_parser{}; + + try { + parser.set_short_opt('f', &a.f, "f opt"); + parser.set_short_opt('a', &a.a, "a opt", 5).validator([](int a) { return a > 0; }); + parser.set_short_opt('b', &a.b, "b opt"); + parser.set_short_opt('h', &a.help, "show help"); + parser.set_long_opt("help", &a.help, "show help"); + parser.set_long_opt("foo", &a.foo, "foo opt"); + parser.set_long_opt("bar", &a.bar, "bar opt"); + parser.set_long_opt("bar2", &a.bar, "bar opt", 5); + parser.add_positional_arg("c", &a.c, "c arg", true); + parser.add_positional_arg("d", &a.d, "d arg"); + parser.set_short_opt('m', &a.m, "m arg"); + parser.add_positional_arg("m2", &a.m2, "m2 arg"); + parser.parse(argc, argv); + } catch (std::exception const &e) { + if (!a.help) { + std::cerr << e.what() << std::endl; + return -1; + } + } + if (a.help) { + parser.print_help(std::cout, "test-argparser", "Test of libargparser"); + return 0; + } + + std::cout << "f: " << a.f << std::endl; + std::cout << "a: " << a.a << std::endl; + std::cout << "b: " << a.b << std::endl; + std::cout << "foo: " << a.foo << std::endl; + std::cout << "bar: " << a.bar << std::endl; + std::cout << "c: " << a.c << std::endl; + std::cout << "d: " << a.d << std::endl; + std::cout << "m: "; + for (auto const &mm : a.m) { + std::cout << mm << ' '; + } + std::cout << std::endl; + std::cout << "m2: "; + for (auto const &mm : a.m2) { + std::cout << mm << ' '; + } + std::cout << std::endl; + + return 0; +} diff --git a/tools/offline_compiler/CMakeLists.txt b/tools/offline_compiler/CMakeLists.txt index 3ea53353..36acd618 100644 --- a/tools/offline_compiler/CMakeLists.txt +++ b/tools/offline_compiler/CMakeLists.txt @@ -4,8 +4,8 @@ include(CommonOptions) include(GNUInstallDirs) -add_executable(tinytc-oc main.cpp args.cpp) -target_link_libraries(tinytc-oc PRIVATE tinytc) +add_executable(tinytc-oc main.cpp) +target_link_libraries(tinytc-oc PRIVATE tinytc argparser) set_target_properties(tinytc-oc PROPERTIES OUTPUT_NAME "tinytc") set_cxx_common_options(tinytc-oc) diff --git a/tools/offline_compiler/args.cpp b/tools/offline_compiler/args.cpp deleted file mode 100644 index 9848a043..00000000 --- a/tools/offline_compiler/args.cpp +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "args.hpp" -#include "tinytc/types.hpp" - -#include -#include -#include - -using tinytc::core_info; -using tinytc::intel_gpu_architecture; -using tinytc::make_core_info_intel_from_arch; - -auto make_core_info_from_string(char const *name) -> core_info { - if (std::strcmp(name, "pvc") == 0) { - return make_core_info_intel_from_arch(intel_gpu_architecture::pvc); - } else if (std::strcmp(name, "tgl") == 0) { - return make_core_info_intel_from_arch(intel_gpu_architecture::tgl); - } - return core_info{}; -} - -args arg_parser::parse_args(int argc, char **argv) { - args a = {}; - a.filename = nullptr; - a.opt_level = 2; - - int npos = 0; - for (int i = 1; i < argc; ++i) { - if (argv[i][0] == '-') { - auto const fail = [&]() { - throw std::runtime_error( - (std::ostringstream{} << "==> Unrecognized argument: " << argv[i]).str()); - }; - if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0) { - a.help = true; - } else if (std::strcmp(argv[i], "-O0") == 0) { - a.opt_level = 0; - } else if (std::strcmp(argv[i], "-O1") == 0) { - a.opt_level = 1; - } else if (std::strcmp(argv[i], "-O2") == 0) { - a.opt_level = 2; - } else if (i + 1 < argc) { - if (std::strcmp(argv[i], "-d") == 0 || std::strcmp(argv[i], "--device") == 0) { - a.info = make_core_info_from_string(argv[++i]); - if (!a.info) { - throw std::runtime_error( - (std::ostringstream{} << "==> Unknown device: " << argv[i]).str()); - } - } else { - fail(); - } - } else { - fail(); - } - } else { - if (npos == 0) { - a.filename = argv[i]; - ++npos; - } else { - throw std::runtime_error("==> At most a single positional argument is expected"); - } - } - } - if (!a.info) { - a.info = make_core_info_intel_from_arch(intel_gpu_architecture::pvc); - } - - return a; -} - -void arg_parser::show_help(std::ostream &os) { - os << "usage: tinytc [-d ] [file-name]" << std::endl - << R"HELP( -positional arguments: - file-name Path to source code; leave empty to read from stdin - -optional arguments: - -d, --device Device name (cf. intel_gpu_architecture enum), default is "pvc" - -h, --help Show help text and exit -)HELP"; -} diff --git a/tools/offline_compiler/args.hpp b/tools/offline_compiler/args.hpp deleted file mode 100644 index 4e135ace..00000000 --- a/tools/offline_compiler/args.hpp +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef ARGS_20240516_HPP -#define ARGS_20240516_HPP - -#include "tinytc/tinytc.hpp" - -#include -#include - -struct args { - char const *filename; - tinytc::core_info info; - bool help; - std::int32_t opt_level; -}; - -class arg_parser { - public: - static args parse_args(int argc, char **argv); - static void show_help(std::ostream &os); -}; - -#endif // ARGS_20240516_HPP diff --git a/tools/offline_compiler/main.cpp b/tools/offline_compiler/main.cpp index 8fc1f71e..3cd02aa4 100644 --- a/tools/offline_compiler/main.cpp +++ b/tools/offline_compiler/main.cpp @@ -1,10 +1,12 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "args.hpp" +#include "argparser.hpp" +#include "argparser_common.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" +#include #include #include #include @@ -14,9 +16,35 @@ using namespace tinytc; int main(int argc, char **argv) { - auto a = args{}; + char const *filename = nullptr; + auto info = core_info{}; + std::int32_t opt_level = 2; + auto flags = cmd::optflag_states{}; + bool help = false; + + auto parser = cmd::arg_parser{}; try { - a = arg_parser::parse_args(argc, argv); + info = make_core_info_intel_from_arch(intel_gpu_architecture::pvc); + + parser.set_short_opt('O', &opt_level, "Optimization level, default is -O2") + .validator([](std::int32_t level) { return 0 <= level; }); + parser + .set_short_opt('d', &info, + "Device name (cf. intel_gpu_architecture enum), default is \"pvc\"") + .converter([](char const *str, core_info &val) -> cmd::parser_status { + val = make_core_info_intel_from_name(str); + if (!val) { + return cmd::parser_status::invalid_argument; + } + return cmd::parser_status::success; + }); + parser.set_short_opt('h', &help, "Show help"); + parser.set_long_opt("help", &help, "Show help"); + parser.add_positional_arg("file-name", &filename, + "Path to source code; leave empty to read from stdin"); + add_optflag_states(parser, flags); + + parser.parse(argc, argv); } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; return -1; @@ -24,23 +52,24 @@ int main(int argc, char **argv) { std::cerr << e.what() << std::endl; return -1; } - if (a.help) { - arg_parser::show_help(std::cout); + if (help) { + parser.print_help(std::cout, "tinytc", ""); return 0; } auto ctx = compiler_context{}; try { ctx = make_compiler_context(); - ctx.set_optimization_level(a.opt_level); + ctx.set_optimization_level(opt_level); + set_optflags(ctx, flags); auto p = prog{}; - if (!a.filename) { + if (!filename) { p = parse_stdin(ctx); } else { - p = parse_file(a.filename, ctx); + p = parse_file(filename, ctx); } - auto src = compile_to_opencl(std::move(p), a.info); + auto src = compile_to_opencl(std::move(p), info); std::cout << src.get_code(); } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; diff --git a/tools/opt/CMakeLists.txt b/tools/opt/CMakeLists.txt index 6368f8bd..c4a15853 100644 --- a/tools/opt/CMakeLists.txt +++ b/tools/opt/CMakeLists.txt @@ -4,8 +4,8 @@ include(CommonOptions) include(GNUInstallDirs) -add_executable(tinytc-opt main.cpp args.cpp) -target_link_libraries(tinytc-opt PRIVATE tinytc) +add_executable(tinytc-opt main.cpp) +target_link_libraries(tinytc-opt PRIVATE tinytc argparser) set_cxx_common_options(tinytc-opt) set_target_properties(tinytc-opt PROPERTIES INSTALL_RPATH_USE_LINK_PATH True) diff --git a/tools/opt/args.cpp b/tools/opt/args.cpp deleted file mode 100644 index b0358946..00000000 --- a/tools/opt/args.cpp +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "args.hpp" -#include "tinytc/types.hpp" - -#include -#include -#include -#include - -using tinytc::core_info; -using tinytc::intel_gpu_architecture; -using tinytc::list_function_passes; -using tinytc::make_core_info_intel_from_arch; - -auto make_core_info_from_string(char const *name) -> core_info { - if (std::strcmp(name, "pvc") == 0) { - return make_core_info_intel_from_arch(intel_gpu_architecture::pvc); - } else if (std::strcmp(name, "tgl") == 0) { - return make_core_info_intel_from_arch(intel_gpu_architecture::tgl); - } - return core_info{}; -} - -args arg_parser::parse_args(int argc, char **argv) { - args a = {}; - a.filename = nullptr; - - std::uint32_t names_size = 0; - char const *const *names = nullptr; - list_function_passes(names_size, names); - - auto const has_function_pass = [&names_size, names](char const *pass_name) -> bool { - for (std::uint32_t i = 0; i < names_size; ++i) { - if (std::strcmp(pass_name, names[i]) == 0) { - return true; - } - } - return false; - }; - - int npos = 0; - for (int i = 1; i < argc; ++i) { - if (argv[i][0] == '-') { - auto const fail = [&]() { - throw std::runtime_error( - (std::ostringstream{} << "==> Unrecognized argument: " << argv[i]).str()); - }; - if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0) { - a.help = true; - } else if (argv[i][1] == '-' && has_function_pass(argv[i] + 2)) { - a.pass_names.emplace_back(std::string(argv[i] + 2)); - } else if (i + 1 < argc) { - if (std::strcmp(argv[i], "-d") == 0 || std::strcmp(argv[i], "--device") == 0) { - a.info = make_core_info_from_string(argv[++i]); - if (!a.info) { - throw std::runtime_error( - (std::ostringstream{} << "==> Unknown device: " << argv[i]).str()); - } - } else { - fail(); - } - } else { - fail(); - } - } else { - if (npos == 0) { - a.filename = argv[i]; - ++npos; - } else { - throw std::runtime_error("==> At most a single positional argument is expected"); - } - } - } - if (!a.info) { - a.info = make_core_info_intel_from_arch(intel_gpu_architecture::pvc); - } - - if (a.pass_names.empty() || std::strncmp(a.pass_names.back().c_str(), "dump", 4) != 0) { - a.pass_names.emplace_back(std::string("dump-ir")); - } - - return a; -} - -void arg_parser::show_help(std::ostream &os) { - os << "usage: tinytc-opt [-d ] [file-name]" << std::endl - << R"HELP( -positional arguments: - file-name Path to source code; leave empty to read from stdin - -optional arguments: - -O0,-O1,-O2 Optimization level, default is -O2 - -d, --device Device name (cf. intel_gpu_architecture enum), default is "pvc" - -h, --help Show help text and exit - -passes: -)HELP"; - std::uint32_t names_size = 0; - char const *const *names = nullptr; - list_function_passes(names_size, names); - for (std::uint32_t i = 0; i < names_size; ++i) { - os << " --" << names[i] << std::endl; - } -} diff --git a/tools/opt/args.hpp b/tools/opt/args.hpp deleted file mode 100644 index 8ddd7754..00000000 --- a/tools/opt/args.hpp +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef ARGS_20240911_HPP -#define ARGS_20240911_HPP - -#include "tinytc/tinytc.hpp" - -#include -#include -#include - -struct args { - std::vector pass_names; - char const *filename; - tinytc::core_info info; - bool help; -}; - -class arg_parser { - public: - static args parse_args(int argc, char **argv); - static void show_help(std::ostream &os); -}; - -#endif // ARGS_20240911_HPP diff --git a/tools/opt/main.cpp b/tools/opt/main.cpp index f0112144..69788556 100644 --- a/tools/opt/main.cpp +++ b/tools/opt/main.cpp @@ -1,22 +1,52 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "args.hpp" +#include "argparser.hpp" +#include "argparser_common.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" +#include +#include #include #include #include -#include #include using namespace tinytc; int main(int argc, char **argv) { - auto a = args{}; + auto pass_names = std::vector{}; + char const *filename = nullptr; + auto info = core_info{}; + std::int32_t opt_level = 2; + auto flags = cmd::optflag_states{}; + bool help = false; + + auto parser = cmd::arg_parser{}; try { - a = arg_parser::parse_args(argc, argv); + info = make_core_info_intel_from_arch(intel_gpu_architecture::pvc); + + parser.set_short_opt('O', &opt_level, "Optimization level, default is -O2") + .validator([](std::int32_t level) { return 0 <= level; }); + parser + .set_short_opt('d', &info, + "Device name (cf. intel_gpu_architecture enum), default is \"pvc\"") + .converter([](char const *str, core_info &val) -> cmd::parser_status { + val = make_core_info_intel_from_name(str); + if (!val) { + return cmd::parser_status::invalid_argument; + } + return cmd::parser_status::success; + }); + parser.set_short_opt('p', &pass_names, "Run pass"); + parser.set_short_opt('h', &help, "Show help"); + parser.set_long_opt("help", &help, "Show help"); + parser.add_positional_arg("file-name", &filename, + "Path to source code; leave empty to read from stdin"); + add_optflag_states(parser, flags); + + parser.parse(argc, argv); } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; return -1; @@ -24,23 +54,38 @@ int main(int argc, char **argv) { std::cerr << e.what() << std::endl; return -1; } - if (a.help) { - arg_parser::show_help(std::cout); + if (help) { + parser.print_help(std::cout, "tinytc-opt", ""); + + std::uint32_t names_size = 0; + char const *const *names = nullptr; + list_function_passes(names_size, names); + + std::cout << std::endl << "Passes:" << std::endl; + for (std::uint32_t i = 0; i < names_size; ++i) { + std::cout << " " << names[i] << std::endl; + } return 0; } + if (pass_names.empty() || std::strncmp(pass_names.back(), "dump", 4) != 0) { + pass_names.emplace_back("dump-ir"); + } + auto ctx = compiler_context{}; try { ctx = make_compiler_context(); + ctx.set_optimization_level(opt_level); + set_optflags(ctx, flags); auto p = prog{}; - if (!a.filename) { + if (!filename) { p = parse_stdin(ctx); } else { - p = parse_file(a.filename, ctx); + p = parse_file(filename, ctx); } - for (auto const &pass_name : a.pass_names) { - run_function_pass(pass_name.c_str(), p, a.info); + for (auto const &pass_name : pass_names) { + run_function_pass(pass_name, p, info); } } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; From 016f2c4e1a8544da54d039057a32f945f6071c46 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 11 Oct 2024 12:18:11 +0200 Subject: [PATCH 048/297] Add folding of arithmetic identities Signed-off-by: Carsten Uphoff --- src/node/inst_node.hpp | 2 + ...helper.hpp => constant_folding_helper.hpp} | 124 ++++++++++++++-- src/pass/constant_propagation.cpp | 140 +++++++++++------- src/pass/constant_propagation.hpp | 2 +- test/opt/constant-propagation-safe.ir | 110 ++++++++++++++ test/opt/constant-propagation-unsafe.ir | 60 ++++++++ test/opt/constant-propagation.ir | 121 ++++++++------- tools/argparser/CMakeLists.txt | 4 + tools/argparser/argparser.cpp | 20 ++- tools/argparser/argparser.hpp | 3 + tools/argparser/argparser_common.cpp | 53 +++++++ tools/argparser/argparser_common.hpp | 29 ++-- tools/offline_compiler/CMakeLists.txt | 2 +- tools/offline_compiler/main.cpp | 8 +- tools/opt/CMakeLists.txt | 2 +- tools/opt/main.cpp | 13 +- 16 files changed, 546 insertions(+), 147 deletions(-) rename src/pass/{constant_propagation_helper.hpp => constant_folding_helper.hpp} (71%) create mode 100644 test/opt/constant-propagation-safe.ir create mode 100644 test/opt/constant-propagation-unsafe.ir create mode 100644 tools/argparser/argparser_common.cpp diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 61f8729d..39b77a72 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -408,7 +408,9 @@ class arith_inst : public standard_inst<2, 1> { arith_inst(arithmetic op, tinytc_value_t a, tinytc_value_t b, location const &lc = {}); inline arithmetic operation() const { return operation_; } + inline auto a() -> tinytc_value & { return op(op_a); } inline auto a() const -> tinytc_value const & { return op(op_a); } + inline auto b() -> tinytc_value & { return op(op_b); } inline auto b() const -> tinytc_value const & { return op(op_b); } private: diff --git a/src/pass/constant_propagation_helper.hpp b/src/pass/constant_folding_helper.hpp similarity index 71% rename from src/pass/constant_propagation_helper.hpp rename to src/pass/constant_folding_helper.hpp index 04cc552b..4b66ea8a 100644 --- a/src/pass/constant_propagation_helper.hpp +++ b/src/pass/constant_folding_helper.hpp @@ -1,8 +1,8 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#ifndef CONSTANT_PROPAGATION_HELPER_20241002_HPP -#define CONSTANT_PROPAGATION_HELPER_20241002_HPP +#ifndef CONSTANT_FOLDING_HELPER_20241011_HPP +#define CONSTANT_FOLDING_HELPER_20241011_HPP #include "scalar_type.hpp" #include "support/casting.hpp" @@ -11,6 +11,7 @@ #include #include #include +#include namespace tinytc { @@ -20,6 +21,8 @@ requires(std::is_floating_point_v) struct is_complex> : public std::true_type {}; template inline constexpr bool is_complex_v = is_complex::value; +using fold_result = std::variant; + struct compute_unary_op { arithmetic_unary operation; data_type ty; @@ -27,7 +30,7 @@ struct compute_unary_op { template requires(std::is_integral_v) - auto operator()(T a) { + auto operator()(T a) -> fold_result { T val = 0; switch (operation) { case arithmetic_unary::abs: @@ -51,7 +54,7 @@ struct compute_unary_op { template requires(std::is_floating_point_v) - auto operator()(T a) -> inst { + auto operator()(T a) -> fold_result { T val = 0; switch (operation) { case arithmetic_unary::abs: @@ -68,7 +71,7 @@ struct compute_unary_op { template requires(is_complex_v) - auto operator()(U const &A) -> inst { + auto operator()(U const &A) -> fold_result { const auto neg_conj = [&](T const &a) { T val = {}; switch (operation) { @@ -126,7 +129,7 @@ struct compute_binary_op { template requires(std::is_integral_v) - auto operator()(T a, T b) { + auto operator()(T a, T b) -> fold_result { T val = 0; switch (operation) { case arithmetic::add: @@ -173,7 +176,7 @@ struct compute_binary_op { template requires(!std::is_integral_v) - auto operator()(U const &A, U const &B) -> inst { + auto operator()(U const &A, U const &B) -> fold_result { const auto a = static_cast(A); const auto b = static_cast(B); T val = {}; @@ -208,6 +211,104 @@ struct compute_binary_op { } }; +struct compute_binop_identities { + bool unsafe_fp_math; + arithmetic operation; + tinytc_value &operand; + bool is_second_operand; + location const &loc; + + template + requires(std::is_integral_v) + auto operator()(T a) -> fold_result { + switch (operation) { + case arithmetic::add: + if (a == T{0}) { // operand + 0 or 0 + operand + return &operand; + } + break; + case arithmetic::sub: + if (a == T{0} && !is_second_operand) { // operand - 0 + return &operand; + } + break; + case arithmetic::mul: + if (a == T{0}) { // operand * 0 or 0 * operand + return make_constant(T{0}, operand.ty(), loc); + } else if (a == T{1}) { // operand * 1 or 1 * operand + return &operand; + } + break; + case arithmetic::div: + if (a == T{1} && !is_second_operand) { // operand / 1 + return &operand; + } + break; + case arithmetic::rem: + if (a == T{1} && !is_second_operand) { // operand % 1 + return make_constant(T{0}, operand.ty(), loc); + } + break; + case arithmetic::shl: + case arithmetic::shr: + if (a == T{0}) { + if (is_second_operand) { // 0 << operand + return make_constant(T{0}, operand.ty(), loc); + } else { // operand << 0 + return &operand; + } + } + case arithmetic::and_: + if (a == T{0}) { + return make_constant(T{0}, operand.ty(), loc); + } + break; + case arithmetic::or_: + case arithmetic::xor_: + if (a == T{0}) { + return &operand; + } + break; + default: + break; + } + return tinytc_value_t{}; + } + + template + requires(!std::is_integral_v) + auto operator()(U const &A) -> fold_result { + const auto a = static_cast(A); + switch (operation) { + case arithmetic::add: + if (a == T{0}) { // operand + 0 or 0 + operand + return &operand; + } + break; + case arithmetic::sub: + if (a == T{0} && !is_second_operand) { // operand - 0 + return &operand; + } + break; + case arithmetic::mul: + if (unsafe_fp_math && a == T{0}) { // operand * 0 or 0 * operand + return make_constant(T{0}, operand.ty(), loc); + } else if (a == T{1}) { // operand * 1 or 1 * operand + return &operand; + } + break; + case arithmetic::div: + if (a == T{1} && !is_second_operand) { // operand / 1 + return &operand; + } + break; + default: + break; + } + return tinytc_value_t{}; + } +}; + struct compute_compare { cmp_condition cond; data_type ty; @@ -215,7 +316,7 @@ struct compute_compare { template requires(std::is_integral_v || std::is_floating_point_v) - auto operator()(T a, T b) { + auto operator()(T a, T b) -> fold_result { bool val = false; switch (cond) { case cmp_condition::eq: @@ -241,7 +342,7 @@ struct compute_compare { } template - auto operator()(std::complex const &A, std::complex const &B) { + auto operator()(std::complex const &A, std::complex const &B) -> fold_result { const auto a = static_cast(A); const auto b = static_cast(B); bool val = false; @@ -286,7 +387,8 @@ template struct value_cast_impl> { template auto value_cast(U const &u) { return value_cast_impl{}(u); } -template auto compute_cast(scalar_data_type *to_ty, T A, location const &loc) -> inst { +template +auto compute_cast(scalar_data_type *to_ty, T A, location const &loc) -> fold_result { switch (to_ty->ty()) { case scalar_type::i1: return make_constant(value_cast(A), to_ty, loc); @@ -314,4 +416,4 @@ template auto compute_cast(scalar_data_type *to_ty, T A, location c } // namespace tinytc -#endif // CONSTANT_PROPAGATION_HELPER_20241002_HPP +#endif // CONSTANT_FOLDING_HELPER_20241011_HPP diff --git a/src/pass/constant_propagation.cpp b/src/pass/constant_propagation.cpp index 555970c0..e14df6b6 100644 --- a/src/pass/constant_propagation.cpp +++ b/src/pass/constant_propagation.cpp @@ -8,7 +8,7 @@ #include "node/inst_node.hpp" #include "node/region_node.hpp" #include "node/value_node.hpp" -#include "pass/constant_propagation_helper.hpp" +#include "pass/constant_folding_helper.hpp" #include "scalar_type.hpp" #include "support/casting.hpp" #include "support/ilist.hpp" @@ -33,7 +33,7 @@ template class unary_op_dispatcher { unary_op_dispatcher(scalar_type sw_ty, F &&f) : switch_ty{sw_ty}, computer{std::forward(f)} {} - auto operator()(std::int64_t const &A) -> inst { + auto operator()(std::int64_t const &A) -> fold_result { switch (switch_ty) { case scalar_type::i1: return computer.template operator()(A); @@ -52,7 +52,7 @@ template class unary_op_dispatcher { break; }; } - auto operator()(double const &A) -> inst { + auto operator()(double const &A) -> fold_result { switch (switch_ty) { case scalar_type::f32: return computer.template operator()(A); @@ -63,7 +63,7 @@ template class unary_op_dispatcher { break; } } - auto operator()(std::complex const &A) -> inst { + auto operator()(std::complex const &A) -> fold_result { switch (switch_ty) { case scalar_type::c32: return computer.template operator()>(A); @@ -85,7 +85,7 @@ template class binary_op_dispatcher { binary_op_dispatcher(scalar_type sw_ty, F &&f) : switch_ty{sw_ty}, computer{std::forward(f)} {} - auto operator()(std::int64_t const &A, std::int64_t const &B) -> inst { + auto operator()(std::int64_t const &A, std::int64_t const &B) -> fold_result { switch (switch_ty) { case scalar_type::i1: return computer.template operator()(A, B); @@ -104,7 +104,7 @@ template class binary_op_dispatcher { break; }; } - auto operator()(double const &A, double const &B) -> inst { + auto operator()(double const &A, double const &B) -> fold_result { switch (switch_ty) { case scalar_type::f32: return computer.template operator()(A, B); @@ -115,7 +115,7 @@ template class binary_op_dispatcher { break; } } - auto operator()(std::complex const &A, std::complex const &B) -> inst { + auto operator()(std::complex const &A, std::complex const &B) -> fold_result { switch (switch_ty) { case scalar_type::c32: return computer.template operator()>(A, B); @@ -126,25 +126,31 @@ template class binary_op_dispatcher { break; } } - template auto operator()(T const &, U const &) -> inst { + template auto operator()(T const &, U const &) -> fold_result { throw compilation_error(computer.loc, status::ir_scalar_mismatch); } }; -class constant_evaluator { +class constant_folding { public: - auto operator()(inst_node &) -> inst; - auto operator()(arith_inst &) -> inst; - auto operator()(arith_unary_inst &) -> inst; - auto operator()(cast_inst &) -> inst; - auto operator()(compare_inst &) -> inst; - auto operator()(size_inst &in) -> inst; + constant_folding(bool unsafe_fp_math); + + auto operator()(inst_node &) -> fold_result; + auto operator()(arith_inst &) -> fold_result; + auto operator()(arith_unary_inst &) -> fold_result; + auto operator()(cast_inst &) -> fold_result; + auto operator()(compare_inst &) -> fold_result; + auto operator()(size_inst &in) -> fold_result; private: auto get_memref_type(value_node const &v) const -> const memref_data_type *; + + bool unsafe_fp_math_; }; -auto constant_evaluator::get_memref_type(value_node const &v) const -> const memref_data_type * { +constant_folding::constant_folding(bool unsafe_fp_math) : unsafe_fp_math_(unsafe_fp_math) {} + +auto constant_folding::get_memref_type(value_node const &v) const -> const memref_data_type * { auto t = dyn_cast(v.ty()); if (t == nullptr) { throw compilation_error(v.loc(), status::ir_expected_memref); @@ -152,34 +158,44 @@ auto constant_evaluator::get_memref_type(value_node const &v) const -> const mem return t; } -auto constant_evaluator::operator()(inst_node &) -> inst { return {}; } +auto constant_folding::operator()(inst_node &) -> fold_result { return {}; } -auto constant_evaluator::operator()(arith_inst &in) -> inst { +auto constant_folding::operator()(arith_inst &in) -> fold_result { auto &op_a = in.a(); auto &op_b = in.b(); constant_inst *a_const = dyn_cast(op_a.defining_inst()); constant_inst *b_const = dyn_cast(op_b.defining_inst()); - if (a_const == nullptr || b_const == nullptr) { - return inst{}; - } auto at = dyn_cast(op_a.ty()); if (at == nullptr) { throw compilation_error(op_a.loc(), status::ir_expected_scalar); } - auto computer = compute_binary_op{in.operation(), op_a.ty(), in.loc()}; - auto dispatcher = binary_op_dispatcher{at->ty(), std::move(computer)}; - return std::visit(std::move(dispatcher), a_const->value(), b_const->value()); + if (a_const != nullptr && b_const != nullptr) { + auto computer = compute_binary_op{in.operation(), op_a.ty(), in.loc()}; + auto dispatcher = binary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), a_const->value(), b_const->value()); + } else if (a_const != nullptr) { + auto computer = + compute_binop_identities{unsafe_fp_math_, in.operation(), op_b, true, in.loc()}; + auto dispatcher = unary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), a_const->value()); + } else if (b_const != nullptr) { + auto computer = + compute_binop_identities{unsafe_fp_math_, in.operation(), op_a, false, in.loc()}; + auto dispatcher = unary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), b_const->value()); + } + return tinytc_value_t{}; } -auto constant_evaluator::operator()(arith_unary_inst &in) -> inst { +auto constant_folding::operator()(arith_unary_inst &in) -> fold_result { auto &op_a = in.a(); constant_inst *a_const = dyn_cast(op_a.defining_inst()); if (a_const == nullptr) { - return inst{}; + return tinytc_value_t{}; } auto at = dyn_cast(op_a.ty()); @@ -192,12 +208,12 @@ auto constant_evaluator::operator()(arith_unary_inst &in) -> inst { return std::visit(std::move(dispatcher), a_const->value()); } -auto constant_evaluator::operator()(cast_inst &in) -> inst { +auto constant_folding::operator()(cast_inst &in) -> fold_result { auto &op_a = in.a(); constant_inst *a_const = dyn_cast(op_a.defining_inst()); if (a_const == nullptr) { - return inst{}; + return tinytc_value_t{}; } auto rt = dyn_cast(in.result(0).ty()); @@ -205,18 +221,19 @@ auto constant_evaluator::operator()(cast_inst &in) -> inst { throw compilation_error(in.result(0).loc(), status::ir_expected_scalar); } - return std::visit(overloaded{[&](auto A) -> inst { return compute_cast(rt, A, in.loc()); }}, - a_const->value()); + return std::visit( + overloaded{[&](auto A) -> fold_result { return compute_cast(rt, A, in.loc()); }}, + a_const->value()); } -auto constant_evaluator::operator()(compare_inst &in) -> inst { +auto constant_folding::operator()(compare_inst &in) -> fold_result { auto &op_a = in.a(); auto &op_b = in.b(); constant_inst *a_const = dyn_cast(op_a.defining_inst()); constant_inst *b_const = dyn_cast(op_b.defining_inst()); if (a_const == nullptr || b_const == nullptr) { - return inst{}; + return tinytc_value_t{}; } auto at = dyn_cast(op_a.ty()); @@ -229,7 +246,7 @@ auto constant_evaluator::operator()(compare_inst &in) -> inst { return std::visit(std::move(dispatcher), a_const->value(), b_const->value()); } -auto constant_evaluator::operator()(size_inst &in) -> inst { +auto constant_folding::operator()(size_inst &in) -> fold_result { auto ct = get_memref_type(in.operand()); auto mode_size = ct->shape(in.mode()); @@ -238,7 +255,7 @@ auto constant_evaluator::operator()(size_inst &in) -> inst { mode_size, scalar_data_type::get(in.operand().context(), scalar_type::index), in.loc()); } - return inst{}; + return tinytc_value_t{}; } void constant_propagation_pass::run_on_function(function_node &fn) { run_on_region(fn.body()); } @@ -249,37 +266,46 @@ void constant_propagation_pass::run_on_region(region_node ®) { run_on_region(subreg); } - auto known_constant = visit(constant_evaluator{}, *it); - if (known_constant) { - // update uses - if (it->num_results() != known_constant->num_results()) { + const auto update_uses = [&it](tinytc_value_t with) { + if (it->num_results() != 1) { throw status::internal_compiler_error; } - auto r_old = it->result_begin(); - auto r_new = known_constant->result_begin(); - for (; r_old != it->result_end() && r_new != known_constant->result_end(); - ++r_old, ++r_new) { - r_new->name(r_old->name()); - auto u = r_old->use_begin(); - while (r_old->has_uses()) { - u->set(&*r_new); - u = r_old->use_begin(); - } - if (r_old->has_uses()) { - throw status::internal_compiler_error; - } + auto r = it->result_begin(); + auto u = r->use_begin(); + while (r->has_uses()) { + u->set(with); + u = r->use_begin(); } - // delete old instruction - it = reg.insts().erase(it); - // insert new instruction - it = reg.insts().insert(it, known_constant.release()); - } + if (r->has_uses()) { + throw status::internal_compiler_error; + } + }; + + fold_result fr = visit(constant_folding{unsafe_fp_math_}, *it); + std::visit(overloaded{[&](tinytc_value_t val) { + if (val) { + update_uses(val); + } + }, + [&](inst &new_constant) { + if (new_constant) { + if (new_constant->num_results() != 1) { + throw status::internal_compiler_error; + } + update_uses(&*new_constant->result_begin()); + // insert new constant + it = reg.insts().insert(it, new_constant.release()); + // skip over constant + ++it; + } + }}, + fr); } } void constant_propagation_pass::set_opt_flag(tinytc::optflag flag, bool enabled) { if (flag == tinytc::optflag::unsafe_fp_math) { - enable_unsafe_fp_math_ = enabled; + unsafe_fp_math_ = enabled; } } diff --git a/src/pass/constant_propagation.hpp b/src/pass/constant_propagation.hpp index 2d18090e..0f25bce1 100644 --- a/src/pass/constant_propagation.hpp +++ b/src/pass/constant_propagation.hpp @@ -17,7 +17,7 @@ class constant_propagation_pass { void set_opt_flag(tinytc::optflag flag, bool enabled); private: - bool enable_unsafe_fp_math_; + bool unsafe_fp_math_ = false; }; } // namespace tinytc diff --git a/test/opt/constant-propagation-safe.ir b/test/opt/constant-propagation-safe.ir new file mode 100644 index 00000000..50447cd7 --- /dev/null +++ b/test/opt/constant-propagation-safe.ir @@ -0,0 +1,110 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-opt -pconstant-propagation -fno-unsafe-fp-math < %s | filecheck %s + +func @identity_iadd(%a: i32) { + %c0 = constant 0 -> i32 + %0 = arith.add %a, %c0 : i32 + %1 = arith.add %c0, %a : i32 + %2 = arith.add %0, %1 : i32 +; CHECK-LABEL: func @identity_iadd({{.*}} +; CHECK: %2 = arith.add %a, %a : i32 +} + +func @identity_isub(%a: i32) { + %c0 = constant 0 -> i32 + %0 = arith.sub %a, %c0 : i32 + %1 = arith.sub %c0, %a : i32 + %2 = arith.add %0, %1 : i32 +; CHECK-LABEL: func @identity_isub({{.*}} +; CHECK: %2 = arith.add %a, %1 : i32 +} + +func @identity_imul0(%a: i32) { + %c0 = constant 0 -> i32 + %0 = arith.mul %a, %c0 : i32 + %1 = arith.mul %c0, %a : i32 +; CHECK-LABEL: func @identity_imul0({{.*}} +; CHECK: %0 = constant 0 -> i32 +; CHECK-NEXT: %1 = arith.mul %a, %c0 : i32 +; CHECK-NEXT: %2 = constant 0 -> i32 +; CHECK-NEXT: %3 = arith.mul %c0, %a : i32 +} + +func @identity_imul1(%a: i32) { + %c1 = constant 1 -> i32 + %0 = arith.mul %a, %c1 : i32 + %1 = arith.mul %c1, %a : i32 + %2 = arith.mul %0, %1 : i32 +; CHECK-LABEL: func @identity_imul1({{.*}} +; CHECK: %2 = arith.mul %a, %a : i32 +} + +func @identity_idiv(%a: i32) { + %c1 = constant 1 -> i32 + %0 = arith.div %a, %c1 : i32 + %1 = arith.div %c1, %a : i32 + %2 = arith.mul %0, %1 : i32 +; CHECK-LABEL: func @identity_idiv({{.*}} +; CHECK: %2 = arith.mul %a, %1 : i32 +} + +func @identity_irem(%a: i32) { + %c1 = constant 1 -> i32 + %0 = arith.rem %a, %c1 : i32 + %1 = arith.rem %c1, %a : i32 + %2 = arith.mul %0, %1 : i32 +; CHECK-LABEL: func @identity_irem({{.*}} +; CHECK: %0 = constant 0 -> i32 +; CHECK: %4 = arith.mul %0, %2 : i32 +} + +func @identity_ishl(%a: i32) { + %c0 = constant 0 -> i32 + %0 = arith.shl %a, %c0 : i32 + %1 = arith.shl %c0, %a : i32 + %2 = arith.add %0, %1 : i32 +; CHECK-LABEL: func @identity_ishl({{.*}} +; CHECK: %1 = constant 0 -> i32 +; CHECK: %3 = arith.add %a, %1 : i32 +} + +func @identity_ishr(%a: i32) { + %c0 = constant 0 -> i32 + %0 = arith.shr %a, %c0 : i32 + %1 = arith.shr %c0, %a : i32 + %2 = arith.add %0, %1 : i32 +; CHECK-LABEL: func @identity_ishr({{.*}} +; CHECK: %1 = constant 0 -> i32 +; CHECK: %3 = arith.add %a, %1 : i32 +} + +func @identity_iand(%a: i32) { + %c0 = constant 0 -> i32 + %0 = arith.and %a, %c0 : i32 + %1 = arith.and %c0, %a : i32 + %2 = arith.add %0, %1 : i32 +; CHECK-LABEL: func @identity_iand({{.*}} +; CHECK: %0 = constant 0 -> i32 +; CHECK: %2 = constant 0 -> i32 +; CHECK: %5 = arith.add %0, %2 : i32 +} + +func @identity_ior(%a: i32) { + %c0 = constant 0 -> i32 + %0 = arith.or %a, %c0 : i32 + %1 = arith.or %c0, %a : i32 + %2 = arith.add %0, %1 : i32 +; CHECK-LABEL: func @identity_ior({{.*}} +; CHECK: %2 = arith.add %a, %a : i32 +} + +func @identity_ixor(%a: i32) { + %c0 = constant 0 -> i32 + %0 = arith.xor %a, %c0 : i32 + %1 = arith.xor %c0, %a : i32 + %2 = arith.add %0, %1 : i32 +; CHECK-LABEL: func @identity_ixor({{.*}} +; CHECK: %2 = arith.add %a, %a : i32 +} diff --git a/test/opt/constant-propagation-unsafe.ir b/test/opt/constant-propagation-unsafe.ir new file mode 100644 index 00000000..53b4dc5c --- /dev/null +++ b/test/opt/constant-propagation-unsafe.ir @@ -0,0 +1,60 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-opt -pconstant-propagation -funsafe-fp-math < %s | filecheck %s + +func @identity_fadd(%a: f32) { + %c0 = constant 0.0 -> f32 + %0 = arith.add %a, %c0 : f32 + %1 = arith.add %c0, %a : f32 + %2 = arith.add %0, %1 : f32 +; CHECK-LABEL: func @identity_fadd({{.*}} +; CHECK: %2 = arith.add %a, %a : f32 +} + +func @identity_fsub(%a: f32) { + %c0 = constant 0.0 -> f32 + %0 = arith.sub %a, %c0 : f32 + %1 = arith.sub %c0, %a : f32 + %2 = arith.add %0, %1 : f32 +; CHECK-LABEL: func @identity_fsub({{.*}} +; CHECK: %2 = arith.add %a, %1 : f32 +} + +func @identity_fmul0(%a: f32) { + %c0 = constant 0.0 -> f32 + %0 = arith.mul %a, %c0 : f32 + %1 = arith.mul %c0, %a : f32 +; CHECK-LABEL: func @identity_fmul0({{.*}} +; CHECK: %0 = constant 0x0p+0 -> f32 +; CHECK-NEXT: %1 = arith.mul %a, %c0 : f32 +; CHECK-NEXT: %2 = constant 0x0p+0 -> f32 +; CHECK-NEXT: %3 = arith.mul %c0, %a : f32 +} + +func @identity_fmul1(%a: f32) { + %c1 = constant 1.0 -> f32 + %0 = arith.mul %a, %c1 : f32 + %1 = arith.mul %c1, %a : f32 + %2 = arith.mul %0, %1 : f32 +; CHECK-LABEL: func @identity_fmul1({{.*}} +; CHECK: %2 = arith.mul %a, %a : f32 +} + +func @identity_fdiv(%a: f32) { + %c1 = constant 1.0 -> f32 + %0 = arith.div %a, %c1 : f32 + %1 = arith.div %c1, %a : f32 + %2 = arith.mul %0, %1 : f32 +; CHECK-LABEL: func @identity_fdiv({{.*}} +; CHECK: %2 = arith.mul %a, %1 : f32 +} + +func @identity_cmul1(%a: c32) { + %c1 = constant [1.0, 0.0] -> c32 + %0 = arith.mul %a, %c1 : c32 + %1 = arith.mul %c1, %a : c32 + %2 = arith.mul %0, %1 : c32 +; CHECK-LABEL: func @identity_cmul1({{.*}} +; CHECK: %2 = arith.mul %a, %a : c32 +} diff --git a/test/opt/constant-propagation.ir b/test/opt/constant-propagation.ir index 0d58463a..702135b6 100644 --- a/test/opt/constant-propagation.ir +++ b/test/opt/constant-propagation.ir @@ -2,14 +2,19 @@ ; SPDX-License-Identifier: BSD-3-Clause ; RUN: %tinytc-opt -pconstant-propagation < %s | filecheck %s -func @known_size(%a: memref) { +func @known_size(%a: memref, %b: index) { %0 = size %a[0] : memref %1 = size %a[1] : memref %2 = arith.add %0, %1 : index + %3 = arith.add %2, %b : index ; CHECK-LABEL: func @known_size({{.*}} -; CHECK-NEXT: %0 = constant 64 -> index -; CHECK-NEXT: %1 = constant 32 -> index -; CHECK-NEXT: %2 = constant 96 -> index +; CHECK: %0 = constant 64 -> index +; CHECK-NEXT: %1 = size %a[0] : memref +; CHECK-NEXT: %2 = constant 32 -> index +; CHECK-NEXT: %3 = size %a[1] : memref +; CHECK-NEXT: %4 = constant 96 -> index +; CHECK-NEXT: %5 = arith.add %0, %2 : index +; CHECK-NEXT: %6 = arith.add %4, %b : index } func @known_loop_bounds() { @@ -21,35 +26,37 @@ func @known_loop_bounds() { for %i=%lb,%ub { } ; CHECK-LABEL: func @known_loop_bounds({{.*}} -; CHECK: %one = constant 1 -> index -; CHECK-NEXT: %lb = constant 5 -> index -; CHECK-NEXT: %size = constant 42 -> index -; CHECK-NEXT: %tmp = constant 37 -> index -; CHECK-NEXT: %ub = constant 36 -> index -; CHECK-NEXT: for %i=%lb,%ub : index { -; CHECK-NEXT: } +; CHECK-NEXT: %one = constant 1 -> index +; CHECK-NEXT: %lb = constant 5 -> index +; CHECK-NEXT: %size = constant 42 -> index +; CHECK-NEXT: %0 = constant 37 -> index +; CHECK-NEXT: %tmp = arith.sub %size, %lb : index +; CHECK-NEXT: %1 = constant 36 -> index +; CHECK-NEXT: %ub = arith.sub %0, %one : index +; CHECK-NEXT: for %i=%lb,%1 : index { } func @known_arith() { %0 = constant 1 -> i64 - %1 = arith.not %0 : i64 - %2 = constant 2 -> i64 - %3 = arith.add %0, %2 : i64 - %4 = constant -2.0 -> f32 - %5 = arith.neg %4 : f32 - %6 = constant [1.0, -1.0] -> c32 - %7 = arith.add %6, %6 : c32 - %8 = arith.abs %4 : f32 + %1 = constant 2 -> i64 + %3 = constant -2.0 -> f32 + %4 = constant [1.0, -1.0] -> c32 + %5 = arith.not %0 : i64 + %6 = arith.add %0, %1 : i64 + %7 = arith.neg %3 : f32 + %8 = arith.add %4, %4 : c32 + %9 = arith.abs %3 : f32 ; CHECK-LABEL: func @known_arith({{.*}} -; CHECK: %0 = constant 1 -> i64 -; CHECK-NEXT: %1 = constant -2 -> i64 -; CHECK-NEXT: %2 = constant 2 -> i64 -; CHECK-NEXT: %3 = constant 3 -> i64 -; CHECK-NEXT: %4 = constant -0x1p+1 -> f32 -; CHECK-NEXT: %5 = constant 0x1p+1 -> f32 -; CHECK-NEXT: %6 = constant [0x1p+0,-0x1p+0] -> c32 -; CHECK-NEXT: %7 = constant [0x1p+1,-0x1p+1] -> c32 -; CHECK-NEXT: %8 = constant 0x1p+1 -> f32 +; CHECK: %4 = constant -2 -> i64 +; CHECK-NEXT: %5 = arith.not %0 : i64 +; CHECK-NEXT: %6 = constant 3 -> i64 +; CHECK-NEXT: %7 = arith.add %0, %1 : i64 +; CHECK-NEXT: %8 = constant 0x1p+1 -> f32 +; CHECK-NEXT: %9 = arith.neg %2 : f32 +; CHECK-NEXT: %10 = constant [0x1p+1,-0x1p+1] -> c32 +; CHECK-NEXT: %11 = arith.add %3, %3 : c32 +; CHECK-NEXT: %12 = constant 0x1p+1 -> f32 +; CHECK-NEXT: %13 = arith.abs %2 : f32 } func @known_cast() { @@ -63,15 +70,20 @@ func @known_cast() { %5 = cast %c1 : c32 -> c64 %6 = cast %3 : i1 -> c32 ; CHECK-LABEL: func @known_cast({{.*}} -; CHECK: %c0 = constant 32768 -> i32 -; CHECK: %c1 = constant [0x1.8p+1,-0x1p+1] -> c32 -; CHECK-NEXT: %0 = constant -32768 -> i16 -; CHECK-NEXT: %1 = constant 0x1p+15 -> f32 -; CHECK-NEXT: %2 = constant [0x1p+15,0x0p+0] -> c32 -; CHECK-NEXT: %3 = constant 1 -> i1 +; CHECK: %0 = constant -32768 -> i16 +; CHECK-NEXT: %1 = cast %c0 : i32 -> i16 +; CHECK-NEXT: %2 = constant 0x1p+15 -> f32 +; CHECK-NEXT: %3 = cast %c0 : i32 -> f32 ; CHECK-NEXT: %4 = constant [0x1p+15,0x0p+0] -> c32 -; CHECK-NEXT: %5 = constant [0x1.8p+1,-0x1p+1] -> c64 -; CHECK-NEXT: %6 = constant [0x1p+0,0x0p+0] -> c32 +; CHECK-NEXT: %5 = cast %c0 : i32 -> c32 +; CHECK-NEXT: %6 = constant 1 -> i1 +; CHECK-NEXT: %7 = cast %c0 : i32 -> i1 +; CHECK-NEXT: %8 = constant [0x1p+15,0x0p+0] -> c32 +; CHECK-NEXT: %9 = cast %c0 : i32 -> c32 +; CHECK-NEXT: %10 = constant [0x1.8p+1,-0x1p+1] -> c64 +; CHECK-NEXT: %11 = cast %c1 : c32 -> c64 +; CHECK-NEXT: %12 = constant [0x1p+0,0x0p+0] -> c32 +; CHECK-NEXT: %13 = cast %6 : i1 -> c32 } func @known_compare() { @@ -80,10 +92,10 @@ func @known_compare() { %2 = cmp.eq %0, %0 : f32 %3 = cmp.eq %0, %1 : f32 ; CHECK-LABEL: func @known_compare({{.*}} -; CHECK: %0 = constant 0x1p+0 -> f32 -; CHECK-NEXT: %1 = constant 0x1p+1 -> f32 -; CHECK-NEXT: %2 = constant 1 -> i1 -; CHECK-NEXT: %3 = constant 0 -> i1 +; CHECK: %2 = constant 1 -> i1 +; CHECK-NEXT: %3 = cmp.eq %0, %0 : f32 +; CHECK-NEXT: %4 = constant 0 -> i1 +; CHECK-NEXT: %5 = cmp.eq %0, %1 : f32 } func @known_arith_complex() { @@ -99,15 +111,22 @@ func @known_arith_complex() { %7 = arith.im %a : c32 %8 = arith.re %a : c32 ; CHECK-LABEL: func @known_arith_complex({{.*}} -; CHECK: %a = constant [0x1.8p+1,0x1p+1] -> c32 -; CHECK: %b = constant [-0x1p+0,0x1.4p+2] -> c32 -; CHECK-NEXT: %0 = constant [0x1p+1,0x1.cp+2] -> c32 -; CHECK-NEXT: %1 = constant [0x1p+2,-0x1.8p+1] -> c32 -; CHECK-NEXT: %2 = constant [-0x1.ap+3,0x1.ap+3] -> c32 -; CHECK-NEXT: %3 = constant [0x1.13b13cp-2,-0x1.4ec4eep-1] -> c32 -; CHECK-NEXT: %4 = constant [-0x1.8p+1,-0x1p+1] -> c32 -; CHECK-NEXT: %5 = constant [0x1.8p+1,-0x1p+1] -> c32 -; CHECK-NEXT: %6 = constant 0x1.cd82b4p+1 -> f32 -; CHECK-NEXT: %7 = constant 0x1p+1 -> f32 -; CHECK-NEXT: %8 = constant 0x1.8p+1 -> f32 +; CHECK: %0 = constant [0x1p+1,0x1.cp+2] -> c32 +; CHECK-NEXT: %1 = arith.add %a, %b : c32 +; CHECK-NEXT: %2 = constant [0x1p+2,-0x1.8p+1] -> c32 +; CHECK-NEXT: %3 = arith.sub %a, %b : c32 +; CHECK-NEXT: %4 = constant [-0x1.ap+3,0x1.ap+3] -> c32 +; CHECK-NEXT: %5 = arith.mul %a, %b : c32 +; CHECK-NEXT: %6 = constant [0x1.13b13cp-2,-0x1.4ec4eep-1] -> c32 +; CHECK-NEXT: %7 = arith.div %a, %b : c32 +; CHECK-NEXT: %8 = constant [-0x1.8p+1,-0x1p+1] -> c32 +; CHECK-NEXT: %9 = arith.neg %a : c32 +; CHECK-NEXT: %10 = constant [0x1.8p+1,-0x1p+1] -> c32 +; CHECK-NEXT: %11 = arith.conj %a : c32 +; CHECK-NEXT: %12 = constant 0x1.cd82b4p+1 -> f32 +; CHECK-NEXT: %13 = arith.abs %a : c32 +; CHECK-NEXT: %14 = constant 0x1p+1 -> f32 +; CHECK-NEXT: %15 = arith.im %a : c32 +; CHECK-NEXT: %16 = constant 0x1.8p+1 -> f32 +; CHECK-NEXT: %17 = arith.re %a : c32 } diff --git a/tools/argparser/CMakeLists.txt b/tools/argparser/CMakeLists.txt index d23245c2..f62d0b43 100644 --- a/tools/argparser/CMakeLists.txt +++ b/tools/argparser/CMakeLists.txt @@ -7,6 +7,10 @@ add_library(argparser STATIC argparser.cpp) target_include_directories(argparser PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ${PROJECT_SOURCE_DIR}/src) set_cxx_common_options(argparser) +add_library(argparser_common STATIC argparser_common.cpp) +target_link_libraries(argparser_common PRIVATE argparser tinytc) +set_cxx_common_options(argparser_common) + add_executable(test-argparser test.cpp) target_link_libraries(test-argparser PRIVATE argparser) set_cxx_common_options(test-argparser) diff --git a/tools/argparser/argparser.cpp b/tools/argparser/argparser.cpp index c9e980c7..f0b10b16 100644 --- a/tools/argparser/argparser.cpp +++ b/tools/argparser/argparser.cpp @@ -47,6 +47,9 @@ auto to_string(parser_status status) -> char const * { return ""; } +int arg_parser::optindent = 4; +int arg_parser::optwidth = 20; + arg_parser::arg_parser() : short_(2 * 26 + 10) {} void arg_parser::parse(int argc, char **argv) { @@ -156,8 +159,6 @@ void arg_parser::parse(int argc, char **argv) { } void arg_parser::print_help(std::ostream &os, char const *name, char const *description) { - constexpr int optwidth = 20; - const auto print = [&](auto const &key, auto const &par, char const *init, char const *sep_req, char const *sep_nonreq) { if (par) { @@ -186,7 +187,10 @@ void arg_parser::print_help(std::ostream &os, char const *name, char const *desc const auto print_short_help = [&](char i) { auto const &opt = short_[short_index(i)]; if (opt.par) { - os << " -" << std::left << std::setw(optwidth) << i << opt.help << std::endl; + for (int i = 0; i < optindent - 1; ++i) { + os << ' '; + } + os << '-' << std::left << std::setw(optwidth) << i << opt.help << std::endl; } }; os << "Usage: " << name; @@ -220,7 +224,10 @@ void arg_parser::print_help(std::ostream &os, char const *name, char const *desc os << "Positional arguments:" << std::endl; for (auto const &pos : positional_) { - os << " " << std::left << std::setw(optwidth) << pos.opt << pos.help << std::endl; + for (int i = 0; i < optindent; ++i) { + os << ' '; + } + os << std::left << std::setw(optwidth) << pos.opt << pos.help << std::endl; } os << std::endl << "Options:" << std::endl; @@ -232,7 +239,10 @@ void arg_parser::print_help(std::ostream &os, char const *name, char const *desc print_short_help(std::toupper(i)); } for (auto const &opt : long_opts) { - os << " --" << std::left << std::setw(optwidth) << opt->opt << opt->help << std::endl; + for (int i = 0; i < optindent - 2; ++i) { + os << ' '; + } + os << "--" << std::left << std::setw(optwidth) << opt->opt << opt->help << std::endl; } } diff --git a/tools/argparser/argparser.hpp b/tools/argparser/argparser.hpp index 9c4a62a2..cd6aa56c 100644 --- a/tools/argparser/argparser.hpp +++ b/tools/argparser/argparser.hpp @@ -151,6 +151,9 @@ template class par_model> : public par_model { class arg_parser { public: + static int optindent; + static int optwidth; + arg_parser(); inline void set_short_opt(char opt, bool *ptr, char const *help = nullptr) { diff --git a/tools/argparser/argparser_common.cpp b/tools/argparser/argparser_common.cpp new file mode 100644 index 00000000..00bbd543 --- /dev/null +++ b/tools/argparser/argparser_common.cpp @@ -0,0 +1,53 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "argparser_common.hpp" +#include "support/fnv1a.hpp" +#include "tinytc/tinytc.hpp" + +#include +#include + +namespace tinytc::cmd { + +void add_optflag_states(arg_parser &parser, optflag_states &flags) { + auto const converter = [](char const *str, std::pair &val) { + optflag flag = {}; + std::int32_t state = 1; + constexpr char const disable_prefix[] = "no-"; + constexpr std::size_t disable_prefix_len = sizeof(disable_prefix) - 1; + if (std::strncmp(str, disable_prefix, disable_prefix_len) == 0) { + state = 0; + str = str + disable_prefix_len; + } + switch (fnv1a(str, std::strlen(str))) { + case "unsafe-fp-math"_fnv1a: + flag = optflag::unsafe_fp_math; + break; + default: + return parser_status::invalid_argument; + }; + val = std::make_pair(flag, state); + return parser_status::success; + }; + parser + .set_short_opt('f', &flags, + "Enable optimization flag; use \"no-\" prefix to disable optimization flag") + .converter(converter); +} + +void set_optflags(compiler_context &ctx, optflag_states const &flags) { + for (auto const &flag : flags) { + ctx.set_optimization_flag(flag.first, flag.second); + } +} + +void list_optimization_flags(std::ostream &os) { + os << "Optimization flags:" << std::endl; + for (int i = 0; i < arg_parser::optindent; ++i) { + os << ' '; + } + os << "unsafe-fp-math" << std::endl; +} + +} // namespace tinytc::cmd diff --git a/tools/argparser/argparser_common.hpp b/tools/argparser/argparser_common.hpp index dca7dfe6..24462cb9 100644 --- a/tools/argparser/argparser_common.hpp +++ b/tools/argparser/argparser_common.hpp @@ -5,25 +5,24 @@ #define ARGPARSER_COMMON_20241010_HPP #include "argparser.hpp" -#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" -namespace tinytc::cmd { - -struct optflag_states { - std::int32_t unsafe_fp_math = -1; -}; +#include +#include +#include -inline void add_optflag_states(arg_parser &parser, optflag_states &flags) { - auto const validator = [](std::int32_t value) { return -1 <= value && value <= 1; }; - parser - .set_long_opt("unsafe-fp-math", &flags.unsafe_fp_math, - "Enable unsafe floating point math (e.g. 0.0 * x = 0.0)", 1) - .validator(validator); +namespace tinytc { +class compiler_context; } -inline void set_optflags(compiler_context &ctx, optflag_states const &flags) { - ctx.set_optimization_flag(optflag::unsafe_fp_math, flags.unsafe_fp_math); -} +namespace tinytc::cmd { +class arg_parser; + +using optflag_states = std::vector>; + +void add_optflag_states(arg_parser &parser, optflag_states &flags); +void set_optflags(compiler_context &ctx, optflag_states const &flags); +void list_optimization_flags(std::ostream &os); } // namespace tinytc::cmd diff --git a/tools/offline_compiler/CMakeLists.txt b/tools/offline_compiler/CMakeLists.txt index 36acd618..e8dd226a 100644 --- a/tools/offline_compiler/CMakeLists.txt +++ b/tools/offline_compiler/CMakeLists.txt @@ -5,7 +5,7 @@ include(CommonOptions) include(GNUInstallDirs) add_executable(tinytc-oc main.cpp) -target_link_libraries(tinytc-oc PRIVATE tinytc argparser) +target_link_libraries(tinytc-oc PRIVATE tinytc argparser argparser_common) set_target_properties(tinytc-oc PROPERTIES OUTPUT_NAME "tinytc") set_cxx_common_options(tinytc-oc) diff --git a/tools/offline_compiler/main.cpp b/tools/offline_compiler/main.cpp index 3cd02aa4..edb798be 100644 --- a/tools/offline_compiler/main.cpp +++ b/tools/offline_compiler/main.cpp @@ -42,7 +42,7 @@ int main(int argc, char **argv) { parser.set_long_opt("help", &help, "Show help"); parser.add_positional_arg("file-name", &filename, "Path to source code; leave empty to read from stdin"); - add_optflag_states(parser, flags); + cmd::add_optflag_states(parser, flags); parser.parse(argc, argv); } catch (status const &st) { @@ -54,6 +54,10 @@ int main(int argc, char **argv) { } if (help) { parser.print_help(std::cout, "tinytc", ""); + + std::cout << std::endl; + cmd::list_optimization_flags(std::cout); + return 0; } @@ -61,7 +65,7 @@ int main(int argc, char **argv) { try { ctx = make_compiler_context(); ctx.set_optimization_level(opt_level); - set_optflags(ctx, flags); + cmd::set_optflags(ctx, flags); auto p = prog{}; if (!filename) { p = parse_stdin(ctx); diff --git a/tools/opt/CMakeLists.txt b/tools/opt/CMakeLists.txt index c4a15853..87721aeb 100644 --- a/tools/opt/CMakeLists.txt +++ b/tools/opt/CMakeLists.txt @@ -5,7 +5,7 @@ include(CommonOptions) include(GNUInstallDirs) add_executable(tinytc-opt main.cpp) -target_link_libraries(tinytc-opt PRIVATE tinytc argparser) +target_link_libraries(tinytc-opt PRIVATE tinytc argparser argparser_common) set_cxx_common_options(tinytc-opt) set_target_properties(tinytc-opt PROPERTIES INSTALL_RPATH_USE_LINK_PATH True) diff --git a/tools/opt/main.cpp b/tools/opt/main.cpp index 69788556..eb86ce67 100644 --- a/tools/opt/main.cpp +++ b/tools/opt/main.cpp @@ -44,7 +44,7 @@ int main(int argc, char **argv) { parser.set_long_opt("help", &help, "Show help"); parser.add_positional_arg("file-name", &filename, "Path to source code; leave empty to read from stdin"); - add_optflag_states(parser, flags); + cmd::add_optflag_states(parser, flags); parser.parse(argc, argv); } catch (status const &st) { @@ -63,8 +63,15 @@ int main(int argc, char **argv) { std::cout << std::endl << "Passes:" << std::endl; for (std::uint32_t i = 0; i < names_size; ++i) { - std::cout << " " << names[i] << std::endl; + for (int i = 0; i < cmd::arg_parser::optindent; ++i) { + std::cout << ' '; + } + std::cout << names[i] << std::endl; } + + std::cout << std::endl; + cmd::list_optimization_flags(std::cout); + return 0; } @@ -76,7 +83,7 @@ int main(int argc, char **argv) { try { ctx = make_compiler_context(); ctx.set_optimization_level(opt_level); - set_optflags(ctx, flags); + cmd::set_optflags(ctx, flags); auto p = prog{}; if (!filename) { p = parse_stdin(ctx); From 435a01b35fcb81b5ab68dfdd3778c5d6cd89a922 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 11 Oct 2024 14:53:29 +0200 Subject: [PATCH 049/297] Add atomic store and atomic fetch add Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.rst | 14 +++++++++++ docs/api/builder_capi.yaml | 2 ++ docs/api/builder_cxxapi.rst | 14 +++++++++++ docs/api/builder_cxxapi.yaml | 2 ++ docs/api/core_capi.rst | 21 ++++++++++++++++ docs/api/core_cxxapi.rst | 7 ++++++ docs/manual/tensor-ir.rst | 46 ++++++++++++++++++++++++++++++++-- include/tinytc/tinytc.h | 6 ++++- include/tinytc/tinytc.hpp | 19 ++++++++++++-- include/tinytc/types.h | 8 ++++++ include/tinytc/types.hpp | 8 ++++++ src/codegen_tools.cpp | 29 +++++++++++++++++++++ src/codegen_tools.hpp | 5 ++++ src/error.cpp | 2 ++ src/inst.cpp | 22 +++++++++++++--- src/node/inst_node.cpp | 11 ++++++-- src/node/inst_node.hpp | 10 ++++++-- src/parser/lexer.re | 1 + src/parser/parser_impl.yy | 12 +++++++-- src/pass/convert_to_opencl.cpp | 22 +++++++++++++++- src/pass/dump_ir.cpp | 6 ++++- src/pass/lower_linalg.cpp | 7 +----- test/codegen/atomic.ir | 29 ++++++++++++--------- 23 files changed, 268 insertions(+), 35 deletions(-) diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index c4d32f69..55d62075 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -20,6 +20,8 @@ Common * :ref:`tinytc_scalar_type_t` + * :ref:`tinytc_store_flag_t` + * :ref:`tinytc_transpose_t` * Definitions @@ -40,6 +42,8 @@ Common * :ref:`tinytc_scalar_type_to_string` + * :ref:`tinytc_store_flag_to_string` + * :ref:`tinytc_transpose_to_string` * Structures @@ -106,6 +110,11 @@ tinytc_scalar_type_t .. doxygenenum:: tinytc_scalar_type_t +tinytc_store_flag_t +................... + +.. doxygenenum:: tinytc_store_flag_t + tinytc_transpose_t .................. @@ -152,6 +161,11 @@ tinytc_scalar_type_to_string .. doxygenfunction:: tinytc_scalar_type_to_string +tinytc_store_flag_to_string +........................... + +.. doxygenfunction:: tinytc_store_flag_to_string + tinytc_transpose_to_string .......................... diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index b17ff7e1..fd76216f 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -8,6 +8,7 @@ Builder C-API: - tinytc_arithmetic_unary_t - tinytc_cmp_condition_t - tinytc_scalar_type_t + - tinytc_store_flag_t - tinytc_transpose_t define: - TINYTC_DYNAMIC @@ -18,6 +19,7 @@ Builder C-API: - tinytc_cmp_condition_to_string - tinytc_scalar_type_size - tinytc_scalar_type_to_string + - tinytc_store_flag_to_string - tinytc_transpose_to_string struct: - tinytc_position diff --git a/docs/api/builder_cxxapi.rst b/docs/api/builder_cxxapi.rst index 38f1c75f..e53a21e0 100644 --- a/docs/api/builder_cxxapi.rst +++ b/docs/api/builder_cxxapi.rst @@ -20,6 +20,8 @@ Common * :ref:`scalar_type` + * :ref:`store_flag` + * :ref:`transpose` * Functions @@ -36,6 +38,8 @@ Common * :ref:`to_string(scalar_type)` + * :ref:`to_string(store_flag)` + * :ref:`to_string(transpose)` * :ref:`size` @@ -82,6 +86,11 @@ scalar_type .. doxygenenum:: tinytc::scalar_type +store_flag +.......... + +.. doxygenenum:: tinytc::store_flag + transpose ......... @@ -120,6 +129,11 @@ to_string(scalar_type) .. doxygenfunction:: tinytc::to_string(scalar_type) +to_string(store_flag) +..................... + +.. doxygenfunction:: tinytc::to_string(store_flag) + to_string(transpose) .................... diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index 9e3fa629..c3a9f077 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -8,6 +8,7 @@ Builder C++-API: - tinytc::arithmetic_unary - tinytc::cmp_condition - tinytc::scalar_type + - tinytc::store_flag - tinytc::transpose function: - tinytc::is_dynamic_value @@ -16,6 +17,7 @@ Builder C++-API: - tinytc::to_string(arithmetic_unary) - tinytc::to_string(cmp_condition) - tinytc::to_string(scalar_type) + - tinytc::to_string(store_flag) - tinytc::to_string(transpose) - tinytc::size class: diff --git a/docs/api/core_capi.rst b/docs/api/core_capi.rst index d83e05c0..093bbbda 100644 --- a/docs/api/core_capi.rst +++ b/docs/api/core_capi.rst @@ -246,6 +246,8 @@ Compiler * :ref:`tinytc_bundle_format_t` + * :ref:`tinytc_optflag_t` + * Functions * :ref:`tinytc_run_function_pass` @@ -262,6 +264,11 @@ tinytc_bundle_format_t .. doxygenenum:: tinytc_bundle_format_t +tinytc_optflag_t +................ + +.. doxygenenum:: tinytc_optflag_t + Compiler Functions ------------------ @@ -291,6 +298,8 @@ Compiler Context * :ref:`tinytc_compiler_context_set_error_reporter` + * :ref:`tinytc_compiler_context_set_optimization_flag` + * :ref:`tinytc_compiler_context_set_optimization_level` * :ref:`tinytc_compiler_context_report_error` @@ -317,6 +326,11 @@ tinytc_compiler_context_set_error_reporter .. doxygenfunction:: tinytc_compiler_context_set_error_reporter +tinytc_compiler_context_set_optimization_flag +............................................. + +.. doxygenfunction:: tinytc_compiler_context_set_optimization_flag + tinytc_compiler_context_set_optimization_level .............................................. @@ -362,6 +376,8 @@ Device Info * :ref:`tinytc_core_info_intel_create_from_arch` + * :ref:`tinytc_core_info_intel_create_from_name` + * :ref:`tinytc_core_info_release` * :ref:`tinytc_core_info_retain` @@ -421,6 +437,11 @@ tinytc_core_info_intel_create_from_arch .. doxygenfunction:: tinytc_core_info_intel_create_from_arch +tinytc_core_info_intel_create_from_name +....................................... + +.. doxygenfunction:: tinytc_core_info_intel_create_from_name + tinytc_core_info_release ........................ diff --git a/docs/api/core_cxxapi.rst b/docs/api/core_cxxapi.rst index ede509f2..a2681cc6 100644 --- a/docs/api/core_cxxapi.rst +++ b/docs/api/core_cxxapi.rst @@ -224,6 +224,8 @@ Device Info * :ref:`make_core_info_intel_from_arch` + * :ref:`make_core_info_intel_from_name` + * Classes * :ref:`core_info` @@ -259,6 +261,11 @@ make_core_info_intel_from_arch .. doxygenfunction:: tinytc::make_core_info_intel_from_arch +make_core_info_intel_from_name +.............................. + +.. doxygenfunction:: tinytc::make_core_info_intel_from_name + Device Info Classes ------------------- diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 8d24ef50..29fcd893 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -333,6 +333,11 @@ The transpose modifier defines :math:`\text{op}` as following: The shape of :math:`\text{op}(A)` and B must be identical and the order of A and B needs to be 1 (vector) or 2 (matrix). +Restrictions +~~~~~~~~~~~~ + +If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. + Foreach ....... @@ -395,6 +400,11 @@ defines :math:`\text{op}_2` as following: If :math:`\text{op}_1(A)` has the shape MxK and :math:`\text{op}_2(B)` has the shape KxN then C must have the shape MxN. +Restrictions +~~~~~~~~~~~~ + +If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. + GEMV .... @@ -426,6 +436,11 @@ The transpose modifier for A as in GEMM. :math:`\text{op}_1(A)` has the shape MxK and :math:`B` has the shape K then c must have the shape M. +Restrictions +~~~~~~~~~~~~ + +If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. + GER ... @@ -455,6 +470,11 @@ a, b, and C, respectively. a and b must be vectors. If the size of a is M and the size of b is N the shape of C must be :math:`M\times N`. +Restrictions +~~~~~~~~~~~~ + +If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. + Hadamard product ................ @@ -486,6 +506,11 @@ a, b, and c, respectively. a, b, and c must be vectors and have equal shape. +Restrictions +~~~~~~~~~~~~ + +If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. + Parallel ........ @@ -537,6 +562,11 @@ If A is a vector, then B must be a scalar memref. The transpose op is defined as in the axpby instruction. +Restrictions +~~~~~~~~~~~~ + +If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. + Mixed instructions @@ -901,7 +931,8 @@ Load .. code:: abnf - value-instruction =/ "load" local-identifier "[" [local-identifier-list] "]" ":" memref-or-group-type + value-instruction =/ "load" local-identifier "[" [local-identifier-list] "]" + ":" memref-or-group-type memref-or-group-type = memref-type / group-type Overview @@ -1078,7 +1109,10 @@ Store .. code:: abnf - instruction =/ "store" local-identifier "," local-identifier "[" [local-identifier-list] "]" ":" memref-type + instruction =/ "store" [store-flag] + local-identifier "," local-identifier "[" [local-identifier-list] "]" + ":" memref-type + store-flag = ".atomic" / ".atomic_add" Overview ~~~~~~~~ @@ -1086,6 +1120,14 @@ Overview Store a scalar value in a memref at the position given by the index list. The number of indices must match the order of the memref. +The store is atomic when the atomic flag is set with relaxed memory ordering. +When the atomic_add flag is set, the following steps are done atomically: +The value at the memory location is fetched, the scalar value is added to the fetched value, +and the resulting value is stored at the memory location. + +When storing a complex value the update may be pseudo-atomic, meaning that an atomic store is used +for the the real and imaginary separately. + *Note:* Store should only be used in SPMD regions as otherwise the same memory location is written from all work-items. diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 75b4f9a8..469e497b 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -154,6 +154,8 @@ TINYTC_EXPORT char const *tinytc_arithmetic_to_string(tinytc_arithmetic_t op); TINYTC_EXPORT char const *tinytc_arithmetic_unary_to_string(tinytc_arithmetic_unary_t op); //! Convert cmp condition to string TINYTC_EXPORT char const *tinytc_cmp_condition_to_string(tinytc_cmp_condition_t cond); +//! Convert store flag to string +TINYTC_EXPORT char const *tinytc_store_flag_to_string(tinytc_store_flag_t flag); //! Convert transpose to string TINYTC_EXPORT char const *tinytc_transpose_to_string(tinytc_transpose_t t); @@ -619,6 +621,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_subview_inst_create( * @code store %val, %a[%index_list] : type(%a) @endcode * * @param instr [out] pointer to the inst object created + * @param flag [in] store flag * @param val [in] value to store * @param a [in] operand * @param index_list_size [in] number of indices @@ -628,7 +631,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_subview_inst_create( * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_store_inst_create(tinytc_inst_t *instr, tinytc_value_t val, +TINYTC_EXPORT tinytc_status_t tinytc_store_inst_create(tinytc_inst_t *instr, + tinytc_store_flag_t flag, tinytc_value_t val, tinytc_value_t a, uint32_t index_list_size, const tinytc_value_t *index_list, const tinytc_location_t *loc); diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 5b1c53c6..d14ed73c 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -684,6 +684,17 @@ inline char const *to_string(cmp_condition cond) { return ::tinytc_cmp_condition_to_string(static_cast<::tinytc_cmp_condition_t>(cond)); } +/** + * @brief Convert store flag to string + * + * @param flag Store flag + * + * @return C-string + */ +inline char const *to_string(store_flag flag) { + return ::tinytc_store_flag_to_string(static_cast(flag)); +} + /** * @brief Convert transpose to string * @@ -1257,6 +1268,7 @@ inline inst make_subview(value a, array_view static_offset_list, /** * @brief Make store instruction * + * @param flag store flag * @param val Value that is stored * @param a Target memref * @param index_list Vector of indices @@ -1264,14 +1276,17 @@ inline inst make_subview(value a, array_view static_offset_list, * * @return Instruction */ -inline inst make_store(value val, value a, array_view index_list, location const &loc = {}) { +inline inst make_store(store_flag flag, value val, value a, array_view index_list, + location const &loc = {}) { tinytc_inst_t instr; auto len = index_list.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("index list too long"); } const tinytc_value_t *il = reinterpret_cast(index_list.data()); - CHECK_STATUS_LOC(tinytc_store_inst_create(&instr, val, a, len, il, &loc), loc); + CHECK_STATUS_LOC(tinytc_store_inst_create(&instr, static_cast(flag), val, + a, len, il, &loc), + loc); return inst(instr); } diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 4af6cb0f..8989798a 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -69,6 +69,7 @@ typedef enum { tinytc_status_ir_i1_unsupported = 0x118, ///< Instruction does not support i1 type tinytc_status_ir_complex_unsupported = 0x119, ///< Instruction does not support complex type tinytc_status_ir_forbidden_cast = 0x11a, ///< Forbidden cast + tinytc_status_ir_invalid_beta = 0x11b, ///< Invalid beta value // Level zero errors tinytc_status_ze_result_not_ready = 0x10000, ///< ZE_RESULT_NOT_READY tinytc_status_ze_result_error_device_lost = 0x10001, ///< ZE_RESULT_ERROR_DEVICE_LOST @@ -281,6 +282,13 @@ typedef enum { tinytc_address_space_local = 0x2 ///< Local memory, returned by alloca } tinytc_address_space_t; +//! Store flag +typedef enum { + tinytc_store_flag_regular = 0, ///< Non-atomic store + tinytc_store_flag_atomic = 1, ///< Atomic store + tinytc_store_flag_atomic_add = 2 ///< Atomic fetch add +} tinytc_store_flag_t; + //! Core features that may be optionally enabled typedef enum { /** diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index fb714ebf..4b20478b 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -79,6 +79,7 @@ enum class status { ir_i1_unsupported = tinytc_status_ir_i1_unsupported, ir_complex_unsupported = tinytc_status_ir_complex_unsupported, ir_forbidden_cast = tinytc_status_ir_forbidden_cast, + ir_invalid_beta = tinytc_status_ir_invalid_beta, ze_result_not_ready = tinytc_status_ze_result_not_ready, ze_result_error_device_lost = tinytc_status_ze_result_error_device_lost, ze_result_error_out_of_host_memory = tinytc_status_ze_result_error_out_of_host_memory, @@ -263,6 +264,13 @@ enum class address_space { local = tinytc_address_space_local ///< Local memory, returned by alloca }; +//! Store flag +enum class store_flag { + regular = tinytc_store_flag_regular, ///< Non-atomic store + atomic = tinytc_store_flag_atomic, ///< Atomic store + atomic_add = tinytc_store_flag_atomic_add ///< Atomic fetch add +}; + //! @brief Cf. @ref tinytc_core_feature_flag_t enum class core_feature_flag { large_register_file = tinytc_core_feature_flag_large_register_file }; diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index be822845..8b1255ba 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -4,6 +4,7 @@ #include "codegen_tools.hpp" #include "error.hpp" #include "node/data_type_node.hpp" +#include "node/inst_node.hpp" #include "node/value_node.hpp" #include "scalar_type.hpp" #include "support/casting.hpp" @@ -540,4 +541,32 @@ auto mixed_precision_arithmetic(region_builder &bb, arithmetic operation, value return bb.add(make_arith(operation, a, b)); } +auto get_atomic_store_flag(value beta) -> std::optional { + constant_inst *beta_cst = dyn_cast(beta->defining_inst()); + if (beta_cst) { + if (beta_cst->is_zero()) { + return store_flag::atomic; + } else if (beta_cst->is_identity()) { + return store_flag::atomic_add; + } + } + return std::nullopt; +} +void blas_update(region_builder &bb, bool atomic, value alpha_ab, value beta, value C, + array_view index_list, location const &loc) { + if (atomic) { + auto flag = get_atomic_store_flag(beta); + if (!flag) { + throw compilation_error(loc, status::ir_invalid_beta); + } + bb.add(make_store(*flag, alpha_ab, C, index_list, loc)); + } else { + auto c = bb.add(make_load(C, index_list, loc)); + auto beta_c = mixed_precision_arithmetic(bb, arithmetic::mul, beta, c, loc); + auto alpha_ab_plus_beta_c = + mixed_precision_arithmetic(bb, arithmetic::add, alpha_ab, beta_c, loc); + bb.add(make_store(store_flag::regular, alpha_ab_plus_beta_c, C, index_list, loc)); + } +} + } // namespace tinytc diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index fb47e9e5..8122bd88 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -17,6 +17,7 @@ #include #include #include +#include #include namespace tinytc { @@ -139,6 +140,10 @@ void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int bloc auto mixed_precision_arithmetic(region_builder &bb, arithmetic operation, value a, value b, location const &loc) -> value; +auto get_atomic_store_flag(value beta) -> std::optional; +void blas_update(region_builder &bb, bool atomic, value alpha_ab, value beta, value C, + array_view index_list, location const &loc); + } // namespace tinytc #endif // CODEGEN_TOOLS_20240229_HPP diff --git a/src/error.cpp b/src/error.cpp index 0c760c5e..e9a9c6a6 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -173,6 +173,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "complex type unsupported by instruction"; case tinytc_status_ir_forbidden_cast: return "Forbidden cast"; + case tinytc_status_ir_invalid_beta: + return "beta must be constant and 0 or 1 for atomic linear algebra operations"; // Level Zero case tinytc_status_ze_result_not_ready: return "ZE_RESULT_NOT_READY"; diff --git a/src/inst.cpp b/src/inst.cpp index acd570b1..53a7904e 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -97,6 +97,18 @@ char const *tinytc_cmp_condition_to_string(tinytc_cmp_condition_t cond) { return "unknown"; } +char const *tinytc_store_flag_to_string(tinytc_store_flag_t flag) { + switch (flag) { + case tinytc_store_flag_regular: + return ""; + case tinytc_store_flag_atomic: + return "atomic"; + case tinytc_store_flag_atomic_add: + return "atomic_add"; + } + return "unknown"; +} + char const *tinytc_transpose_to_string(tinytc_transpose_t t) { switch (t) { case tinytc_transpose_T: @@ -402,16 +414,18 @@ tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t a, uint32_t stat }); } -tinytc_status_t tinytc_store_inst_create(tinytc_inst_t *instr, tinytc_value_t val, tinytc_value_t a, +tinytc_status_t tinytc_store_inst_create(tinytc_inst_t *instr, tinytc_store_flag_t flag, + tinytc_value_t val, tinytc_value_t a, uint32_t index_list_size, const tinytc_value_t *index_list, const tinytc_location_t *loc) { if (instr == nullptr || (index_list_size > 0 && index_list == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(val, a, array_view{index_list, index_list_size}, - get_optional(loc)) - .release(); + *instr = + std::make_unique(enum_cast(flag), val, a, + array_view{index_list, index_list_size}, get_optional(loc)) + .release(); }); } diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 4ac8de8d..a7297672 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -288,6 +288,13 @@ constant_inst::constant_inst(value_type const &value, tinytc_data_type_t ty, loc result(0) = value_node{ty, this, lc}; } +auto constant_inst::is_zero() const -> bool { + return std::visit([](auto const &v) { return v == decltype(v){0}; }, value_); +} +auto constant_inst::is_identity() const -> bool { + return std::visit([](auto const &v) { return v == decltype(v){1}; }, value_); +} + expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, array_view static_expand_shape0, array_view expand_shape0, location const &lc) @@ -598,9 +605,9 @@ subview_inst::subview_inst(tinytc_value_t op0, array_view static_o result(0) = value_node{result_ty, this, lc}; } -store_inst::store_inst(tinytc_value_t val0, tinytc_value_t op0, +store_inst::store_inst(store_flag flag, tinytc_value_t val0, tinytc_value_t op0, array_view index_list0, location const &lc) - : standard_inst{IK::store, static_cast(2 + index_list0.size())} { + : standard_inst{IK::store, static_cast(2 + index_list0.size())}, flag_{flag} { op(op_val, val0); op(op_operand, op0); { diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 39b77a72..732ad859 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -478,6 +478,8 @@ class constant_inst : public standard_inst<0, 1> { constant_inst(value_type const &value, tinytc_data_type_t ty, location const &lc = {}); auto value() const -> value_type const & { return value_; } + auto is_zero() const -> bool; + auto is_identity() const -> bool; private: value_type value_; @@ -714,12 +716,16 @@ class store_inst : public standard_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::store; } enum op_number { op_val = 0, op_operand = 1 }; - store_inst(tinytc_value_t val, tinytc_value_t op, array_view index_list, - location const &lc = {}); + store_inst(store_flag flag, tinytc_value_t val, tinytc_value_t op, + array_view index_list, location const &lc = {}); + inline auto flag() const -> store_flag { return flag_; } inline auto val() const -> tinytc_value const & { return op(op_val); } inline auto operand() const -> tinytc_value const & { return op(op_operand); } inline auto index_list() const { return operands() | std::views::drop(2); } + + private: + store_flag flag_; }; class sum_inst : public blas_a2_inst { diff --git a/src/parser/lexer.re b/src/parser/lexer.re index d2056cbd..1f570de9 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -95,6 +95,7 @@ lex: ".n" { adv_loc(); return parser::make_NOTRANS(loc_); } ".t" { adv_loc(); return parser::make_TRANS(loc_); } ".atomic" { adv_loc(); return parser::make_ATOMIC(loc_); } + ".atomic_add" { adv_loc(); return parser::make_ATOMIC_ADD(loc_); } "local" { adv_loc(); return parser::make_LOCAL(loc_); } "global" { adv_loc(); return parser::make_GLOBAL(loc_); } ".local" { adv_loc(); return parser::make_LOCAL_ATTR(loc_); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 97102cfd..394ab6fb 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -94,6 +94,7 @@ NOTRANS ".n" TRANS ".t" ATOMIC ".atomic" + ATOMIC_ADD ".atomic_add" LOCAL "local" GLOBAL "global" LOCAL_ATTR ".local" @@ -207,6 +208,7 @@ %nterm subgroup_local_id_inst %nterm subgroup_size_inst %nterm store_inst +%nterm store_flag %nterm subview_inst %nterm >> optional_slice_list %nterm >> slice_list @@ -920,7 +922,7 @@ load_inst: ; store_inst: - STORE var[a] COMMA var[b] LSQBR optional_value_list RSQBR COLON memref_type { + STORE store_flag var[a] COMMA var[b] LSQBR optional_value_list RSQBR COLON memref_type { if ($b->ty() != $memref_type) { auto loc = @b; loc.end = @memref_type.end; @@ -928,7 +930,7 @@ store_inst: } try { $$ = inst { - std::make_unique(std::move($a), std::move($b), + std::make_unique($store_flag, std::move($a), std::move($b), std::move($optional_value_list), @store_inst) .release() }; @@ -939,6 +941,12 @@ store_inst: } ; +store_flag: + %empty { $$ = store_flag::regular; } + | ATOMIC { $$ = store_flag::atomic; } + | ATOMIC_ADD { $$ = store_flag::atomic_add; } +; + group_id_inst: GROUP_ID { $$ = inst{std::make_unique(ctx.cctx().get(), @GROUP_ID).release()}; } ; diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 1c884e52..b5b9b813 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -995,7 +995,27 @@ std::vector convert_to_opencl_pass::operator()(store_inst const &s) } auto rhs = val(s.val()); - auto st = assignment(dereference(std::move(lhs)), std::move(rhs)); + auto st = clir::expr{}; + auto atomic_pointer_ty = + pointer_to(to_clir_atomic_ty(ot->element_ty(), to_clir_address_space(ot->addrspace()), + clir::type_qualifier::volatile_t)); + switch (s.flag()) { + case store_flag::regular: + st = assignment(dereference(std::move(lhs)), std::move(rhs)); + break; + case store_flag::atomic: + lhs = cast(std::move(atomic_pointer_ty), std::move(lhs)); + st = call_builtin(clir::builtin_function::atomic_store_explicit, + {std::move(lhs), std::move(rhs), clir::memory_order::relaxed, + clir::memory_scope::work_group}); + break; + case store_flag::atomic_add: + lhs = cast(std::move(atomic_pointer_ty), std::move(lhs)); + st = call_builtin(clir::builtin_function::atomic_fetch_add_explicit, + {std::move(lhs), std::move(rhs), clir::memory_order::relaxed, + clir::memory_scope::work_group}); + break; + } return {expression_statement(std::move(st))}; } diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 498d2a31..f91cb9ad 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -378,7 +378,11 @@ void dump_ir_pass::operator()(subview_inst const &s) { } void dump_ir_pass::operator()(store_inst const &e) { - *os_ << "store "; + *os_ << "store"; + if (e.flag() != store_flag::regular) { + *os_ << '.' << to_string(e.flag()); + } + *os_ << ' '; dump_val(e.val()); *os_ << ", "; dump_val(e.operand()); diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index be8e953e..a1768e01 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -80,12 +80,7 @@ auto linalg_generator::operator()(ger_inst &g) -> inst { auto ab = mixed_precision_arithmetic(bb, arithmetic::mul, a, b, g.loc()); auto alpha_ab = mixed_precision_arithmetic(bb, arithmetic::mul, &g.alpha(), ab, g.loc()); - auto c = bb.add(make_load(&g.C(), {mm, nn}, g.loc())); - auto beta_c = - mixed_precision_arithmetic(bb, arithmetic::mul, &g.beta(), c, g.loc()); - auto alpha_ab_plus_beta_c = mixed_precision_arithmetic( - bb, arithmetic::add, alpha_ab, beta_c, g.loc()); - bb.add(make_store(alpha_ab_plus_beta_c, &g.C(), {mm, nn}, g.loc())); + blas_update(bb, g.atomic(), alpha_ab, &g.beta(), &g.C(), {mm, nn}, g.loc()); }); }); }); diff --git a/test/codegen/atomic.ir b/test/codegen/atomic.ir index a876dec7..37f81fd7 100644 --- a/test/codegen/atomic.ir +++ b/test/codegen/atomic.ir @@ -2,6 +2,17 @@ ; SPDX-License-Identifier: BSD-3-Clause ; RUN: %tinytc-oc < %s | filecheck %s +func @atomic_store(%A: memref) { + %f0 = constant 0.0 -> f64 + %i0 = constant 0 -> index + store.atomic %f0, %A[%i0] : memref + store.atomic_add %f0, %A[%i0] : memref +; CHECK-LABEL: void atomic_store({{.*}} +; CHECK: atomic_store_explicit((global volatile atomic_double*) (A + i0 * 1), f0, memory_order_relaxed, memory_scope_work_group); +; CHECK: atomic_fetch_add_explicit((global volatile atomic_double*) (A + i0 * 1), f0, memory_order_relaxed, memory_scope_work_group); +} + + func @axpby_atomic_store(%alpha: f64, %A: memref, %B: memref) { %zero = constant 0.0 -> f64 axpby.n.atomic %alpha, %A, %zero, %B : f64, memref, f64, memref @@ -16,15 +27,6 @@ func @axpby_atomic_add(%alpha: f32, %A: memref, %B: memref) { ; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) b, alpha * Ab[(blck1 + m) * 1], memory_order_relaxed, memory_scope_work_group); } -func @axpby_atomic_general(%alpha: f32, %A: memref, %B: memref) { - axpby.n.atomic %alpha, %A, 3.14, %B : f32, memref, f32, memref -; CHECK: float expected = *B; -; CHECK-NEXT: float desired; -; CHECK-NEXT: do { -; CHECK-NEXT: desired = alpha * A[0] + 0x1.91eb851eb851fp+1f * expected; -; CHECK-NEXT: } while (atomic_compare_exchange_strong_explicit((global volatile atomic_float*) B, &expected, desired, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group)); -} - func @gemm_atomic(%A: memref, %B: memref, %C: memref) { %one = constant 1.0 -> f32 gemm.n.n.atomic %one, %A, %B, %one, %C @@ -42,7 +44,8 @@ func @ger_atomic(%A: memref, %B: memref, %C: memref) { } func @hadamard_atomic(%A: memref, %B: memref, %C: memref) { - hadamard.atomic 1.0, %A, %B, 1.0, %C + %one = constant 1.0 -> f32 + hadamard.atomic %one, %A, %B, %one, %C : f32, memref, memref, f32, memref ; CHECK: global float* c = C + (blck + m) * 1; ; CHECK-NEXT: float ab = A[(blck + m) * 1] * B[(blck + m) * 1]; @@ -50,7 +53,8 @@ func @hadamard_atomic(%A: memref, %B: memref, %C: memref } func @sum_atomic(%A: memref, %B: memref) { - sum.n.atomic 1.0, %A, 1.0, %B : f32, memref, f32, memref + %one = constant 1.0 -> f32 + sum.n.atomic %one, %A, %one, %B : f32, memref, f32, memref ; CHECK: float sum = work_group_reduce_add(acc); ; CHECK-NEXT: if (get_sub_group_id() == 0 && get_sub_group_local_id() == 0) { ; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) B, 0x1p+0f * sum, memory_order_relaxed, memory_scope_work_group); @@ -58,7 +62,8 @@ func @sum_atomic(%A: memref, %B: memref) { } func @sum_atomic_matrix(%A: memref, %B: memref) { - sum.n.atomic 1.0, %A, 1.0, %B : f32, memref, f32, memref + %one = constant 1.0 -> f32 + sum.n.atomic %one, %A, %one, %B : f32, memref, f32, memref ; CHECK: global float* b = B + (blck + m) * 1; ; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) b, 0x1p+0f * acc, memory_order_relaxed, memory_scope_work_group); } From 3b34e418f0206395cff89e81b5d841d743ce14d7 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 11 Oct 2024 16:23:33 +0200 Subject: [PATCH 050/297] Continue on linalg lowering Signed-off-by: Carsten Uphoff --- src/codegen_tools.cpp | 22 ++++- src/codegen_tools.hpp | 6 +- src/node/inst_node.hpp | 4 + src/pass/lower_linalg.cpp | 172 ++++++++++++++++++++++++++++++++------ 4 files changed, 178 insertions(+), 26 deletions(-) diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 8b1255ba..68ead93f 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -474,6 +474,25 @@ void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, in }); } +void tile_loop_by_sgs_standard(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, + value sg_id, sgs_loop_body_builder_standard const &body) { + auto ctx = compiler_context{sg_id->context(), true}; + auto index_ty = get_scalar(ctx, scalar_type::index); + auto m = bb.add(make_subgroup_local_id(ctx)); + auto m_index = bb.add(make_cast(m, index_ty)); + tile_loop_by_sgs_new( + bb, loop_trip_count, sgs, num_tiles, sg_id, + [&m_index, &body](region_builder &bb, value block, bool is_remainder, value trip_count) { + auto mm = bb.add(make_arith(arithmetic::add, block, m_index)); + if (is_remainder) { + auto cond = bb.add(make_cmp(cmp_condition::lt, m_index, trip_count)); + bb.if_condition(cond, [&](region_builder &bb) { body(bb, mm); }); + } else { + body(bb, mm); + } + }); +} + void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int block_size, int num_tiles, value sg_id, uniform_loop_body_builder_new const &body) { @@ -552,8 +571,9 @@ auto get_atomic_store_flag(value beta) -> std::optional { } return std::nullopt; } -void blas_update(region_builder &bb, bool atomic, value alpha_ab, value beta, value C, +void blas_update(region_builder &bb, bool atomic, value alpha, value ab, value beta, value C, array_view index_list, location const &loc) { + auto alpha_ab = mixed_precision_arithmetic(bb, arithmetic::mul, alpha, ab, loc); if (atomic) { auto flag = get_atomic_store_flag(beta); if (!flag) { diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index 8122bd88..a49db00f 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -129,11 +129,15 @@ void write_matrix_block(clir::block_builder &bb, block_accessor const &block, // tools for tinytc lowering using sgs_loop_body_builder_new = std::function; +using sgs_loop_body_builder_standard = std::function; using uniform_loop_body_builder_new = std::function; void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, value sg_id, sgs_loop_body_builder_new const &body); +void tile_loop_by_sgs_standard(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, + value sg_id, sgs_loop_body_builder_standard const &body); + void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int block_size, int num_tiles, value sg_id, uniform_loop_body_builder_new const &body); @@ -141,7 +145,7 @@ auto mixed_precision_arithmetic(region_builder &bb, arithmetic operation, value location const &loc) -> value; auto get_atomic_store_flag(value beta) -> std::optional; -void blas_update(region_builder &bb, bool atomic, value alpha_ab, value beta, value C, +void blas_update(region_builder &bb, bool atomic, value alpha, value ab, value beta, value C, array_view index_list, location const &loc); } // namespace tinytc diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 732ad859..8ddb7fef 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -324,9 +324,13 @@ class blas_a2_inst : public standard_inst<4, 0> { inline bool atomic() const { return atomic_; } inline void atomic(bool a) { atomic_ = a; } + inline auto alpha() -> tinytc_value & { return op(op_alpha); } inline auto alpha() const -> tinytc_value const & { return op(op_alpha); } + inline auto A() -> tinytc_value & { return op(op_A); } inline auto A() const -> tinytc_value const & { return op(op_A); } + inline auto beta() -> tinytc_value & { return op(op_beta); } inline auto beta() const -> tinytc_value const & { return op(op_beta); } + inline auto B() -> tinytc_value & { return op(op_B); } inline auto B() const -> tinytc_value const & { return op(op_B); } protected: diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index a1768e01..2b171ce3 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -19,6 +19,7 @@ #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" +#include #include #include #include @@ -30,7 +31,10 @@ class linalg_generator { linalg_generator(local_tiling tiling, core_config core_cfg) : tiling_{std::move(tiling)}, core_cfg_{std::move(core_cfg)} {} auto operator()(inst_node &) -> inst { return inst{}; } - auto operator()(ger_inst &g) -> inst; + auto operator()(axpby_inst &in) -> inst; + auto operator()(ger_inst &in) -> inst; + auto operator()(hadamard_inst &in) -> inst; + auto operator()(sum_inst &in) -> inst; private: auto get_memref_type(value_node const &v) const -> const memref_data_type *; @@ -47,47 +51,167 @@ auto linalg_generator::get_memref_type(value_node const &v) const -> const memre return t; } -auto linalg_generator::operator()(ger_inst &g) -> inst { - auto parallel = make_parallel(g.loc()); +auto linalg_generator::operator()(axpby_inst &in) -> inst { + auto parallel = make_parallel(in.loc()); tinytc_region_t body = ¶llel->child_region(0); auto bb = region_builder{body}; - auto ctx = compiler_context{g.alpha().context(), true}; + auto ctx = compiler_context{in.alpha().context(), true}; + auto index_ty = get_scalar(ctx, scalar_type::index); + auto i32_ty = get_scalar(ctx, scalar_type::i32); + + auto bt = get_memref_type(in.B()); + + auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); + + auto const inner_loop = [&](region_builder &bb, value Ab, value Bb, value trip_count, + int num_tiles, value sgid) { + tile_loop_by_sgs_standard(bb, trip_count, core_cfg_.subgroup_size, num_tiles, sgid, + [&](region_builder &bb, value mm) { + auto a = bb.add(make_load(Ab, {mm}, in.loc())); + blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), Bb, + {mm}, in.loc()); + }); + }; + + if (bt->dim() == 0) { + auto m = bb.add(make_subgroup_local_id(ctx, in.loc())); + auto c0 = bb.add(make_constant(0, i32_ty)); + auto cond0 = bb.add(make_cmp(cmp_condition::eq, sgid, c0)); + auto cond1 = bb.add(make_cmp(cmp_condition::eq, m, c0)); + auto cond = bb.add(make_arith(arithmetic::and_, cond0, cond1)); + bb.if_condition(cond, [&](region_builder &bb) { + auto a = bb.add(make_load(&in.A(), {}, in.loc())); + blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), {}, in.loc()); + }); + } else if (bt->dim() == 1) { + auto c_shape0 = bb.add(make_size(&in.B(), 0, in.loc())); + inner_loop(bb, &in.A(), &in.B(), c_shape0, tiling_.m_tiles() * tiling_.n_tiles(), sgid); + } else if (bt->dim() == 2) { + auto c_m_tiles = bb.add(make_constant(tiling_.m_tiles(), i32_ty, in.loc())); + auto sg_n = bb.add(make_arith(arithmetic::div, sgid, c_m_tiles, in.loc())); + auto sg_m = bb.add(make_arith(arithmetic::rem, sgid, c_m_tiles, in.loc())); + + auto c_shape0 = bb.add(make_size(&in.B(), 0, in.loc())); + auto c_shape1 = bb.add(make_size(&in.B(), 1, in.loc())); + tile_loop_uniformly_new( + bb, c_shape1, core_cfg_.subgroup_size, tiling_.n_tiles(), sg_n, + [&](region_builder &bb, value block, value trip_count) { + auto zero = bb.add(make_constant(0, index_ty)); + bb.for_loop(zero, trip_count, index_ty, [&](region_builder &bb, value n) { + auto nn = bb.add(make_arith(arithmetic::add, block, n, in.loc())); + auto static_offset_list = std::array{dynamic, 0}; + auto static_size_list = std::array{dynamic, 0}; + auto Bb = bb.add(make_subview(&in.B(), static_offset_list, static_size_list, + {nn}, {c_shape0}, in.loc())); + if (in.tA() == transpose::T) { + std::swap(static_offset_list[0], static_offset_list[1]); + std::swap(static_size_list[0], static_size_list[1]); + } + auto Ab = bb.add(make_subview(&in.A(), static_offset_list, static_size_list, + {nn}, {c_shape0}, in.loc())); + inner_loop(bb, Ab, Bb, c_shape0, tiling_.m_tiles(), sg_m); + }); + }); + } + + return parallel; +} + +auto linalg_generator::operator()(ger_inst &in) -> inst { + auto parallel = make_parallel(in.loc()); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto ctx = compiler_context{in.alpha().context(), true}; auto i32_ty = get_scalar(ctx, scalar_type::i32); auto index_ty = get_scalar(ctx, scalar_type::index); - auto sgid = bb.add(make_subgroup_id(ctx, g.loc())); - auto c_m_tiles = bb.add(make_constant(tiling_.m_tiles(), i32_ty, g.loc())); - auto sg_n = bb.add(make_arith(arithmetic::div, sgid, c_m_tiles, g.loc())); - auto sg_m = bb.add(make_arith(arithmetic::rem, sgid, c_m_tiles, g.loc())); - auto m = bb.add(make_subgroup_local_id(ctx, g.loc())); - auto m_index = bb.add(make_cast(m, index_ty, g.loc())); + auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); + auto c_m_tiles = bb.add(make_constant(tiling_.m_tiles(), i32_ty, in.loc())); + auto sg_n = bb.add(make_arith(arithmetic::div, sgid, c_m_tiles, in.loc())); + auto sg_m = bb.add(make_arith(arithmetic::rem, sgid, c_m_tiles, in.loc())); - auto c_shape0 = bb.add(make_size(&g.C(), 0, g.loc())); - auto c_shape1 = bb.add(make_size(&g.C(), 1, g.loc())); + auto c_shape0 = bb.add(make_size(&in.C(), 0, in.loc())); + auto c_shape1 = bb.add(make_size(&in.C(), 1, in.loc())); tile_loop_uniformly_new( bb, c_shape1, core_cfg_.subgroup_size, tiling_.n_tiles(), sg_n, [&](region_builder &bb, value block, value trip_count) { auto zero = bb.add(make_constant(0, index_ty)); bb.for_loop(zero, trip_count, index_ty, [&](region_builder &bb, value n) { - auto nn = bb.add(make_arith(arithmetic::add, block, n, g.loc())); - auto b = bb.add(make_load(&g.B(), {nn}, g.loc())); - tile_loop_by_sgs_new( - bb, c_shape0, core_cfg_.subgroup_size, tiling_.m_tiles(), sg_m, - [&](region_builder &bb, value block, bool, value) { - auto mm = bb.add(make_arith(arithmetic::add, block, m_index, g.loc())); - auto a = bb.add(make_load(&g.A(), {mm}, g.loc())); - auto ab = mixed_precision_arithmetic(bb, arithmetic::mul, a, b, g.loc()); - auto alpha_ab = mixed_precision_arithmetic(bb, arithmetic::mul, &g.alpha(), - ab, g.loc()); - blas_update(bb, g.atomic(), alpha_ab, &g.beta(), &g.C(), {mm, nn}, g.loc()); - }); + auto nn = bb.add(make_arith(arithmetic::add, block, n, in.loc())); + auto b = bb.add(make_load(&in.B(), {nn}, in.loc())); + tile_loop_by_sgs_standard(bb, c_shape0, core_cfg_.subgroup_size, tiling_.m_tiles(), + sg_m, [&](region_builder &bb, value mm) { + auto a = bb.add(make_load(&in.A(), {mm}, in.loc())); + auto ab = mixed_precision_arithmetic( + bb, arithmetic::mul, a, b, in.loc()); + blas_update(bb, in.atomic(), &in.alpha(), ab, + &in.beta(), &in.C(), {mm, nn}, in.loc()); + }); }); }); return parallel; } +auto linalg_generator::operator()(hadamard_inst &in) -> inst { + auto parallel = make_parallel(in.loc()); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto ctx = compiler_context{in.alpha().context(), true}; + auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); + + auto c_shape0 = bb.add(make_size(&in.C(), 0, in.loc())); + tile_loop_by_sgs_standard( + bb, c_shape0, core_cfg_.subgroup_size, tiling_.m_tiles() * tiling_.n_tiles(), sgid, + [&](region_builder &bb, value mm) { + auto a = bb.add(make_load(&in.A(), {mm}, in.loc())); + auto b = bb.add(make_load(&in.B(), {mm}, in.loc())); + auto ab = mixed_precision_arithmetic(bb, arithmetic::mul, a, b, in.loc()); + blas_update(bb, in.atomic(), &in.alpha(), ab, &in.beta(), &in.C(), {mm}, in.loc()); + }); + + return parallel; +} + +auto linalg_generator::operator()(sum_inst &in) -> inst { + auto parallel = make_parallel(in.loc()); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto ctx = compiler_context{in.alpha().context(), true}; + auto index_ty = get_scalar(ctx, scalar_type::index); + + auto bt = get_memref_type(in.B()); + + auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); + + if (bt->dim() == 0) { + // @todo + } else if (bt->dim() == 1) { + auto c_shape0 = bb.add(make_size(&in.B(), 0, in.loc())); + auto c_trip_count = bb.add(make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, in.loc())); + tile_loop_by_sgs_standard( + bb, c_shape0, core_cfg_.subgroup_size, tiling_.m_tiles() * tiling_.n_tiles(), sgid, + [&](region_builder &bb, value mm) { + auto zero = bb.add(make_constant(0, index_ty)); + // @todo need for loop that yields values + bb.for_loop(zero, c_trip_count, index_ty, [&](region_builder &bb, value n) { + auto index_list = std::array{mm, n}; + if (in.tA() == transpose::T) { + std::swap(index_list[0], index_list[1]); + } + auto a = bb.add(make_load(&in.A(), index_list, in.loc())); + blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), {mm}, + in.loc()); + }); + }); + } + return parallel; +} + lower_linalg_pass::lower_linalg_pass(::tinytc_core_info const *info) : info_(std::move(info)) { if (info_ == nullptr) { throw std::invalid_argument("info must not be nullptr"); From 9c00a242c1340895e176c1e21f17fa0941e9f265 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 14 Oct 2024 11:40:28 +0200 Subject: [PATCH 051/297] Fix L0 tests Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.rst | 14 +++++++ docs/api/builder_capi.yaml | 2 + docs/api/builder_cxxapi.rst | 14 +++++++ docs/api/builder_cxxapi.yaml | 2 + docs/api/core_capi.rst | 14 +++++++ docs/api/core_capi.yaml | 2 + docs/api/ze/cxxapi.rst | 32 +++++++-------- docs/api/ze/cxxapi.yaml | 8 ++-- examples/simple_ze/main.c | 37 +++++------------ include/tinytc/tinytc.h | 52 ++++++++++++++++++++++++ include/tinytc/tinytc.hpp | 63 ++++++++++++++++++++++++++--- include/tinytc/tinytc_ze.h | 35 ++++++---------- include/tinytc/tinytc_ze.hpp | 40 +++++++----------- src/binary.cpp | 25 ++++++++---- src/binary.hpp | 9 ++++- src/compiler.cpp | 14 +++---- src/inst.cpp | 67 +++++++++++++++++++++++++++++++ src/node/program_node.hpp | 3 +- src/prog.cpp | 3 +- src/recipe/small_gemm_batched.cpp | 13 +++--- src/recipe/tall_and_skinny.cpp | 12 +++--- src/source.cpp | 15 +++++++ src/source.hpp | 12 +++--- src/ze/error.hpp | 5 ++- src/ze/kernel.cpp | 56 +++++++++++++------------- src/ze/recipe_handler.cpp | 10 ++--- src/ze/recipe_handler.hpp | 3 +- 27 files changed, 384 insertions(+), 178 deletions(-) diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index 55d62075..ee78cef0 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -352,6 +352,10 @@ Instruction * :ref:`tinytc_constant_inst_create_int` + * :ref:`tinytc_constant_inst_create_one` + + * :ref:`tinytc_constant_inst_create_zero` + * :ref:`tinytc_expand_inst_create` * :ref:`tinytc_for_inst_create` @@ -450,6 +454,16 @@ tinytc_constant_inst_create_int .. doxygenfunction:: tinytc_constant_inst_create_int +tinytc_constant_inst_create_one +............................... + +.. doxygenfunction:: tinytc_constant_inst_create_one + +tinytc_constant_inst_create_zero +................................ + +.. doxygenfunction:: tinytc_constant_inst_create_zero + tinytc_expand_inst_create ......................... diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index fd76216f..1e25b4e9 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -62,6 +62,8 @@ Builder C-API: - tinytc_constant_inst_create_complex - tinytc_constant_inst_create_float - tinytc_constant_inst_create_int + - tinytc_constant_inst_create_one + - tinytc_constant_inst_create_zero - tinytc_expand_inst_create - tinytc_for_inst_create - tinytc_foreach_inst_create diff --git a/docs/api/builder_cxxapi.rst b/docs/api/builder_cxxapi.rst index e53a21e0..ba128a65 100644 --- a/docs/api/builder_cxxapi.rst +++ b/docs/api/builder_cxxapi.rst @@ -290,6 +290,10 @@ Instruction * :ref:`make_constant(std::int64_t,data_type,location const&)` + * :ref:`make_constant_one` + + * :ref:`make_constant_zero` + * :ref:`make_expand` * :ref:`make_for` @@ -391,6 +395,16 @@ make_constant(std::int64_t,data_type,location const&) .. doxygenfunction:: tinytc::make_constant(std::int64_t,data_type,location const&) +make_constant_one +................. + +.. doxygenfunction:: tinytc::make_constant_one + +make_constant_zero +.................. + +.. doxygenfunction:: tinytc::make_constant_zero + make_expand ........... diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index c3a9f077..c06aaa38 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -55,6 +55,8 @@ Builder C++-API: - tinytc::make_constant(double,data_type,location const&) - tinytc::make_constant(std::int32_t,data_type,location const&) - tinytc::make_constant(std::int64_t,data_type,location const&) + - tinytc::make_constant_one + - tinytc::make_constant_zero - tinytc::make_expand - tinytc::make_for - tinytc::make_foreach diff --git a/docs/api/core_capi.rst b/docs/api/core_capi.rst index 093bbbda..8579e308 100644 --- a/docs/api/core_capi.rst +++ b/docs/api/core_capi.rst @@ -203,6 +203,8 @@ Binary * :ref:`tinytc_binary_create` + * :ref:`tinytc_binary_get_compiler_context` + * :ref:`tinytc_binary_get_core_features` * :ref:`tinytc_binary_get_raw` @@ -219,6 +221,11 @@ tinytc_binary_create .. doxygenfunction:: tinytc_binary_create +tinytc_binary_get_compiler_context +.................................. + +.. doxygenfunction:: tinytc_binary_get_compiler_context + tinytc_binary_get_core_features ............................... @@ -607,6 +614,8 @@ Source * :ref:`tinytc_source_get_code` + * :ref:`tinytc_source_get_compiler_context` + * :ref:`tinytc_source_get_core_features` * :ref:`tinytc_source_get_location` @@ -625,6 +634,11 @@ tinytc_source_get_code .. doxygenfunction:: tinytc_source_get_code +tinytc_source_get_compiler_context +.................................. + +.. doxygenfunction:: tinytc_source_get_compiler_context + tinytc_source_get_core_features ............................... diff --git a/docs/api/core_capi.yaml b/docs/api/core_capi.yaml index 4a4d2594..c57673bc 100644 --- a/docs/api/core_capi.yaml +++ b/docs/api/core_capi.yaml @@ -33,6 +33,7 @@ Core C-API: Binary: function: - tinytc_binary_create + - tinytc_binary_get_compiler_context - tinytc_binary_get_core_features - tinytc_binary_get_raw - tinytc_binary_release @@ -97,6 +98,7 @@ Core C-API: Source: function: - tinytc_source_get_code + - tinytc_source_get_compiler_context - tinytc_source_get_core_features - tinytc_source_get_location - tinytc_source_get_extensions diff --git a/docs/api/ze/cxxapi.rst b/docs/api/ze/cxxapi.rst index d11b47b5..eedae3b3 100644 --- a/docs/api/ze/cxxapi.rst +++ b/docs/api/ze/cxxapi.rst @@ -55,11 +55,11 @@ Kernel * :ref:`make_kernel(ze_module_handle_t,char const \\*)` - * :ref:`make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&,source_context)` + * :ref:`make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&)` - * :ref:`make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t,source_context)` + * :ref:`make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t)` - * :ref:`make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&,source_context)` + * :ref:`make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&)` Kernel Functions ---------------- @@ -84,27 +84,27 @@ make_kernel(ze_module_handle_t,char const \*) .. doxygenfunction:: tinytc::make_kernel(ze_module_handle_t,char const *) -make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&,source_context) -....................................................................................... +make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&) +........................................................................ -.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&,source_context) +.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&) -make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t,source_context) -.......................................................................................................... +make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t) +........................................................................................... -.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t,source_context) +.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t) -make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&,source_context) -....................................................................................... +make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&) +........................................................................ -.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&,source_context) +.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&) Recipe ====== * Functions - * :ref:`make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&,source_context)` + * :ref:`make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&)` * Classes @@ -113,10 +113,10 @@ Recipe Recipe Functions ---------------- -make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&,source_context) -........................................................................................ +make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&) +......................................................................... -.. doxygenfunction:: tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&,source_context) +.. doxygenfunction:: tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&) Recipe Classes -------------- diff --git a/docs/api/ze/cxxapi.yaml b/docs/api/ze/cxxapi.yaml index 4308b3ed..ad480350 100644 --- a/docs/api/ze/cxxapi.yaml +++ b/docs/api/ze/cxxapi.yaml @@ -14,11 +14,11 @@ C++-API: - tinytc::get_group_count - tinytc::get_group_size(ze_kernel_handle_t) - tinytc::make_kernel(ze_module_handle_t,char const *) - - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&,source_context) - - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t,source_context) - - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&,source_context) + - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&) + - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t) + - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&) Recipe: function: - - tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&,source_context) + - tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&) class: - tinytc::level_zero_recipe_handler diff --git a/examples/simple_ze/main.c b/examples/simple_ze/main.c index d4a6a713..0ece8057 100644 --- a/examples/simple_ze/main.c +++ b/examples/simple_ze/main.c @@ -41,7 +41,6 @@ tinytc_status_t gemm(ze_context_handle_t context, ze_device_handle_t device, ze_command_list_handle_t list) { tinytc_status_t status = tinytc_status_success; tinytc_core_info_t info = NULL; - tinytc_source_context_t source_ctx = NULL; tinytc_recipe_t recipe = NULL; tinytc_recipe_handler_t handler = NULL; void *A = NULL, *B = NULL, *C = NULL; @@ -50,11 +49,10 @@ tinytc_status_t gemm(ze_context_handle_t context, ze_device_handle_t device, CHECK(tinytc_ze_core_info_create(&info, device)); const uint32_t M = 64, N = 64, K = 64, howmany = 1000; - CHECK(tinytc_source_context_create(&source_ctx)); CHECK(tinytc_recipe_small_gemm_batched_create(&recipe, info, tinytc_scalar_type_f32, tinytc_transpose_N, tinytc_transpose_N, M, N, K, - M, M * K, K, K * N, M, M * N, source_ctx)); - CHECK(tinytc_ze_recipe_handler_create(&handler, context, device, recipe, source_ctx)); + M, M * K, K, K * N, M, M * N, NULL)); + CHECK(tinytc_ze_recipe_handler_create(&handler, context, device, recipe)); const size_t Abytes = M * K * howmany * sizeof(float); const size_t Bbytes = K * N * howmany * sizeof(float); @@ -110,14 +108,6 @@ tinytc_status_t gemm(ze_context_handle_t context, ze_device_handle_t device, } tinytc_recipe_handler_release(handler); tinytc_recipe_release(recipe); - if (source_ctx) { - const char *error_log; - tinytc_source_context_get_error_log(source_ctx, &error_log); - if (error_log[0] != '\0') { - printf("\nError log:\n%s\n", error_log); - } - tinytc_source_context_release(source_ctx); - } tinytc_core_info_release(info); return status; @@ -129,7 +119,6 @@ tinytc_status_t custom_kernel(ze_context_handle_t context, ze_device_handle_t de int32_t *host = NULL; void *A = NULL, *B = NULL; tinytc_core_info_t info = NULL; - tinytc_source_context_t source_ctx = NULL; tinytc_prog_t program = NULL; ze_module_handle_t module = NULL; ze_kernel_handle_t kernel = NULL; @@ -157,16 +146,16 @@ tinytc_status_t custom_kernel(ze_context_handle_t context, ze_device_handle_t de static const char source_text[] = "func @copy(%A: memref, %B: memref) {\n" " %gid = group_id\n" - " %a = subview %A[:,%gid] : memref\n" - " %b = subview %B[:,%gid] : memref\n" - " axpby.n 1, %a, 0, %b\n" + " %a = subview %A[0:" CHUNK_SIZE_S ",%gid] : memref\n" + " %b = subview %B[0:" CHUNK_SIZE_S ",%gid] : memref\n" + " %c0 = constant 0 -> i32\n" + " %c1 = constant 1 -> i32\n" + " axpby.n %c1, %a, %c0, %b\n" " : i32, memref, i32, memref\n" "}\n"; - CHECK(tinytc_source_context_create(&source_ctx)); - CHECK(tinytc_parse_string(&program, sizeof(source_text), source_text, source_ctx)); - CHECK(tinytc_ze_kernel_bundle_create_with_program(&module, context, device, program, 0u, - source_ctx)); + CHECK(tinytc_parse_string(&program, sizeof(source_text), source_text, NULL)); + CHECK(tinytc_ze_kernel_bundle_create_with_program(&module, context, device, program, 0u)); CHECK(tinytc_ze_kernel_create(&kernel, module, "copy")); ZE_CHECK(zeKernelSetArgumentValue(kernel, 0, sizeof(A), &A)); @@ -200,14 +189,6 @@ tinytc_status_t custom_kernel(ze_context_handle_t context, ze_device_handle_t de zeModuleDestroy(module); } tinytc_prog_release(program); - if (source_ctx) { - const char *error_log; - tinytc_source_context_get_error_log(source_ctx, &error_log); - if (error_log[0] != '\0') { - printf("\nError log:\n%s\n", error_log); - } - tinytc_source_context_release(source_ctx); - } tinytc_core_info_release(info); if (B) { zeMemFree(context, B); diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 469e497b..05033691 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -271,6 +271,32 @@ TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_int(tinytc_inst_t *ins tinytc_data_type_t ty, const tinytc_location_t *loc); +/** + * @brief Creates the multiplicative identity constant (i.e. "1") for the given data type + * + * @param instr [out] pointer to the inst object created + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_one(tinytc_inst_t *instr, + tinytc_data_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Creates the additive identity constant (i.e. "0") for the given data type + * + * @param instr [out] pointer to the inst object created + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_zero(tinytc_inst_t *instr, + tinytc_data_type_t ty, + const tinytc_location_t *loc); + /** * @brief Create alloca instruction * @@ -1320,6 +1346,18 @@ TINYTC_EXPORT tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src TINYTC_EXPORT tinytc_status_t tinytc_source_get_code(const_tinytc_source_t src, size_t *length, char const **code); +/** + * @brief Get context object from source object + * + * @param src [in] source object + * @param ctx [out] pointer to context object; reference count is increased so the user needs to + * call tinytc_compiler_context_release to clean up + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_source_get_compiler_context(const_tinytc_source_t src, + tinytc_compiler_context_t *ctx); + /** * @brief Get source location * @@ -1360,6 +1398,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_source_get_extensions(const_tinytc_source_t * @brief Create binary * * @param bin [out] pointer to binary object + * @param ctx [in] compiler context * @param format [in] Bundle format (SPIR-V or Native) * @param data_size [in] Size of data in bytes * @param data [in][range(0, data_size)] Binary data; data is copied @@ -1369,10 +1408,23 @@ TINYTC_EXPORT tinytc_status_t tinytc_source_get_extensions(const_tinytc_source_t * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_binary_create(tinytc_binary_t *bin, + tinytc_compiler_context_t ctx, tinytc_bundle_format_t format, size_t data_size, uint8_t const *data, tinytc_core_feature_flags_t core_features); +/** + * @brief Get context object from binary object + * + * @param bin [in] binary object + * @param ctx [out] pointer to context object; reference count is increased so the user needs to + * call tinytc_compiler_context_release to clean up + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_binary_get_compiler_context(const_tinytc_binary_t bin, + tinytc_compiler_context_t *ctx); + /** * @brief Get raw binary data * diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index d14ed73c..8c8b4bac 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -920,6 +920,34 @@ inline inst make_constant(std::int64_t value, data_type ty, location const &loc return inst(instr); } +/** + * @brief Make multiplicative identity constant ("1") for the given data type + * + * @param ty Scalar data type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_constant_one(data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_one(&instr, ty, &loc), loc); + return inst(instr); +} + +/** + * @brief Make additive identity constant ("0") for the given data type + * + * @param ty Scalar data type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_constant_zero(data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_zero(&instr, ty, &loc), loc); + return inst(instr); +} + /** * @brief Make alloca instruction * @@ -1252,7 +1280,7 @@ inline inst make_subview(value a, array_view static_offset_list, if (offset_len > std::numeric_limits::max()) { throw std::out_of_range("dynamic offset list too long"); } - auto size_len = offset_list.size(); + auto size_len = size_list.size(); if (size_len > std::numeric_limits::max()) { throw std::out_of_range("dynamic size list too long"); } @@ -1894,6 +1922,17 @@ class source : public shared_handle { return std::string_view(code, length); } + /** + * @brief Get compiler context + * + * @return Compiler context + */ + inline auto get_compiler_context() const -> compiler_context { + tinytc_compiler_context_t ctx; + CHECK_STATUS(tinytc_source_get_compiler_context(obj_, &ctx)); + return compiler_context{ctx, true}; + } + /** * @brief Get location * @@ -1945,19 +1984,29 @@ class binary : public shared_handle { * * @return Raw data */ - inline auto get_raw() -> raw { + inline auto get_raw() const -> raw { raw r; tinytc_bundle_format_t f; CHECK_STATUS(tinytc_binary_get_raw(obj_, &f, &r.data_size, &r.data)); r.format = bundle_format{std::underlying_type_t(f)}; return r; } + /** + * @brief Get compiler context + * + * @return Compiler context + */ + inline auto get_compiler_context() const -> compiler_context { + tinytc_compiler_context_t ctx; + CHECK_STATUS(tinytc_binary_get_compiler_context(obj_, &ctx)); + return compiler_context{ctx, true}; + } /** * @brief Get core features * * @return Core features */ - inline auto get_core_features() -> tinytc_core_feature_flags_t { + inline auto get_core_features() const -> tinytc_core_feature_flags_t { tinytc_core_feature_flags_t cf; CHECK_STATUS(tinytc_binary_get_core_features(obj_, &cf)); return cf; @@ -1967,6 +2016,7 @@ class binary : public shared_handle { /** * @brief Make binary * + * @param ctx Compiler context * @param format Bundle format (SPIR-V or Native) * @param data_size Size of data in bytes * @param data Binary data; data is copied @@ -1975,11 +2025,12 @@ class binary : public shared_handle { * * @return Binary */ -inline auto make_binary(bundle_format format, std::size_t data_size, std::uint8_t const *data, +inline auto make_binary(compiler_context const &ctx, bundle_format format, std::size_t data_size, + std::uint8_t const *data, tinytc_core_feature_flags_t core_features) -> binary { tinytc_binary_t bin; - CHECK_STATUS(tinytc_binary_create(&bin, static_cast(format), data_size, - data, core_features)); + CHECK_STATUS(tinytc_binary_create(&bin, ctx.get(), static_cast(format), + data_size, data, core_features)); return binary{bin}; } diff --git a/include/tinytc/tinytc_ze.h b/include/tinytc/tinytc_ze.h index d5b4d45a..2539076a 100644 --- a/include/tinytc/tinytc_ze.h +++ b/include/tinytc/tinytc_ze.h @@ -64,14 +64,13 @@ TINYTC_EXPORT tinytc_status_t tinytc_ze_core_info_create(tinytc_core_info_t *inf * @param src [in] source text * @param ip_version [in] IP version (pass tinytc_intel_gpu_architecture_t here) * @param format [in] binary format (SPIR-V or native) - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_ze_source_compile_to_binary( - tinytc_binary_t *bin, const_tinytc_source_t src, uint32_t ip_version, - tinytc_bundle_format_t format, tinytc_source_context_t source_ctx); +TINYTC_EXPORT tinytc_status_t tinytc_ze_source_compile_to_binary(tinytc_binary_t *bin, + const_tinytc_source_t src, + uint32_t ip_version, + tinytc_bundle_format_t format); /** * @brief Compile OpenCL-C source to device binary @@ -80,14 +79,12 @@ TINYTC_EXPORT tinytc_status_t tinytc_ze_source_compile_to_binary( * @param context [in] context handle * @param device [in] device handle * @param src [in] source text and extensions - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_bundle_create_with_source( - ze_module_handle_t *bundle, ze_context_handle_t context, ze_device_handle_t device, - const_tinytc_source_t src, tinytc_source_context_t source_ctx); +TINYTC_EXPORT tinytc_status_t +tinytc_ze_kernel_bundle_create_with_source(ze_module_handle_t *bundle, ze_context_handle_t context, + ze_device_handle_t device, const_tinytc_source_t src); /** * @brief Compile tensor program @@ -98,15 +95,12 @@ TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_bundle_create_with_source( * @param prg [inout] tensor program; modified as compiler passes are run * @param core_features [in][optional] requested core features; must be 0 (default) or a combination * of tinytc_core_feature_flag_t - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_bundle_create_with_program( ze_module_handle_t *bundle, ze_context_handle_t context, ze_device_handle_t device, - tinytc_prog_t prg, tinytc_core_feature_flags_t core_features, - tinytc_source_context_t source_ctx); + tinytc_prog_t prg, tinytc_core_feature_flags_t core_features); /** * @brief Create an OpenCL program from a tinytc binary @@ -115,14 +109,12 @@ TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_bundle_create_with_program( * @param context [in] context handle * @param device [in] device handle * @param bin [in] binary object - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_bundle_create_with_binary( - ze_module_handle_t *bundle, ze_context_handle_t context, ze_device_handle_t device, - const_tinytc_binary_t bin, tinytc_source_context_t source_ctx); +TINYTC_EXPORT tinytc_status_t +tinytc_ze_kernel_bundle_create_with_binary(ze_module_handle_t *bundle, ze_context_handle_t context, + ze_device_handle_t device, const_tinytc_binary_t bin); /** * @brief Create a kernel and set group size @@ -169,16 +161,13 @@ TINYTC_EXPORT ze_group_count_t tinytc_ze_get_group_count(int64_t howmany); * @param context [in] context handle * @param device [in] device handle * @param recipe [in] recipe object - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_ze_recipe_handler_create(tinytc_recipe_handler_t *handler, ze_context_handle_t context, ze_device_handle_t device, - tinytc_recipe_t recipe, - tinytc_source_context_t source_ctx); + tinytc_recipe_t recipe); /** * @brief Submit recipe to device diff --git a/include/tinytc/tinytc_ze.hpp b/include/tinytc/tinytc_ze.hpp index 57a2eac2..a3bc4b6f 100644 --- a/include/tinytc/tinytc_ze.hpp +++ b/include/tinytc/tinytc_ze.hpp @@ -64,15 +64,14 @@ inline auto make_core_info(ze_device_handle_t device) -> core_info { * @param src Source object * @param ip_version IP version (pass tinytc_intel_gpu_architecture_t here) * @param format Bundle format (SPIR-V or Native) - * @param ctx Source context for improved error reporting * * @return Binary */ -inline auto compile_to_binary(source const &src, std::uint32_t ip_version, bundle_format format, - source_context ctx = {}) -> binary { +inline auto compile_to_binary(source const &src, std::uint32_t ip_version, + bundle_format format) -> binary { tinytc_binary_t bin; - CHECK_STATUS(tinytc_ze_source_compile_to_binary( - &bin, src.get(), ip_version, static_cast(format), ctx.get())); + CHECK_STATUS(tinytc_ze_source_compile_to_binary(&bin, src.get(), ip_version, + static_cast(format))); return binary{bin}; } @@ -91,16 +90,13 @@ template <> struct unique_handle_traits { * @param context Context * @param device Device * @param src Source - * @param source_ctx Source context for improved error reporting * * @return Level Zero module (unique handle) */ inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t device, - source const &src, source_context source_ctx = {}) - -> unique_handle { + source const &src) -> unique_handle { ze_module_handle_t obj; - CHECK_STATUS(tinytc_ze_kernel_bundle_create_with_source(&obj, context, device, src.get(), - source_ctx.get())); + CHECK_STATUS(tinytc_ze_kernel_bundle_create_with_source(&obj, context, device, src.get())); return unique_handle{obj}; } @@ -112,17 +108,15 @@ inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t d * @param prg Program * @param core_features requested core features; must be 0 (default) or a combination of * tinytc_core_feature_flag_t - * @param source_ctx Source context for improved error reporting * * @return Level Zero module (unique handle) */ inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t device, prog prg, - tinytc_core_feature_flags_t core_features = 0, - source_context source_ctx = {}) + tinytc_core_feature_flags_t core_features = 0) -> unique_handle { ze_module_handle_t obj; CHECK_STATUS(tinytc_ze_kernel_bundle_create_with_program(&obj, context, device, prg.get(), - core_features, source_ctx.get())); + core_features)); return unique_handle{obj}; } @@ -132,16 +126,13 @@ inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t d * @param context Context * @param device Device * @param bin Binary - * @param source_ctx Source context for improved error reporting * * @return Level Zero module (unique handle) */ inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t device, - binary const &bin, source_context source_ctx = {}) - -> unique_handle { + binary const &bin) -> unique_handle { ze_module_handle_t obj; - CHECK_STATUS(tinytc_ze_kernel_bundle_create_with_binary(&obj, context, device, bin.get(), - source_ctx.get())); + CHECK_STATUS(tinytc_ze_kernel_bundle_create_with_binary(&obj, context, device, bin.get())); return unique_handle{obj}; } @@ -153,8 +144,8 @@ inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t d * * @return Level Zero kernel (unique handle) */ -inline auto make_kernel(ze_module_handle_t mod, char const *name) - -> unique_handle { +inline auto make_kernel(ze_module_handle_t mod, + char const *name) -> unique_handle { ze_kernel_handle_t obj; CHECK_STATUS(tinytc_ze_kernel_create(&obj, mod, name)); return unique_handle{obj}; @@ -218,16 +209,13 @@ class level_zero_recipe_handler : public recipe_handler { * @param context Context * @param device Device * @param rec Recipe - * @param source_ctx Source context for improved error reporting * * @return Level Zero recipe handler */ inline auto make_recipe_handler(ze_context_handle_t context, ze_device_handle_t device, - recipe const &rec, source_context source_ctx = {}) - -> level_zero_recipe_handler { + recipe const &rec) -> level_zero_recipe_handler { tinytc_recipe_handler_t handler; - CHECK_STATUS( - tinytc_ze_recipe_handler_create(&handler, context, device, rec.get(), source_ctx.get())); + CHECK_STATUS(tinytc_ze_recipe_handler_create(&handler, context, device, rec.get())); return level_zero_recipe_handler{handler}; } diff --git a/src/binary.cpp b/src/binary.cpp index df09ca92..dea1ba8b 100644 --- a/src/binary.cpp +++ b/src/binary.cpp @@ -13,20 +13,23 @@ using namespace tinytc; -tinytc_binary::tinytc_binary(std::vector data, bundle_format format, - tinytc_core_feature_flags_t core_features) - : data_(std::move(data)), format_(format), core_features_(core_features) {} +tinytc_binary::tinytc_binary(compiler_context ctx, std::vector data, + bundle_format format, tinytc_core_feature_flags_t core_features) + : ctx_(std::move(ctx)), data_(std::move(data)), format_(format), core_features_(core_features) { +} extern "C" { -tinytc_status_t tinytc_binary_create(tinytc_binary_t *bin, tinytc_bundle_format_t format, - size_t data_size, uint8_t const *data, +tinytc_status_t tinytc_binary_create(tinytc_binary_t *bin, tinytc_compiler_context_t ctx, + tinytc_bundle_format_t format, size_t data_size, + uint8_t const *data, tinytc_core_feature_flags_t core_features) { - if (bin == nullptr || data == nullptr) { + if (bin == nullptr || ctx == nullptr || data == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *bin = std::make_unique(std::vector(data, data + data_size), + *bin = std::make_unique(compiler_context{ctx, true}, + std::vector(data, data + data_size), enum_cast(format), core_features) .release(); }); @@ -43,6 +46,14 @@ tinytc_status_t tinytc_binary_get_raw(const_tinytc_binary_t bin, tinytc_bundle_f return tinytc_status_success; } +tinytc_status_t tinytc_binary_get_compiler_context(const_tinytc_binary_t bin, + tinytc_compiler_context_t *ctx) { + if (bin == nullptr || ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *ctx = bin->share_context().release(); }); +} + tinytc_status_t tinytc_binary_get_core_features(const_tinytc_binary_t bin, tinytc_core_feature_flags_t *core_features) { if (bin == nullptr || core_features == nullptr) { diff --git a/src/binary.hpp b/src/binary.hpp index 3607c0db..2b1950f3 100644 --- a/src/binary.hpp +++ b/src/binary.hpp @@ -5,6 +5,7 @@ #define BINARY_20240308_HPP #include "reference_counted.hpp" +#include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" @@ -20,14 +21,17 @@ struct tinytc_binary : tinytc::reference_counted { /** * @brief Create binary * + * @param ctx Compiler context * @param data Binary data * @param format Binary format (SPIR-V or native device binary) * @param metadata_map Dictionary kernel name -> kernel metadata * @param core_features Required core features */ - tinytc_binary(std::vector data, tinytc::bundle_format format, - tinytc_core_feature_flags_t core_features); + tinytc_binary(tinytc::compiler_context ctx, std::vector data, + tinytc::bundle_format format, tinytc_core_feature_flags_t core_features); + inline auto context() const -> tinytc_compiler_context_t { return ctx_.get(); } + inline auto share_context() const -> tinytc::compiler_context { return ctx_; } //! Get raw data inline auto data() const noexcept -> std::uint8_t const * { return data_.data(); } //! Get size of raw data @@ -40,6 +44,7 @@ struct tinytc_binary : tinytc::reference_counted { } private: + tinytc::compiler_context ctx_; std::vector data_; tinytc::bundle_format format_; tinytc_core_feature_flags_t core_features_; diff --git a/src/compiler.cpp b/src/compiler.cpp index 2197a615..1c8ab4ef 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -57,7 +57,7 @@ tinytc_status_t tinytc_run_function_pass(char const *pass_name, tinytc_prog_t pr #define FUNCTION_PASS(NAME, CREATE_PASS, ...) \ if (strcmp(NAME, pass_name) == 0) { \ auto pass = CREATE_PASS; \ - optflag_setter{pass, prg->get_context()}(__VA_ARGS__); \ + optflag_setter{pass, prg->context()}(__VA_ARGS__); \ return run_function_pass(std::move(pass), *prg); \ } #define FUNCTION_PASS_WITH_INFO(NAME, CREATE_PASS) \ @@ -69,7 +69,7 @@ tinytc_status_t tinytc_run_function_pass(char const *pass_name, tinytc_prog_t pr #undef FUNCTION_PASS_WITH_INFO throw status::unknown_pass_name; }, - prg->get_context()); + prg->context()); } tinytc_status_t tinytc_list_function_passes(uint32_t *names_size, char const *const **names) { @@ -96,12 +96,12 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ } return exception_to_status_code( [&] { - const auto ctx = prg->get_context(); + auto ctx = prg->share_context(); const auto opt_level = ctx->opt_level(); // passes auto cpp = constant_propagation_pass{}; - optflag_setter{cpp, ctx}(tinytc::optflag::unsafe_fp_math); + optflag_setter{cpp, ctx.get()}(tinytc::optflag::unsafe_fp_math); run_function_pass(check_ir_pass{}, *prg); @@ -135,10 +135,10 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ clir::generate_opencl(oss, std::move(ast)); - *src = std::make_unique<::tinytc_source>(oss.str(), prg->loc(), std::move(ext), - info->core_features()) + *src = std::make_unique<::tinytc_source>(std::move(ctx), oss.str(), prg->loc(), + std::move(ext), info->core_features()) .release(); }, - prg->get_context()); + prg->context()); } } diff --git a/src/inst.cpp b/src/inst.cpp index 53a7904e..9ce8cc53 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -8,6 +8,7 @@ #include "node/inst_node.hpp" #include "node/region_node.hpp" #include "node/value_node.hpp" +#include "support/casting.hpp" #include "support/util.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" @@ -198,6 +199,72 @@ tinytc_status_t tinytc_constant_inst_create_int(tinytc_inst_t *instr, int64_t va [&] { *instr = std::make_unique(value, ty, get_optional(loc)).release(); }); } +tinytc_status_t tinytc_constant_inst_create_one(tinytc_inst_t *instr, tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + const auto *st = dyn_cast(ty); + if (st == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + switch (st->ty()) { + case scalar_type::i1: + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + *instr = + std::make_unique(std::int64_t{1}, ty, get_optional(loc)).release(); + break; + case scalar_type::f32: + case scalar_type::f64: + *instr = std::make_unique(double{1}, ty, get_optional(loc)).release(); + break; + case scalar_type::c32: + case scalar_type::c64: + *instr = std::make_unique(std::complex{1}, ty, get_optional(loc)) + .release(); + break; + } + }); +} + +tinytc_status_t tinytc_constant_inst_create_zero(tinytc_inst_t *instr, tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + const auto *st = dyn_cast(ty); + if (st == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + switch (st->ty()) { + case scalar_type::i1: + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + *instr = + std::make_unique(std::int64_t{0}, ty, get_optional(loc)).release(); + break; + case scalar_type::f32: + case scalar_type::f64: + *instr = std::make_unique(double{0}, ty, get_optional(loc)).release(); + break; + case scalar_type::c32: + case scalar_type::c64: + *instr = std::make_unique(std::complex{0}, ty, get_optional(loc)) + .release(); + break; + } + }); +} + tinytc_status_t tinytc_alloca_inst_create(tinytc_inst_t *instr, tinytc_data_type_t ty, const tinytc_location_t *loc) { if (instr == nullptr) { diff --git a/src/node/program_node.hpp b/src/node/program_node.hpp index 7015d5c0..c283691e 100644 --- a/src/node/program_node.hpp +++ b/src/node/program_node.hpp @@ -21,7 +21,8 @@ struct tinytc_prog final : tinytc::reference_counted { tinytc_prog(tinytc::compiler_context ctx, tinytc_location const &lc = {}); - inline auto get_context() const -> tinytc_compiler_context_t { return ctx_.get(); } + inline auto context() const -> tinytc_compiler_context_t { return ctx_.get(); } + inline auto share_context() const -> tinytc::compiler_context { return ctx_; } inline auto loc() const noexcept -> tinytc_location const & { return loc_; } inline void loc(tinytc_location const &loc) noexcept { loc_ = loc; } diff --git a/src/prog.cpp b/src/prog.cpp index 6fd170e3..e8a7ddee 100644 --- a/src/prog.cpp +++ b/src/prog.cpp @@ -74,8 +74,7 @@ tinytc_status_t tinytc_prog_get_compiler_context(const_tinytc_prog_t prg, if (prg == nullptr || ctx == nullptr) { return tinytc_status_invalid_arguments; } - return tinytc::exception_to_status_code( - [&] { *ctx = tinytc::compiler_context{prg->get_context(), true}.release(); }); + return exception_to_status_code([&] { *ctx = prg->share_context().release(); }); } tinytc_status_t tinytc_prog_print_to_file(const_tinytc_prog_t prg, char const *filename) { diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index 8c969ee1..7931fda6 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -18,7 +18,6 @@ #include #include #include -#include namespace tinytc { @@ -106,17 +105,17 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( auto bb = region_builder{fn_body}; auto gid = bb.add(make_group_id(ctx_, my_loc())); - auto const static_offsets = std::vector{0, 0, dynamic}; - auto const A_static_sizes = std::vector{M, K, 0}; - auto const B_static_sizes = std::vector{K, N, 0}; - auto const C_static_sizes = std::vector{M, N, 0}; + auto const static_offsets = std::array{0, 0, dynamic}; + auto const A_static_sizes = std::array{M, K, 0}; + auto const B_static_sizes = std::array{K, N, 0}; + auto const C_static_sizes = std::array{M, N, 0}; auto a = bb.add(make_subview(params[1], static_offsets, A_static_sizes, array_view{gid}, {}, my_loc())); auto b = bb.add(make_subview(params[2], static_offsets, B_static_sizes, array_view{gid}, {}, my_loc())); - auto c = bb.add(make_subview(params[3], static_offsets, C_static_sizes, + auto c = bb.add(make_subview(params[4], static_offsets, C_static_sizes, array_view{gid}, {}, my_loc())); - auto beta = is_beta_nonzero ? params[4] : bb.add(make_constant(0.0, ty_, my_loc())); + auto beta = is_beta_nonzero ? params[3] : bb.add(make_constant_zero(ty_, my_loc())); bb.add(make_gemm(tA_, tB_, false, params[0], std::move(a), std::move(b), beta, std::move(c), my_loc())); diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index e394bd79..13a9aa27 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -109,14 +109,14 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( auto c_M_block_size = bb.add(make_constant(M_block_size, index_ty, my_loc())); auto gid = bb.add(make_group_id(ctx_, my_loc())); auto m = bb.add(make_arith(arithmetic::mul, gid, c_M_block_size, my_loc())); - auto beta = is_beta_nonzero ? beta_arg : bb.add(make_constant(0.0, ty_, my_loc())); + auto beta = is_beta_nonzero ? beta_arg : bb.add(make_constant_zero(ty_, my_loc())); - auto const static_offsets = std::vector{dynamic, 0}; + auto const static_offsets = std::array{dynamic, 0}; auto const offsets = array_view{m}; auto const static_gemm = [&](region_builder &bb) { - auto const A_static_sizes = std::vector{M_block_size, K}; - auto const C_static_sizes = std::vector{M_block_size, N}; + auto const A_static_sizes = std::array{M_block_size, K}; + auto const C_static_sizes = std::array{M_block_size, N}; auto a = bb.add( make_subview(A, static_offsets, A_static_sizes, offsets, {}, my_loc())); auto c = bb.add( @@ -125,8 +125,8 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( my_loc())); }; auto const dynamic_gemm = [&](region_builder &bb, value dyn_block_size) { - auto const A_static_sizes = std::vector{dynamic, K}; - auto const C_static_sizes = std::vector{dynamic, N}; + auto const A_static_sizes = std::array{dynamic, K}; + auto const C_static_sizes = std::array{dynamic, N}; auto const sizes = array_view{dyn_block_size}; auto a = bb.add( make_subview(A, static_offsets, A_static_sizes, offsets, sizes, my_loc())); diff --git a/src/source.cpp b/src/source.cpp index 18038a0d..f4f9fca6 100644 --- a/src/source.cpp +++ b/src/source.cpp @@ -9,6 +9,13 @@ using namespace tinytc; +tinytc_source::tinytc_source(compiler_context ctx, std::string code, + tinytc_location const &code_loc, + std::vector required_extensions, + tinytc_core_feature_flags_t core_features) + : ctx_{std::move(ctx)}, code_(std::move(code)), code_loc_(code_loc), + required_extensions_(std::move(required_extensions)), core_features_(core_features) {} + extern "C" { tinytc_status_t tinytc_source_get_code(const_tinytc_source_t src, size_t *length, char const **code) { @@ -28,6 +35,14 @@ tinytc_status_t tinytc_source_get_location(const_tinytc_source_t src, tinytc_loc return exception_to_status_code([&] { *loc = src->code_loc(); }); } +tinytc_status_t tinytc_source_get_compiler_context(const_tinytc_source_t src, + tinytc_compiler_context_t *ctx) { + if (src == nullptr || ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *ctx = src->share_context().release(); }); +} + tinytc_status_t tinytc_source_get_core_features(const_tinytc_source_t src, tinytc_core_feature_flags_t *core_features) { if (src == nullptr || core_features == nullptr) { diff --git a/src/source.hpp b/src/source.hpp index d07017e7..ade28c25 100644 --- a/src/source.hpp +++ b/src/source.hpp @@ -5,6 +5,7 @@ #define SOURCE_20240412_HPP #include "reference_counted.hpp" +#include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include @@ -14,14 +15,14 @@ struct tinytc_source : tinytc::reference_counted { public: - inline tinytc_source(std::string code, tinytc_location const &code_loc, - std::vector required_extensions, - tinytc_core_feature_flags_t core_features) - : code_(std::move(code)), code_loc_(code_loc), - required_extensions_(std::move(required_extensions)), core_features_(core_features) {} + tinytc_source(tinytc::compiler_context ctx, std::string code, tinytc_location const &code_loc, + std::vector required_extensions, + tinytc_core_feature_flags_t core_features); inline auto code() const -> char const * { return code_.c_str(); } inline auto code_loc() const -> tinytc_location const & { return code_loc_; } + inline auto context() const -> tinytc_compiler_context_t { return ctx_.get(); } + inline auto share_context() const -> tinytc::compiler_context { return ctx_; } inline auto size() const -> std::size_t { return code_.size(); } inline auto const &required_extensions() const { return required_extensions_; } inline auto core_features() const noexcept -> tinytc_core_feature_flags_t { @@ -29,6 +30,7 @@ struct tinytc_source : tinytc::reference_counted { } private: + tinytc::compiler_context ctx_; std::string code_; tinytc_location code_loc_; std::vector required_extensions_; diff --git a/src/ze/error.hpp b/src/ze/error.hpp index f6f581c6..62c8cbf7 100644 --- a/src/ze/error.hpp +++ b/src/ze/error.hpp @@ -5,6 +5,7 @@ #define ZE_ERROR_20240419_HPP #include "opencl_cc.hpp" +#include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" #include @@ -13,7 +14,7 @@ namespace tinytc { template auto exception_to_status_code_ze(F &&f, - tinytc_source_context_t context = nullptr) -> tinytc_status_t { + tinytc_compiler_context_t context = nullptr) -> tinytc_status_t { try { f(); } catch (status const &st) { @@ -23,7 +24,7 @@ auto exception_to_status_code_ze(F &&f, } catch (opencl_c_compilation_error const &e) { if (context) { auto const loc = location{}; - tinytc_source_context_report_error(context, &loc, e.what(), true); + tinytc_compiler_context_report_error(context, &loc, e.what()); } return tinytc_status_compilation_error; } catch (std::bad_alloc const &e) { diff --git a/src/ze/kernel.cpp b/src/ze/kernel.cpp index 23bd099c..55063eb1 100644 --- a/src/ze/kernel.cpp +++ b/src/ze/kernel.cpp @@ -24,8 +24,7 @@ extern "C" { tinytc_status_t tinytc_ze_source_compile_to_binary(tinytc_binary_t *bin, const_tinytc_source_t src, uint32_t ip_version, - tinytc_bundle_format_t format, - tinytc_source_context_t source_ctx) { + tinytc_bundle_format_t format) { if (bin == nullptr || src == nullptr) { return tinytc_status_invalid_arguments; @@ -33,12 +32,16 @@ tinytc_status_t tinytc_ze_source_compile_to_binary(tinytc_binary_t *bin, const_t size_t code_size = 0; char const *code = nullptr; + tinytc_compiler_context_t ctx = nullptr; tinytc_core_feature_flags_t core_features = 0; std::uint32_t extensions_size = 0; char const *const *extensions = nullptr; + TINYTC_CHECK_STATUS(tinytc_source_get_code(src, &code_size, &code)); TINYTC_CHECK_STATUS(tinytc_source_get_core_features(src, &core_features)); TINYTC_CHECK_STATUS(tinytc_source_get_extensions(src, &extensions_size, &extensions)); + TINYTC_CHECK_STATUS(tinytc_source_get_compiler_context(src, &ctx)); + auto ctx_ = compiler_context{ctx}; // Clean-up ctx when ctx_ gets out of scope return exception_to_status_code_ze( [&] { @@ -51,17 +54,16 @@ tinytc_status_t tinytc_ze_source_compile_to_binary(tinytc_binary_t *bin, const_t auto bin_data = compile_opencl_c(code_size, code, fmt, ip_version, compiler_options.size(), compiler_options.data(), extensions_size, extensions); - CHECK_STATUS( - tinytc_binary_create(bin, format, bin_data.size(), bin_data.data(), core_features)); + CHECK_STATUS(tinytc_binary_create(bin, ctx_.get(), format, bin_data.size(), + bin_data.data(), core_features)); }, - source_ctx); + ctx_.get()); } tinytc_status_t tinytc_ze_kernel_bundle_create_with_source(ze_module_handle_t *bundle, ze_context_handle_t context, ze_device_handle_t device, - const_tinytc_source_t src, - tinytc_source_context_t source_ctx) { + const_tinytc_source_t src) { if (bundle == nullptr || src == nullptr) { return tinytc_status_invalid_arguments; } @@ -77,11 +79,11 @@ tinytc_status_t tinytc_ze_kernel_bundle_create_with_source(ze_module_handle_t *b // Get binary tinytc_binary_t bin = nullptr; - TINYTC_CHECK_STATUS(tinytc_ze_source_compile_to_binary( - &bin, src, dev_ip_ver.ipVersion, tinytc_bundle_format_native, source_ctx)); + TINYTC_CHECK_STATUS(tinytc_ze_source_compile_to_binary(&bin, src, dev_ip_ver.ipVersion, + tinytc_bundle_format_native)); tinytc_status_t status = - tinytc_ze_kernel_bundle_create_with_binary(bundle, context, device, bin, source_ctx); + tinytc_ze_kernel_bundle_create_with_binary(bundle, context, device, bin); tinytc_binary_release(bin); return status; } @@ -89,8 +91,7 @@ tinytc_status_t tinytc_ze_kernel_bundle_create_with_source(ze_module_handle_t *b tinytc_status_t tinytc_ze_kernel_bundle_create_with_program(ze_module_handle_t *bundle, ze_context_handle_t context, ze_device_handle_t device, tinytc_prog_t prg, - tinytc_core_feature_flags_t core_features, - tinytc_source_context_t source_ctx) { + tinytc_core_feature_flags_t core_features) { if (bundle == nullptr || prg == nullptr) { return tinytc_status_invalid_arguments; } @@ -106,12 +107,10 @@ tinytc_ze_kernel_bundle_create_with_program(ze_module_handle_t *bundle, ze_conte status != tinytc_status_success) { goto err; } - if (status = tinytc_prog_compile_to_opencl(&src, prg, info, source_ctx); - status != tinytc_status_success) { + if (status = tinytc_prog_compile_to_opencl(&src, prg, info); status != tinytc_status_success) { goto err; } - if (status = - tinytc_ze_kernel_bundle_create_with_source(bundle, context, device, src, source_ctx); + if (status = tinytc_ze_kernel_bundle_create_with_source(bundle, context, device, src); status != tinytc_status_success) { goto err; } @@ -125,8 +124,7 @@ tinytc_ze_kernel_bundle_create_with_program(ze_module_handle_t *bundle, ze_conte tinytc_status_t tinytc_ze_kernel_bundle_create_with_binary(ze_module_handle_t *bundle, ze_context_handle_t context, ze_device_handle_t device, - const_tinytc_binary_t bin, - tinytc_source_context_t source_ctx) { + const_tinytc_binary_t bin) { if (bin == nullptr) { return tinytc_status_invalid_arguments; } @@ -147,22 +145,24 @@ tinytc_status_t tinytc_ze_kernel_bundle_create_with_binary(ze_module_handle_t *b uint32_t core_features; TINYTC_CHECK_STATUS(tinytc_binary_get_core_features(bin, &core_features)); + tinytc_compiler_context_t ctx = nullptr; + TINYTC_CHECK_STATUS(tinytc_binary_get_compiler_context(bin, &ctx)); + auto ctx_ = compiler_context{ctx}; // Clean-up ctx when ctx_ goes out of scope + if (core_features & static_cast(tinytc_core_feature_flag_large_register_file)) { module_desc.pBuildFlags = large_register_file_compiler_option_ze; } ze_module_build_log_handle_t build_log; ze_result_t status = zeModuleCreate(context, device, &module_desc, bundle, &build_log); if (status != ZE_RESULT_SUCCESS) { - if (source_ctx) { - std::string log; - std::size_t log_size; - zeModuleBuildLogGetString(build_log, &log_size, nullptr); - log.resize(log_size); - zeModuleBuildLogGetString(build_log, &log_size, log.data()); - - tinytc_location_t loc = {}; - tinytc_source_context_report_error(source_ctx, &loc, log.c_str(), true); - } + std::string log; + std::size_t log_size; + zeModuleBuildLogGetString(build_log, &log_size, nullptr); + log.resize(log_size); + zeModuleBuildLogGetString(build_log, &log_size, log.data()); + + tinytc_location_t loc = {}; + tinytc_compiler_context_report_error(ctx_.get(), &loc, log.c_str()); zeModuleBuildLogDestroy(build_log); TINYTC_ZE_CHECK_STATUS(status); } else { diff --git a/src/ze/recipe_handler.cpp b/src/ze/recipe_handler.cpp index 2d0cd641..9bddeb88 100644 --- a/src/ze/recipe_handler.cpp +++ b/src/ze/recipe_handler.cpp @@ -17,10 +17,10 @@ namespace tinytc { ze_recipe_handler::ze_recipe_handler(ze_context_handle_t context, ze_device_handle_t device, - recipe rec, source_context source_ctx) + recipe rec) : ::tinytc_recipe_handler(std::move(rec)) { - module_ = make_kernel_bundle(context, device, get_recipe().get_source(), std::move(source_ctx)); + module_ = make_kernel_bundle(context, device, get_recipe().get_source()); auto const num_kernels = get_recipe()->num_kernels(); kernels_.reserve(num_kernels); @@ -58,14 +58,12 @@ extern "C" { tinytc_status_t tinytc_ze_recipe_handler_create(tinytc_recipe_handler_t *handler, ze_context_handle_t context, - ze_device_handle_t device, tinytc_recipe_t rec, - tinytc_source_context_t source_ctx) { + ze_device_handle_t device, tinytc_recipe_t rec) { if (handler == nullptr || rec == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code_ze([&] { - *handler = std::make_unique(context, device, recipe(rec, true), - source_context(source_ctx, true)) + *handler = std::make_unique(context, device, recipe{rec, true}) .release(); }); } diff --git a/src/ze/recipe_handler.hpp b/src/ze/recipe_handler.hpp index 14bd0f48..4c2ca202 100644 --- a/src/ze/recipe_handler.hpp +++ b/src/ze/recipe_handler.hpp @@ -17,8 +17,7 @@ namespace tinytc { struct ze_recipe_handler : ::tinytc_recipe_handler { public: - ze_recipe_handler(ze_context_handle_t context, ze_device_handle_t device, recipe rec, - source_context source_ctx); + ze_recipe_handler(ze_context_handle_t context, ze_device_handle_t device, recipe rec); void active_kernel(int kernel_num) override; void arg(std::uint32_t arg_index, std::size_t arg_size, const void *arg_value) override; From 5039dc0b18de451c741c98dede6f2436e44c09d0 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 14 Oct 2024 11:51:05 +0200 Subject: [PATCH 052/297] Fix CL tests Signed-off-by: Carsten Uphoff --- docs/api/cl/cxxapi.rst | 32 ++++++++++++++-------------- docs/api/cl/cxxapi.yaml | 8 +++---- examples/simple_cl/main.c | 37 ++++++++------------------------ include/tinytc/tinytc_cl.h | 27 +++++++++--------------- include/tinytc/tinytc_cl.hpp | 33 ++++++++++++----------------- src/cl/kernel.cpp | 41 +++++++++++++++++++++--------------- src/cl/recipe_handler.cpp | 11 ++++------ src/cl/recipe_handler.hpp | 3 +-- 8 files changed, 81 insertions(+), 111 deletions(-) diff --git a/docs/api/cl/cxxapi.rst b/docs/api/cl/cxxapi.rst index 2245e883..0bfa8e83 100644 --- a/docs/api/cl/cxxapi.rst +++ b/docs/api/cl/cxxapi.rst @@ -53,11 +53,11 @@ Kernel * :ref:`make_kernel(cl_program,char const\\*)` - * :ref:`make_kernel_bundle(cl_context,cl_device_id,binary const&,source_context)` + * :ref:`make_kernel_bundle(cl_context,cl_device_id,binary const&)` - * :ref:`make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t,source_context)` + * :ref:`make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t)` - * :ref:`make_kernel_bundle(cl_context,cl_device_id,source const&,source_context)` + * :ref:`make_kernel_bundle(cl_context,cl_device_id,source const&)` Kernel Functions ---------------- @@ -77,27 +77,27 @@ make_kernel(cl_program,char const\*) .. doxygenfunction:: tinytc::make_kernel(cl_program,char const*) -make_kernel_bundle(cl_context,cl_device_id,binary const&,source_context) -........................................................................ +make_kernel_bundle(cl_context,cl_device_id,binary const&) +......................................................... -.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,binary const&,source_context) +.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,binary const&) -make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t,source_context) -........................................................................................... +make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t) +............................................................................ -.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t,source_context) +.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t) -make_kernel_bundle(cl_context,cl_device_id,source const&,source_context) -........................................................................ +make_kernel_bundle(cl_context,cl_device_id,source const&) +......................................................... -.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,source const&,source_context) +.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,source const&) Recipe ====== * Functions - * :ref:`make_recipe_handler(cl_context,cl_device_id,recipe const&,source_context)` + * :ref:`make_recipe_handler(cl_context,cl_device_id,recipe const&)` * Classes @@ -110,10 +110,10 @@ Recipe Recipe Functions ---------------- -make_recipe_handler(cl_context,cl_device_id,recipe const&,source_context) -......................................................................... +make_recipe_handler(cl_context,cl_device_id,recipe const&) +.......................................................... -.. doxygenfunction:: tinytc::make_recipe_handler(cl_context,cl_device_id,recipe const&,source_context) +.. doxygenfunction:: tinytc::make_recipe_handler(cl_context,cl_device_id,recipe const&) Recipe Classes -------------- diff --git a/docs/api/cl/cxxapi.yaml b/docs/api/cl/cxxapi.yaml index a61f974d..fc1e9d6e 100644 --- a/docs/api/cl/cxxapi.yaml +++ b/docs/api/cl/cxxapi.yaml @@ -13,12 +13,12 @@ C++-API: - tinytc::get_global_size(std::int64_t,std::array const &) - tinytc::get_group_size(cl_kernel) - tinytc::make_kernel(cl_program,char const*) - - tinytc::make_kernel_bundle(cl_context,cl_device_id,binary const&,source_context) - - tinytc::make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t,source_context) - - tinytc::make_kernel_bundle(cl_context,cl_device_id,source const&,source_context) + - tinytc::make_kernel_bundle(cl_context,cl_device_id,binary const&) + - tinytc::make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t) + - tinytc::make_kernel_bundle(cl_context,cl_device_id,source const&) Recipe: function: - - tinytc::make_recipe_handler(cl_context,cl_device_id,recipe const&,source_context) + - tinytc::make_recipe_handler(cl_context,cl_device_id,recipe const&) class: - tinytc::opencl_recipe_handler struct: diff --git a/examples/simple_cl/main.c b/examples/simple_cl/main.c index 0ae5c520..e6efb1d2 100644 --- a/examples/simple_cl/main.c +++ b/examples/simple_cl/main.c @@ -39,7 +39,6 @@ tinytc_status_t gemm(cl_context context, cl_device_id device, cl_command_queue queue) { tinytc_status_t status = tinytc_status_success; tinytc_core_info_t info = NULL; - tinytc_source_context_t source_ctx = NULL; tinytc_recipe_t recipe = NULL; tinytc_recipe_handler_t handler = NULL; cl_mem A = NULL, B = NULL, C = NULL; @@ -49,11 +48,10 @@ tinytc_status_t gemm(cl_context context, cl_device_id device, cl_command_queue q CHECK(tinytc_cl_core_info_create(&info, device)); const uint32_t M = 64, N = 64, K = 64, howmany = 1000; - CHECK(tinytc_source_context_create(&source_ctx)); CHECK(tinytc_recipe_small_gemm_batched_create(&recipe, info, tinytc_scalar_type_f32, tinytc_transpose_N, tinytc_transpose_N, M, N, K, - M, M * K, K, K * N, M, M * N, source_ctx)); - CHECK(tinytc_cl_recipe_handler_create(&handler, context, device, recipe, source_ctx)); + M, M * K, K, K * N, M, M * N, NULL)); + CHECK(tinytc_cl_recipe_handler_create(&handler, context, device, recipe)); const size_t Abytes = M * K * howmany * sizeof(float); const size_t Bbytes = K * N * howmany * sizeof(float); @@ -114,14 +112,6 @@ tinytc_status_t gemm(cl_context context, cl_device_id device, cl_command_queue q } tinytc_recipe_handler_release(handler); tinytc_recipe_release(recipe); - if (source_ctx) { - const char *error_log; - tinytc_source_context_get_error_log(source_ctx, &error_log); - if (error_log[0] != '\0') { - printf("\nError log:\n%s\n", error_log); - } - tinytc_source_context_release(source_ctx); - } tinytc_core_info_release(info); return status; @@ -132,7 +122,6 @@ tinytc_status_t custom_kernel(cl_context context, cl_device_id device, cl_comman int32_t *host = NULL; cl_mem A = NULL, B = NULL; tinytc_core_info_t info = NULL; - tinytc_source_context_t source_ctx = NULL; tinytc_prog_t program = NULL; cl_program module = NULL; cl_kernel kernel = NULL; @@ -159,16 +148,16 @@ tinytc_status_t custom_kernel(cl_context context, cl_device_id device, cl_comman static const char source_text[] = "func @copy(%A: memref, %B: memref) {\n" " %gid = group_id\n" - " %a = subview %A[:,%gid] : memref\n" - " %b = subview %B[:,%gid] : memref\n" - " axpby.n 1, %a, 0, %b\n" + " %a = subview %A[0:" CHUNK_SIZE_S ",%gid] : memref\n" + " %b = subview %B[0:" CHUNK_SIZE_S ",%gid] : memref\n" + " %c0 = constant 0 -> i32\n" + " %c1 = constant 1 -> i32\n" + " axpby.n %c1, %a, %c0, %b\n" " : i32, memref, i32, memref\n" "}\n"; - CHECK(tinytc_source_context_create(&source_ctx)); - CHECK(tinytc_parse_string(&program, sizeof(source_text), source_text, source_ctx)); - CHECK(tinytc_cl_kernel_bundle_create_with_program(&module, context, device, program, 0u, - source_ctx)); + CHECK(tinytc_parse_string(&program, sizeof(source_text), source_text, NULL)); + CHECK(tinytc_cl_kernel_bundle_create_with_program(&module, context, device, program, 0u)); kernel = clCreateKernel(module, "copy", &err); CL_CHECK(err); @@ -211,14 +200,6 @@ tinytc_status_t custom_kernel(cl_context context, cl_device_id device, cl_comman clReleaseProgram(module); } tinytc_prog_release(program); - if (source_ctx) { - const char *error_log; - tinytc_source_context_get_error_log(source_ctx, &error_log); - if (error_log[0] != '\0') { - printf("\nError log:\n%s\n", error_log); - } - tinytc_source_context_release(source_ctx); - } tinytc_core_info_release(info); if (B) { clReleaseMemObject(B); diff --git a/include/tinytc/tinytc_cl.h b/include/tinytc/tinytc_cl.h index 148958b2..191e0b40 100644 --- a/include/tinytc/tinytc_cl.h +++ b/include/tinytc/tinytc_cl.h @@ -64,14 +64,13 @@ TINYTC_EXPORT tinytc_status_t tinytc_cl_core_info_create(tinytc_core_info_t *inf * @param context [in] context handle * @param device [in] device handle * @param src [in] source text and extensions - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_source( - cl_program *bundle, cl_context context, cl_device_id device, const_tinytc_source_t src, - tinytc_source_context_t source_ctx); +TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_source(cl_program *bundle, + cl_context context, + cl_device_id device, + const_tinytc_source_t src); /** * @brief Compile tensor program @@ -82,14 +81,12 @@ TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_source( * @param prg [inout] tensor program; modified as compiler passes are run * @param core_features [in][optional] requested core features; must be 0 (default) or a combination * of tinytc_core_feature_flag_t - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_program( cl_program *bundle, cl_context context, cl_device_id device, tinytc_prog_t prg, - tinytc_core_feature_flags_t core_features, tinytc_source_context_t source_ctx); + tinytc_core_feature_flags_t core_features); /** * @brief Create an OpenCL program from a tinytc binary @@ -98,14 +95,13 @@ TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_program( * @param context [in] context handle * @param device [in] device handle * @param bin [in] binary object - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_binary( - cl_program *bundle, cl_context context, cl_device_id device, const_tinytc_binary_t bin, - tinytc_source_context_t source_ctx); +TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_binary(cl_program *bundle, + cl_context context, + cl_device_id device, + const_tinytc_binary_t bin); /** * @brief Get work group size for kernel @@ -139,16 +135,13 @@ TINYTC_EXPORT void tinytc_cl_get_global_size(int64_t howmany, const size_t *loca * @param context [in] context handle * @param device [in] device handle * @param recipe [in] recipe object - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_cl_recipe_handler_create(tinytc_recipe_handler_t *handler, cl_context context, cl_device_id device, - tinytc_recipe_t recipe, - tinytc_source_context_t source_ctx); + tinytc_recipe_t recipe); /** * @brief Submit recipe to device diff --git a/include/tinytc/tinytc_cl.hpp b/include/tinytc/tinytc_cl.hpp index 93227d29..2165c4fd 100644 --- a/include/tinytc/tinytc_cl.hpp +++ b/include/tinytc/tinytc_cl.hpp @@ -83,15 +83,13 @@ template <> struct shared_handle_traits { * @param context Context * @param device Device * @param src Source - * @param source_ctx Source context for improved error reporting * * @return cl_program (shared handle) */ -inline auto make_kernel_bundle(cl_context context, cl_device_id device, source const &src, - source_context source_ctx = {}) -> shared_handle { +inline auto make_kernel_bundle(cl_context context, cl_device_id device, + source const &src) -> shared_handle { cl_program obj; - CHECK_STATUS(tinytc_cl_kernel_bundle_create_with_source(&obj, context, device, src.get(), - source_ctx.get())); + CHECK_STATUS(tinytc_cl_kernel_bundle_create_with_source(&obj, context, device, src.get())); return shared_handle{obj}; } @@ -103,16 +101,15 @@ inline auto make_kernel_bundle(cl_context context, cl_device_id device, source c * @param prg Program * @param core_features requested core features; must be 0 (default) or a combination of * tinytc_core_feature_flag_t - * @param source_ctx Source context for improved error reporting * * @return cl_program (shared handle) */ -inline auto make_kernel_bundle(cl_context context, cl_device_id device, prog prg, - tinytc_core_feature_flags_t core_features = 0, - source_context source_ctx = {}) -> shared_handle { +inline auto +make_kernel_bundle(cl_context context, cl_device_id device, prog prg, + tinytc_core_feature_flags_t core_features = 0) -> shared_handle { cl_program obj; CHECK_STATUS(tinytc_cl_kernel_bundle_create_with_program(&obj, context, device, prg.get(), - core_features, source_ctx.get())); + core_features)); return shared_handle{obj}; } @@ -122,15 +119,13 @@ inline auto make_kernel_bundle(cl_context context, cl_device_id device, prog prg * @param context Context * @param device Device * @param bin Binary - * @param source_ctx Source context for improved error reporting * * @return cl_program (shared handle) */ -inline auto make_kernel_bundle(cl_context context, cl_device_id device, binary const &bin, - source_context source_ctx = {}) -> shared_handle { +inline auto make_kernel_bundle(cl_context context, cl_device_id device, + binary const &bin) -> shared_handle { cl_program obj; - CHECK_STATUS(tinytc_cl_kernel_bundle_create_with_binary(&obj, context, device, bin.get(), - source_ctx.get())); + CHECK_STATUS(tinytc_cl_kernel_bundle_create_with_binary(&obj, context, device, bin.get())); return shared_handle{obj}; } @@ -242,15 +237,13 @@ class opencl_recipe_handler : public recipe_handler { * @param context Context * @param device Device * @param rec Recipe - * @param source_ctx Source context for improved error reporting * * @return OpenCL recipe handler */ -inline auto make_recipe_handler(cl_context context, cl_device_id device, recipe const &rec, - source_context source_ctx = {}) -> opencl_recipe_handler { +inline auto make_recipe_handler(cl_context context, cl_device_id device, + recipe const &rec) -> opencl_recipe_handler { tinytc_recipe_handler_t handler; - CHECK_STATUS( - tinytc_cl_recipe_handler_create(&handler, context, device, rec.get(), source_ctx.get())); + CHECK_STATUS(tinytc_cl_recipe_handler_create(&handler, context, device, rec.get())); return opencl_recipe_handler{handler}; } diff --git a/src/cl/kernel.cpp b/src/cl/kernel.cpp index 416d7a35..741dd780 100644 --- a/src/cl/kernel.cpp +++ b/src/cl/kernel.cpp @@ -3,6 +3,7 @@ #include "../compiler_options.hpp" #include "tinytc/tinytc.h" +#include "tinytc/tinytc.hpp" #include "tinytc/tinytc_cl.h" #include "tinytc/types.h" @@ -16,12 +17,13 @@ #include #include +using tinytc::compiler_context; + extern "C" { tinytc_status_t tinytc_cl_kernel_bundle_create_with_source(cl_program *bundle, cl_context context, cl_device_id device, - const_tinytc_source_t src, - tinytc_source_context_t source_ctx) { + const_tinytc_source_t src) { if (bundle == nullptr || src == nullptr) { return tinytc_status_invalid_arguments; } @@ -29,8 +31,11 @@ tinytc_status_t tinytc_cl_kernel_bundle_create_with_source(cl_program *bundle, c size_t length = 0; char const *code = nullptr; tinytc_core_feature_flags_t core_features = 0; - TINYTC_CL_CHECK_STATUS(tinytc_source_get_code(src, &length, &code)); - TINYTC_CL_CHECK_STATUS(tinytc_source_get_core_features(src, &core_features)); + tinytc_compiler_context_t ctx = nullptr; + TINYTC_CHECK_STATUS(tinytc_source_get_code(src, &length, &code)); + TINYTC_CHECK_STATUS(tinytc_source_get_core_features(src, &core_features)); + TINYTC_CHECK_STATUS(tinytc_source_get_compiler_context(src, &ctx)); + auto ctx_ = compiler_context{ctx}; // Clean-up ctx when ctx_ gets out of scope cl_int err; cl_program p = clCreateProgramWithSource(context, 1, &code, &length, &err); @@ -46,7 +51,7 @@ tinytc_status_t tinytc_cl_kernel_bundle_create_with_source(cl_program *bundle, c auto options_str = std::move(options).str(); if (err = clBuildProgram(p, 1, &device, options_str.c_str(), nullptr, nullptr); err != CL_SUCCESS) { - if (source_ctx) { + if (ctx_.get()) { std::string log; std::size_t log_size; clGetProgramBuildInfo(p, device, CL_PROGRAM_BUILD_LOG, 0, nullptr, &log_size); @@ -55,7 +60,7 @@ tinytc_status_t tinytc_cl_kernel_bundle_create_with_source(cl_program *bundle, c tinytc_location_t loc = {}; tinytc_source_get_location(src, &loc); - tinytc_source_context_report_error(source_ctx, &loc, log.c_str(), true); + tinytc_compiler_context_report_error(ctx_.get(), &loc, log.c_str()); } clReleaseProgram(p); TINYTC_CL_CHECK_STATUS(err); @@ -64,9 +69,10 @@ tinytc_status_t tinytc_cl_kernel_bundle_create_with_source(cl_program *bundle, c return tinytc_status_success; } -tinytc_status_t tinytc_cl_kernel_bundle_create_with_program( - cl_program *bundle, cl_context context, cl_device_id device, tinytc_prog_t prg, - tinytc_core_feature_flags_t core_features, tinytc_source_context_t source_ctx) { +tinytc_status_t +tinytc_cl_kernel_bundle_create_with_program(cl_program *bundle, cl_context context, + cl_device_id device, tinytc_prog_t prg, + tinytc_core_feature_flags_t core_features) { if (bundle == nullptr || prg == nullptr) { return tinytc_status_invalid_arguments; } @@ -82,12 +88,10 @@ tinytc_status_t tinytc_cl_kernel_bundle_create_with_program( status != tinytc_status_success) { goto err; } - if (status = tinytc_prog_compile_to_opencl(&src, prg, info, source_ctx); - status != tinytc_status_success) { + if (status = tinytc_prog_compile_to_opencl(&src, prg, info); status != tinytc_status_success) { goto err; } - if (status = - tinytc_cl_kernel_bundle_create_with_source(bundle, context, device, src, source_ctx); + if (status = tinytc_cl_kernel_bundle_create_with_source(bundle, context, device, src); status != tinytc_status_success) { goto err; } @@ -100,8 +104,7 @@ tinytc_status_t tinytc_cl_kernel_bundle_create_with_program( tinytc_status_t tinytc_cl_kernel_bundle_create_with_binary(cl_program *bundle, cl_context context, cl_device_id device, - const_tinytc_binary_t bin, - tinytc_source_context_t source_ctx) { + const_tinytc_binary_t bin) { if (bin == nullptr) { return tinytc_status_invalid_arguments; } @@ -122,12 +125,16 @@ tinytc_status_t tinytc_cl_kernel_bundle_create_with_binary(cl_program *bundle, c tinytc_core_feature_flags_t core_features; TINYTC_CHECK_STATUS(tinytc_binary_get_core_features(bin, &core_features)); + tinytc_compiler_context_t ctx = nullptr; + TINYTC_CHECK_STATUS(tinytc_binary_get_compiler_context(bin, &ctx)); + auto ctx_ = compiler_context{ctx}; // Clean-up ctx when ctx_ gets out of scope + char const *options = ""; if (core_features & tinytc_core_feature_flag_large_register_file) { options = tinytc::large_register_file_compiler_option_cl; } if (err = clBuildProgram(p, 1, &device, options, nullptr, nullptr); err != CL_SUCCESS) { - if (source_ctx) { + if (ctx_.get()) { std::string log; std::size_t log_size; clGetProgramBuildInfo(p, device, CL_PROGRAM_BUILD_LOG, 0, nullptr, &log_size); @@ -135,7 +142,7 @@ tinytc_status_t tinytc_cl_kernel_bundle_create_with_binary(cl_program *bundle, c clGetProgramBuildInfo(p, device, CL_PROGRAM_BUILD_LOG, log_size, log.data(), nullptr); tinytc_location_t loc = {}; - tinytc_source_context_report_error(source_ctx, &loc, log.c_str(), true); + tinytc_compiler_context_report_error(ctx_.get(), &loc, log.c_str()); } clReleaseProgram(p); TINYTC_CL_CHECK_STATUS(err); diff --git a/src/cl/recipe_handler.cpp b/src/cl/recipe_handler.cpp index e9f562f6..59761897 100644 --- a/src/cl/recipe_handler.cpp +++ b/src/cl/recipe_handler.cpp @@ -17,11 +17,10 @@ namespace tinytc { -cl_recipe_handler::cl_recipe_handler(cl_context context, cl_device_id device, recipe rec, - source_context source_ctx) +cl_recipe_handler::cl_recipe_handler(cl_context context, cl_device_id device, recipe rec) : ::tinytc_recipe_handler(std::move(rec)) { - module_ = make_kernel_bundle(context, device, get_recipe().get_source(), std::move(source_ctx)); + module_ = make_kernel_bundle(context, device, get_recipe().get_source()); auto const num_kernels = get_recipe()->num_kernels(); kernels_.reserve(num_kernels); @@ -69,14 +68,12 @@ extern "C" { tinytc_status_t tinytc_cl_recipe_handler_create(tinytc_recipe_handler_t *handler, cl_context context, cl_device_id device, - tinytc_recipe_t rec, - tinytc_source_context_t source_ctx) { + tinytc_recipe_t rec) { if (handler == nullptr || rec == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code_cl([&] { - *handler = std::make_unique(context, device, recipe(rec, true), - source_context(source_ctx, true)) + *handler = std::make_unique(context, device, recipe(rec, true)) .release(); }); } diff --git a/src/cl/recipe_handler.hpp b/src/cl/recipe_handler.hpp index 6c4ebea7..c2b03479 100644 --- a/src/cl/recipe_handler.hpp +++ b/src/cl/recipe_handler.hpp @@ -19,8 +19,7 @@ namespace tinytc { struct cl_recipe_handler : ::tinytc_recipe_handler { public: - cl_recipe_handler(cl_context context, cl_device_id device, recipe rec, - source_context source_ctx); + cl_recipe_handler(cl_context context, cl_device_id device, recipe rec); void active_kernel(int kernel_num) override; void arg(std::uint32_t arg_index, std::size_t arg_size, const void *arg_value) override; From 2c5321ca0802645251c5bc4d6bb3150796d38a09 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 14 Oct 2024 12:08:21 +0200 Subject: [PATCH 053/297] Fix rt tests Signed-off-by: Carsten Uphoff --- include/tinytc/tinytc_sycl.hpp | 35 ++++++++++------------- src/recipe/small_gemm_batched.cpp | 30 ++++++++++---------- src/sycl/kernel.cpp | 46 +++++++++++++------------------ src/sycl/recipe_handler.cpp | 19 +++++-------- src/sycl/recipe_handler.hpp | 3 +- 5 files changed, 57 insertions(+), 76 deletions(-) diff --git a/include/tinytc/tinytc_sycl.hpp b/include/tinytc/tinytc_sycl.hpp index 52bc2b80..9b53af7c 100644 --- a/include/tinytc/tinytc_sycl.hpp +++ b/include/tinytc/tinytc_sycl.hpp @@ -47,13 +47,12 @@ TINYTC_EXPORT auto make_core_info(sycl::device const &dev) -> core_info; * @param ctx Context * @param dev Device * @param src Source - * @param source_ctx Source context for improved error reporting * * @return SYCL kernel bundle */ -TINYTC_EXPORT auto make_kernel_bundle(sycl::context const &ctx, sycl::device const &dev, - source const &src, source_context source_ctx = {}) - -> sycl::kernel_bundle; +TINYTC_EXPORT auto +make_kernel_bundle(sycl::context const &ctx, sycl::device const &dev, + source const &src) -> sycl::kernel_bundle; /** * @brief Make SYCL kernel bundle from tinytc program @@ -63,13 +62,11 @@ TINYTC_EXPORT auto make_kernel_bundle(sycl::context const &ctx, sycl::device con * @param prg Program * @param core_features requested core features; must be 0 (default) or a combination of * tinytc_core_feature_flag_t - * @param source_ctx Source context for improved error reporting * * @return SYCL kernel bundle */ TINYTC_EXPORT auto make_kernel_bundle(sycl::context const &ctx, sycl::device const &dev, prog prg, - tinytc_core_feature_flags_t core_features = 0, - source_context source_ctx = {}) + tinytc_core_feature_flags_t core_features = 0) -> sycl::kernel_bundle; /** @@ -78,13 +75,12 @@ TINYTC_EXPORT auto make_kernel_bundle(sycl::context const &ctx, sycl::device con * @param ctx Context * @param dev Device * @param bin Binary - * @param source_ctx Source context for improved error reporting * * @return SYCL kernel bundle */ -TINYTC_EXPORT auto make_kernel_bundle(sycl::context const &ctx, sycl::device const &dev, - binary const &bin, source_context source_ctx = {}) - -> sycl::kernel_bundle; +TINYTC_EXPORT auto +make_kernel_bundle(sycl::context const &ctx, sycl::device const &dev, + binary const &bin) -> sycl::kernel_bundle; /** * @brief Make SYCL kernel @@ -114,8 +110,8 @@ TINYTC_EXPORT auto get_group_size(sycl::kernel const &krnl) -> sycl::range<3u>; * * @return Global size */ -TINYTC_EXPORT auto get_global_size(std::int64_t howmany, sycl::range<3u> const &local_size) - -> sycl::range<3u>; +TINYTC_EXPORT auto get_global_size(std::int64_t howmany, + sycl::range<3u> const &local_size) -> sycl::range<3u>; /** * @brief Get SYCL nd_range @@ -125,8 +121,8 @@ TINYTC_EXPORT auto get_global_size(std::int64_t howmany, sycl::range<3u> const & * * @return ND range */ -TINYTC_EXPORT auto get_execution_range(sycl::kernel const &krnl, std::int64_t howmany) - -> sycl::nd_range<3u>; +TINYTC_EXPORT auto get_execution_range(sycl::kernel const &krnl, + std::int64_t howmany) -> sycl::nd_range<3u>; //////////////////////////// ////////// Recipe ////////// @@ -179,24 +175,21 @@ class TINYTC_EXPORT sycl_recipe_handler : public recipe_handler { * @param ctx Context * @param dev Device * @param rec Recipe - * @param source_ctx Source context for improved error reporting * * @return SYCL recipe handler */ TINYTC_EXPORT auto make_recipe_handler(sycl::context const &ctx, sycl::device const &dev, - recipe const &rec, source_context source_ctx = {}) - -> sycl_recipe_handler; + recipe const &rec) -> sycl_recipe_handler; /** * @brief Make recipe handler * * @param q Queue * @param rec Recipe - * @param source_ctx Source context for improved error reporting * * @return SYCL recipe handler */ -TINYTC_EXPORT auto make_recipe_handler(sycl::queue const &q, recipe const &rec, - source_context source_ctx = {}) -> sycl_recipe_handler; +TINYTC_EXPORT auto make_recipe_handler(sycl::queue const &q, + recipe const &rec) -> sycl_recipe_handler; } // namespace tinytc diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index 7931fda6..5d71bab3 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -72,13 +72,14 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( ++l.end.column; return l; }; - - auto const selA = [&](std::int64_t N1, std::int64_t N2) { - return tA == tinytc_transpose_T ? N2 : N1; - }; - auto const selB = [&](std::int64_t N1, std::int64_t N2) { - return tB == tinytc_transpose_T ? N2 : N1; + auto const make_static_sizes = [](transpose t, int64_t A, std::int64_t B) { + auto s = std::array{A, B, 0}; + if (t == transpose::T) { + std::swap(s[0], s[1]); + } + return s; }; + return exception_to_status_code( [&] { auto const ty_ = get_scalar(ctx_, enum_cast(ty)); @@ -86,10 +87,15 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( auto const tB_ = enum_cast(tB); auto const kernel = [&](char const *name, bool is_beta_nonzero) { - auto A_ty = get_memref(ty_, {selA(M, K), selA(K, M), dynamic}, {1, ldA, strideA}, - address_space::global, my_loc()); - auto B_ty = get_memref(ty_, {selB(K, N), selB(N, K), dynamic}, {1, ldB, strideB}, - address_space::global, my_loc()); + auto const static_offsets = std::array{0, 0, dynamic}; + auto const A_static_sizes = make_static_sizes(tA_, M, K); + auto const B_static_sizes = make_static_sizes(tB_, K, N); + auto const C_static_sizes = make_static_sizes(transpose::N, M, N); + + auto A_ty = get_memref(ty_, {A_static_sizes[0], A_static_sizes[1], dynamic}, + {1, ldA, strideA}, address_space::global, my_loc()); + auto B_ty = get_memref(ty_, {B_static_sizes[0], B_static_sizes[1], dynamic}, + {1, ldB, strideB}, address_space::global, my_loc()); auto C_ty = get_memref(ty_, {M, N, dynamic}, {1, ldC, strideC}, address_space::global, my_loc()); auto f = make_func(name, {ty_, A_ty, B_ty, ty_, C_ty}, my_loc()); @@ -105,10 +111,6 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( auto bb = region_builder{fn_body}; auto gid = bb.add(make_group_id(ctx_, my_loc())); - auto const static_offsets = std::array{0, 0, dynamic}; - auto const A_static_sizes = std::array{M, K, 0}; - auto const B_static_sizes = std::array{K, N, 0}; - auto const C_static_sizes = std::array{M, N, 0}; auto a = bb.add(make_subview(params[1], static_offsets, A_static_sizes, array_view{gid}, {}, my_loc())); auto b = bb.add(make_subview(params[2], static_offsets, B_static_sizes, diff --git a/src/sycl/kernel.cpp b/src/sycl/kernel.cpp index ec2a5caa..995e341c 100644 --- a/src/sycl/kernel.cpp +++ b/src/sycl/kernel.cpp @@ -18,34 +18,28 @@ namespace tinytc { template struct kernel_bundle_dispatcher; template <> struct kernel_bundle_dispatcher { - template - auto operator()(context const &ctx, device const &dev, T const &obj, - source_context source_ctx) { + template auto operator()(context const &ctx, device const &dev, T const &obj) { auto native_context = get_native(ctx); auto native_device = get_native(dev); - auto native_mod = - make_kernel_bundle(native_context, native_device, obj, std::move(source_ctx)); + auto native_mod = make_kernel_bundle(native_context, native_device, obj); return make_kernel_bundle( {native_mod.release(), ext::oneapi::level_zero::ownership::transfer}, ctx); } auto operator()(context const &ctx, device const &dev, prog prg, - tinytc_core_feature_flags_t core_features, source_context source_ctx) { + tinytc_core_feature_flags_t core_features) { auto native_context = get_native(ctx); auto native_device = get_native(dev); - auto native_mod = make_kernel_bundle(native_context, native_device, std::move(prg), - core_features, std::move(source_ctx)); + auto native_mod = + make_kernel_bundle(native_context, native_device, std::move(prg), core_features); return make_kernel_bundle( {native_mod.release(), ext::oneapi::level_zero::ownership::transfer}, ctx); } }; template <> struct kernel_bundle_dispatcher { - template - auto operator()(context const &ctx, device const &dev, T const &obj, - source_context source_ctx) { + template auto operator()(context const &ctx, device const &dev, T const &obj) { auto native_context = get_native(ctx); auto native_device = get_native(dev); - auto native_mod = - make_kernel_bundle(native_context, native_device, obj, std::move(source_ctx)); + auto native_mod = make_kernel_bundle(native_context, native_device, obj); auto bundle = make_kernel_bundle(native_mod.get(), ctx); CL_CHECK_STATUS(clReleaseDevice(native_device)); @@ -53,11 +47,11 @@ template <> struct kernel_bundle_dispatcher { return bundle; } auto operator()(context const &ctx, device const &dev, prog prg, - tinytc_core_feature_flags_t core_features, source_context source_ctx) { + tinytc_core_feature_flags_t core_features) { auto native_context = get_native(ctx); auto native_device = get_native(dev); - auto native_mod = make_kernel_bundle(native_context, native_device, std::move(prg), - core_features, std::move(source_ctx)); + auto native_mod = + make_kernel_bundle(native_context, native_device, std::move(prg), core_features); auto bundle = make_kernel_bundle(native_mod.get(), ctx); CL_CHECK_STATUS(clReleaseDevice(native_device)); @@ -66,21 +60,19 @@ template <> struct kernel_bundle_dispatcher { } }; -auto make_kernel_bundle(context const &ctx, device const &dev, source const &src, - source_context source_ctx) -> kernel_bundle { - return dispatch(dev.get_backend(), ctx, dev, src, - std::move(source_ctx)); +auto make_kernel_bundle(context const &ctx, device const &dev, + source const &src) -> kernel_bundle { + return dispatch(dev.get_backend(), ctx, dev, src); } auto make_kernel_bundle(context const &ctx, device const &dev, prog prg, - tinytc_core_feature_flags_t core_features, - source_context source_ctx) -> kernel_bundle { + tinytc_core_feature_flags_t core_features) + -> kernel_bundle { return dispatch(dev.get_backend(), ctx, dev, std::move(prg), - core_features, std::move(source_ctx)); + core_features); } -auto make_kernel_bundle(context const &ctx, device const &dev, binary const &bin, - source_context source_ctx) -> kernel_bundle { - return dispatch(dev.get_backend(), ctx, dev, bin, - std::move(source_ctx)); +auto make_kernel_bundle(context const &ctx, device const &dev, + binary const &bin) -> kernel_bundle { + return dispatch(dev.get_backend(), ctx, dev, bin); } template struct kernel_dispatcher; diff --git a/src/sycl/recipe_handler.cpp b/src/sycl/recipe_handler.cpp index 799542e3..e005b72e 100644 --- a/src/sycl/recipe_handler.cpp +++ b/src/sycl/recipe_handler.cpp @@ -24,11 +24,9 @@ template <> struct arg_handler_dispatcher { }; sycl_recipe_handler_impl::sycl_recipe_handler_impl(sycl::context const &context, - sycl::device const &device, recipe rec, - source_context source_ctx) + sycl::device const &device, recipe rec) : ::tinytc_recipe_handler(std::move(rec)), - module_( - make_kernel_bundle(context, device, get_recipe().get_source(), std::move(source_ctx))) { + module_(make_kernel_bundle(context, device, get_recipe().get_source())) { auto const num_kernels = get_recipe()->num_kernels(); kernels_.reserve(num_kernels); @@ -68,19 +66,16 @@ auto sycl_recipe_handler_impl::local_size() const -> sycl::range<3u> const & { return local_size_[active_kernel_]; } -auto make_recipe_handler(sycl::context const &ctx, sycl::device const &dev, recipe const &rec, - source_context source_ctx) -> sycl_recipe_handler { +auto make_recipe_handler(sycl::context const &ctx, sycl::device const &dev, + recipe const &rec) -> sycl_recipe_handler { tinytc_recipe_handler_t handler = - std::make_unique(ctx, dev, rec, std::move(source_ctx)).release(); + std::make_unique(ctx, dev, rec).release(); return sycl_recipe_handler{handler}; } -auto make_recipe_handler(sycl::queue const &q, recipe const &rec, - source_context source_ctx) -> sycl_recipe_handler { +auto make_recipe_handler(sycl::queue const &q, recipe const &rec) -> sycl_recipe_handler { tinytc_recipe_handler_t handler = - std::make_unique(q.get_context(), q.get_device(), rec, - std::move(source_ctx)) - .release(); + std::make_unique(q.get_context(), q.get_device(), rec).release(); return sycl_recipe_handler{handler}; } diff --git a/src/sycl/recipe_handler.hpp b/src/sycl/recipe_handler.hpp index 3a8fd9dd..ca6ea182 100644 --- a/src/sycl/recipe_handler.hpp +++ b/src/sycl/recipe_handler.hpp @@ -16,8 +16,7 @@ namespace tinytc { struct sycl_recipe_handler_impl : ::tinytc_recipe_handler { public: - sycl_recipe_handler_impl(sycl::context const &context, sycl::device const &device, recipe rec, - source_context source_ctx); + sycl_recipe_handler_impl(sycl::context const &context, sycl::device const &device, recipe rec); void active_kernel(int kernel_num) override; void arg(std::uint32_t arg_index, std::size_t arg_size, const void *arg_value) override; From 9a420808675bb7a3a5a3a1cbdad55137834ae533 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 14 Oct 2024 15:13:17 +0200 Subject: [PATCH 054/297] Fix tinytc-bench Signed-off-by: Carsten Uphoff --- docs/api/sycl/cxxapi.rst | 40 +++---- docs/api/sycl/cxxapi.yaml | 10 +- examples/benchmark/CMakeLists.txt | 4 +- examples/benchmark/args.cpp | 97 --------------- examples/benchmark/args.hpp | 37 ------ examples/benchmark/main.cpp | 188 ++++++++++++++++++++++-------- examples/jit/main.cpp | 7 +- src/pass/dump_ir.cpp | 20 +++- tools/argparser/argparser.cpp | 7 +- tools/argparser/argparser.hpp | 35 ++++-- 10 files changed, 209 insertions(+), 236 deletions(-) delete mode 100644 examples/benchmark/args.cpp delete mode 100644 examples/benchmark/args.hpp diff --git a/docs/api/sycl/cxxapi.rst b/docs/api/sycl/cxxapi.rst index 82417e6c..2db5996f 100644 --- a/docs/api/sycl/cxxapi.rst +++ b/docs/api/sycl/cxxapi.rst @@ -40,11 +40,11 @@ Kernel * :ref:`make_kernel(sycl::kernel_bundle\ const &,char const \\*)` - * :ref:`make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &,source_context)` + * :ref:`make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &)` - * :ref:`make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t,source_context)` + * :ref:`make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t)` - * :ref:`make_kernel_bundle(sycl::context const &,sycl::device const &,source const &,source_context)` + * :ref:`make_kernel_bundle(sycl::context const &,sycl::device const &,source const &)` Kernel Functions ---------------- @@ -69,29 +69,29 @@ make_kernel(sycl::kernel_bundle const &,char con .. doxygenfunction:: tinytc::make_kernel(sycl::kernel_bundle const &,char const *) -make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &,source_context) -............................................................................................ +make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &) +............................................................................. -.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &,source_context) +.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &) -make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t,source_context) -.............................................................................................................. +make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t) +............................................................................................... -.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t,source_context) +.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t) -make_kernel_bundle(sycl::context const &,sycl::device const &,source const &,source_context) -............................................................................................ +make_kernel_bundle(sycl::context const &,sycl::device const &,source const &) +............................................................................. -.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,source const &,source_context) +.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,source const &) Recipe ====== * Functions - * :ref:`make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &,source_context)` + * :ref:`make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &)` - * :ref:`make_recipe_handler(sycl::queue const&,recipe const&,source_context)` + * :ref:`make_recipe_handler(sycl::queue const&,recipe const&)` * Classes @@ -100,15 +100,15 @@ Recipe Recipe Functions ---------------- -make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &,source_context) -............................................................................................. +make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &) +.............................................................................. -.. doxygenfunction:: tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &,source_context) +.. doxygenfunction:: tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &) -make_recipe_handler(sycl::queue const&,recipe const&,source_context) -.................................................................... +make_recipe_handler(sycl::queue const&,recipe const&) +..................................................... -.. doxygenfunction:: tinytc::make_recipe_handler(sycl::queue const&,recipe const&,source_context) +.. doxygenfunction:: tinytc::make_recipe_handler(sycl::queue const&,recipe const&) Recipe Classes -------------- diff --git a/docs/api/sycl/cxxapi.yaml b/docs/api/sycl/cxxapi.yaml index 2d3416f6..2d2a7a05 100644 --- a/docs/api/sycl/cxxapi.yaml +++ b/docs/api/sycl/cxxapi.yaml @@ -11,12 +11,12 @@ C++-API: - tinytc::get_global_size(std::int64_t,sycl::range<3u> const &) - tinytc::get_group_size(sycl::kernel const &) - tinytc::make_kernel(sycl::kernel_bundle const &,char const *) - - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &,source_context) - - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t,source_context) - - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,source const &,source_context) + - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &) + - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t) + - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,source const &) Recipe: function: - - tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &,source_context) - - tinytc::make_recipe_handler(sycl::queue const&,recipe const&,source_context) + - tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &) + - tinytc::make_recipe_handler(sycl::queue const&,recipe const&) class: - tinytc::sycl_recipe_handler diff --git a/examples/benchmark/CMakeLists.txt b/examples/benchmark/CMakeLists.txt index d3f351a4..81cc0224 100644 --- a/examples/benchmark/CMakeLists.txt +++ b/examples/benchmark/CMakeLists.txt @@ -5,7 +5,7 @@ include(CommonOptions) find_package(SYCL REQUIRED) -add_executable(tinytc-bench main.cpp args.cpp) +add_executable(tinytc-bench main.cpp) add_sycl_to_target(TARGET tinytc-bench SOURCES main.cpp) -target_link_libraries(tinytc-bench PRIVATE tinytc tinytc_sycl) +target_link_libraries(tinytc-bench PRIVATE tinytc tinytc_sycl argparser) set_cxx_common_options(tinytc-bench) diff --git a/examples/benchmark/args.cpp b/examples/benchmark/args.cpp deleted file mode 100644 index cdd1ca6e..00000000 --- a/examples/benchmark/args.cpp +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "args.hpp" - -#include -#include -#include -#include -#include - -args arg_parser::parse_args(int argc, char **argv) { - args a = {}; - a.internal_repetitions = 1; - a.ty = tinytc::scalar_type::f32; - a.transA = tinytc::transpose::N; - a.transB = tinytc::transpose::N; - a.beta = 0.0; - auto num = std::vector(3); - for (int i = 1; i < argc; ++i) { - if (argv[i][0] == '-') { - auto const fail = [&]() { - throw std::runtime_error("==> Error: unrecognized argument " + - std::string(argv[i])); - }; - if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0) { - a.help = true; - } else if (std::strcmp(argv[i], "--trans-a") == 0) { - a.transA = tinytc::transpose::T; - } else if (std::strcmp(argv[i], "--trans-b") == 0) { - a.transB = tinytc::transpose::T; - } else if (std::strcmp(argv[i], "-v") == 0 || std::strcmp(argv[i], "--verify") == 0) { - a.verify = true; - } else if (std::strcmp(argv[i], "-a") == 0 || std::strcmp(argv[i], "--atomic") == 0) { - a.atomic = true; - } else if (i + 1 < argc) { - if (std::strcmp(argv[i], "-i") == 0 || - std::strcmp(argv[i], "--internal-reps") == 0) { - a.internal_repetitions = atoi(argv[++i]); - } else if (std::strcmp(argv[i], "-b") == 0 || std::strcmp(argv[i], "--beta") == 0) { - ++i; - a.beta = atof(argv[i]); - } else if (std::strcmp(argv[i], "-p") == 0 || - std::strcmp(argv[i], "--precision") == 0) { - ++i; - if (argv[i][0] == 'd' || strcmp(argv[i], "f64") == 0) { - a.ty = tinytc::scalar_type::f64; - } else if (argv[i][0] == 's' || strcmp(argv[i], "f32") == 0) { - a.ty = tinytc::scalar_type::f32; - } else if (strcmp(argv[i], "c64") == 0) { - a.ty = tinytc::scalar_type::c64; - } else if (strcmp(argv[i], "c32") == 0) { - a.ty = tinytc::scalar_type::c32; - } else { - fail(); - } - } else { - fail(); - } - } else { - fail(); - } - } else { - num.clear(); - char const *delim = "x"; - auto arg = std::string(argv[i]); - char *token = std::strtok(argv[i], delim); - while (token) { - num.emplace_back(atoi(token)); - token = std::strtok(nullptr, delim); - } - if (num.size() != 3) { - throw std::runtime_error("==> Could not parse test case: " + arg); - } - a.tc.push_back({num[0], num[1], num[2]}); - } - } - - return a; -} - -void arg_parser::show_help(std::ostream &os) { - os << "usage: tinytcbench test-case1 test-case2 ..." << std::endl - << R"HELP( -positional arguments: - test-caseN MxNxK triplet (e.g. 64x64x64) - -optional arguments: - -h, --help Show help and quit - -i, --internal-reps Number of GEMM repetitions inside kernel (default: 1) - -p, --precision Precision (single = s or f32, double = d or f64, complex = c32, long complex = c64) - --trans-a Transpose A matrix - --trans-b Transpose B matrix - -v, --verify Verify optimized implementation - -a, --atomic Update C atomically -)HELP"; -} diff --git a/examples/benchmark/args.hpp b/examples/benchmark/args.hpp deleted file mode 100644 index 636c2fb3..00000000 --- a/examples/benchmark/args.hpp +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef ARGS_20230417_HPP -#define ARGS_20230417_HPP - -#include "tinytc/types.hpp" - -#include -#include -#include - -struct test_case { - std::int64_t m; - std::int64_t n; - std::int64_t k; -}; - -struct args { - std::vector tc; - int internal_repetitions; - tinytc::scalar_type ty; - bool help; - tinytc::transpose transA; - tinytc::transpose transB; - double beta; - bool verify; - bool atomic; -}; - -class arg_parser { - public: - static args parse_args(int argc, char **argv); - static void show_help(std::ostream &os); -}; - -#endif // ARGS_20230417_HPP diff --git a/examples/benchmark/main.cpp b/examples/benchmark/main.cpp index 2e86dfe8..825f0408 100644 --- a/examples/benchmark/main.cpp +++ b/examples/benchmark/main.cpp @@ -1,16 +1,17 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "args.hpp" - +#include #include #include #include #include +#include #include #include #include +#include #include #include #include @@ -22,6 +23,24 @@ using namespace sycl; using namespace tinytc; +struct test_case { + std::int64_t m; + std::int64_t n; + std::int64_t k; +}; + +struct args { + bool atomic = false; + bool dump = false; + int internal_repetitions = 1; + bool trans_a = false; + bool trans_b = false; + scalar_type ty = scalar_type::f32; + bool update = false; + bool verify = false; + std::vector tc; +}; + template double bench(F f, int nrepeat = 10) { f(); double min_exec_time_ns = std::numeric_limits::max(); @@ -39,10 +58,10 @@ template double bench(F f, int nrepeat = 10) { auto gemm_kernel_with_inner_repetition(scalar_type ty, transpose tA, transpose tB, bool atomic, std::int64_t M, std::int64_t N, std::int64_t K, std::array A_stride, - std::array B_stride, double beta, + std::array B_stride, bool update, std::array C_stride, - std::int32_t repetitions, queue q) -> source { - auto ctx = make_source_context(); + std::int32_t repetitions, bool dump, queue q) -> source { + auto ctx = make_compiler_context(); char const *file_name = std::source_location::current().file_name(); auto const source_id = ctx.add_source(file_name, ""); @@ -55,52 +74,62 @@ auto gemm_kernel_with_inner_repetition(scalar_type ty, transpose tA, transpose t ++l.end.column; return l; }; + auto const make_type = [](data_type element_ty, transpose t, int64_t A, std::int64_t B, + std::array const &stride, location const &loc) { + auto s = std::array{A, B}; + if (t == transpose::T) { + std::swap(s[0], s[1]); + } + auto mr = get_memref(element_ty, s, stride, address_space::global, loc); + return get_group(mr, 0, loc); + }; + + auto kernel = [&](compiler_context const &ctx) { + auto index_ty = get_scalar(ctx, scalar_type::index); + auto element_ty = get_scalar(ctx, ty); + auto A_ty = make_type(element_ty, tA, M, K, A_stride, my_loc()); + auto B_ty = make_type(element_ty, tB, K, N, B_stride, my_loc()); + auto C_ty = make_type(element_ty, transpose::N, M, N, C_stride, my_loc()); + auto f = make_func("gemm", {A_ty, B_ty, C_ty}, my_loc()); + auto fn_body = f.get_body(); + auto params = std::array{}; + fn_body.get_parameters(params); - auto kernel = [&](function_builder &fb) { - auto A = fb.argument( - make_group(make_memref( - ty, {M, K}, std::vector(A_stride.begin(), A_stride.end()), my_loc())), - "A", my_loc()); - auto B = fb.argument( - make_group(make_memref( - ty, {K, N}, std::vector(B_stride.begin(), B_stride.end()), my_loc())), - "B", my_loc()); - auto C = fb.argument( - make_group(make_memref( - ty, {M, N}, std::vector(C_stride.begin(), C_stride.end()), my_loc())), - "C", my_loc()); - fb.body( - [&](region_builder &bb) { - auto gid = bb.add(make_group_id(my_loc())); - auto a = bb.add(make_load(A, {gid}, my_loc())); - auto b = bb.add(make_load(B, {gid}, my_loc())); - auto c = bb.add(make_load(C, {gid}, my_loc())); - bb.for_loop( - scalar_type::index, make_index(0, my_loc()), make_index(repetitions, my_loc()), - [&](region_builder &bb, value const &) { - bb.add(make_gemm(tA, tB, atomic, make_imm(1.0, ty, my_loc()), a, b, - make_imm(beta, ty, my_loc()), c, my_loc())); - }, - "r", my_loc()); + auto bb = region_builder{fn_body}; + auto gid = bb.add(make_group_id(ctx, my_loc())); + auto from = bb.add(make_constant_zero(index_ty, my_loc())); + auto to = bb.add(make_constant(repetitions, index_ty, my_loc())); + auto calpha = bb.add(make_constant_one(element_ty, my_loc())); + auto cbeta = bb.add(update ? make_constant_one(element_ty, my_loc()) + : make_constant_zero(element_ty, my_loc())); + auto a = bb.add(make_load(params[0], {gid}, my_loc())); + auto b = bb.add(make_load(params[1], {gid}, my_loc())); + auto c = bb.add(make_load(params[2], {gid}, my_loc())); + bb.for_loop( + from, to, index_ty, + [&](region_builder &bb, value const &) { + bb.add(make_gemm(tA, tB, atomic, calpha, a, b, cbeta, c, my_loc())); }, my_loc()); + + return f; }; try { - auto pb = program_builder{}; - pb.create("gemm", kernel, my_loc()); + auto p = make_prog(ctx, my_loc()); + p.add_function(kernel(ctx)); + if (dump) { + p.dump(); + } auto info = make_core_info(q.get_device()); info.set_core_features(tinytc_core_feature_flag_large_register_file); - return compile_to_opencl(pb.get_product(my_loc()), info, ctx); + return compile_to_opencl(std::move(p), info); } catch (builder_error const &e) { ctx.report_error(e.loc(), e.what()); - std::cerr << "Error (" << static_cast(e.code()) << "): " << std::endl - << ctx.get_error_log() << std::endl; + std::cerr << "Error (" << static_cast(e.code()) << "): " << std::endl; } catch (status const &st) { - std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl - << "Error log:" << std::endl - << ctx.get_error_log() << std::endl; + std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; } return source{nullptr}; } @@ -152,9 +181,9 @@ template void test(queue q, args &a) { auto howmany = total_reals / max_reals; if (a.verify && a.internal_repetitions == 1) { + const bool trans_a = a.trans_a; + const bool trans_b = a.trans_b; q.submit([&](auto &h) { - bool transa = a.transA == transpose::T; - bool transb = a.transB == transpose::T; h.parallel_for(range{howmany, 32}, [=](id<2> it) { auto batch = it[0]; auto m = it[1]; @@ -165,8 +194,8 @@ template void test(queue q, args &a) { for (std::int64_t n = 0; n < c.n; ++n) { auto c_acc = T(0.0); for (std::int64_t k = 0; k < c.k; ++k) { - c_acc += a[transa ? k + mb * c.k : mb + k * c.m] * - b[transb ? n + k * c.n : k + n * c.k]; + c_acc += a[trans_a ? k + mb * c.k : mb + k * c.m] * + b[trans_b ? n + k * c.n : k + n * c.k]; } c_ref[mb + n * c.m] = c_acc; } @@ -188,10 +217,10 @@ template void test(queue q, args &a) { constexpr auto element_ty = to_scalar_type_v; try { auto src = gemm_kernel_with_inner_repetition( - element_ty, a.transA, a.transB, a.atomic, c.m, c.n, c.k, - {1, a.transA == transpose::T ? c.k : c.m}, - {1, a.transB == transpose::T ? c.n : c.k}, a.beta, {1, c.m}, a.internal_repetitions, - q); + element_ty, a.trans_a ? transpose::T : transpose::N, + a.trans_b ? transpose::T : transpose::N, a.atomic, c.m, c.n, c.k, + {1, a.trans_a ? c.k : c.m}, {1, a.trans_b ? c.n : c.k}, a.update, {1, c.m}, + a.internal_repetitions, a.dump, q); if (src) { auto bundle = make_kernel_bundle(q.get_context(), q.get_device(), src); auto kernel = make_kernel(bundle, "gemm"); @@ -255,15 +284,72 @@ template void test(queue q, args &a) { int main(int argc, char **argv) { auto a = args{}; + bool help = false; + + auto const convert_data_type = [](char const *str, scalar_type &val) -> cmd::parser_status { + if (std::strcmp(str, "f32") == 0) { + val = scalar_type::f32; + } else if (std::strcmp(str, "f64") == 0) { + val = scalar_type::f64; + } else if (std::strcmp(str, "c32") == 0) { + val = scalar_type::c32; + } else if (std::strcmp(str, "c64") == 0) { + val = scalar_type::c64; + } else { + return cmd::parser_status::invalid_argument; + } + return cmd::parser_status::success; + }; + auto const convert_test_case = [](char const *str, test_case &tc) { + auto const parse = [](std::int64_t *v, char const *str, char **end, char sep) { + *v = strtol(str, end, 10); + if (*v == 0 || **end != sep) { + throw cmd::parser_status::invalid_argument; + } + if (errno == ERANGE) { + throw cmd::parser_status::argument_out_of_range; + } + }; + char *end = nullptr; + try { + parse(&tc.m, str, &end, 'x'); + parse(&tc.n, end + 1, &end, 'x'); + parse(&tc.k, end + 1, &end, 0); + } catch (cmd::parser_status st) { + return st; + } + return cmd::parser_status::success; + }; + + auto parser = cmd::arg_parser{}; try { - a = arg_parser::parse_args(argc, argv); + parser.set_short_opt('a', &a.atomic, "Update C atomically"); + parser.set_short_opt('d', &a.dump, "Dump IR to stdout"); + parser.set_short_opt('f', &a.ty, "Data type (f32, f64, c32, c64)") + .converter(convert_data_type); + parser + .set_short_opt('i', &a.internal_repetitions, + "Number of GEMM repetitions inside kernel (default: 1)") + .validator([](std::int32_t rep) { return 0 <= rep; }); + parser.set_short_opt('h', &help, "Show help"); + parser.set_short_opt('u', &a.update, + "Add A*B to C (beta=1) instead of overwriting C (beta=0)"); + parser.set_short_opt('v', &a.verify, "Verify optimized implementation"); + parser.set_long_opt("help", &help, "Show help"); + parser.set_long_opt("transpose-a", &a.trans_a, "Transpose A matrix"); + parser.set_long_opt("transpose-b", &a.trans_b, "Transpose B matrix"); + parser.add_positional_arg("test-case", &a.tc, "MxNxK triplet (e.g. 64x64x64)") + .converter(convert_test_case) + .validator([](test_case const &tc) { return tc.m > 0 && tc.n > 0 && tc.k > 0; }); + + parser.parse(argc, argv); } catch (std::runtime_error const &e) { std::cerr << e.what() << std::endl; return -1; } - if (a.help || a.tc.empty()) { - arg_parser::show_help(std::cout); - return 0; + if (help || a.tc.empty()) { + parser.print_help(std::cout, "tinytc-bench", ""); + return !help ? -1 : 0; } auto q = queue{}; diff --git a/examples/jit/main.cpp b/examples/jit/main.cpp index 14aee5a3..01444c4e 100644 --- a/examples/jit/main.cpp +++ b/examples/jit/main.cpp @@ -15,18 +15,15 @@ int main(int argc, char **argv) { return -1; } - auto ctx = source_context{}; try { - ctx = make_source_context(); auto info = make_core_info_intel_from_arch(intel_gpu_architecture::pvc); - auto prog = parse_file(argv[1], ctx); + auto prog = parse_file(argv[1]); if (!prog) { return -1; } - compile_to_opencl(std::move(prog), info, ctx); + compile_to_opencl(std::move(prog), info); } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; - std::cerr << "Error log: " << std::endl << ctx.get_error_log() << std::endl; return 1; } catch (std::exception const &e) { std::cerr << e.what() << std::endl; diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index f91cb9ad..a122779f 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -63,6 +63,10 @@ void dump_ir_pass::dump_val(value_node const &v) { /* Inst nodes */ void dump_ir_pass::dump_blas_a2(blas_a2_inst const &g) { + if (g.atomic()) { + *os_ << ".atomic"; + } + *os_ << ' '; dump_val(g.alpha()); *os_ << ", "; dump_val(g.A()); @@ -81,6 +85,10 @@ void dump_ir_pass::dump_blas_a2(blas_a2_inst const &g) { } void dump_ir_pass::dump_blas_a3(blas_a3_inst const &g) { + if (g.atomic()) { + *os_ << ".atomic"; + } + *os_ << ' '; dump_val(g.alpha()); *os_ << ", "; dump_val(g.A()); @@ -110,7 +118,7 @@ void dump_ir_pass::operator()(alloca_inst const &a) { void dump_ir_pass::operator()(axpby_inst const &a) { *os_ << "axpby"; - *os_ << "." << to_string(a.tA()) << " "; + *os_ << "." << to_string(a.tA()); dump_blas_a2(static_cast(a)); } @@ -248,18 +256,18 @@ void dump_ir_pass::operator()(lifetime_stop_inst const &l) { void dump_ir_pass::operator()(gemm_inst const &g) { *os_ << "gemm"; *os_ << "." << to_string(g.tA()); - *os_ << "." << to_string(g.tB()) << " "; + *os_ << "." << to_string(g.tB()); dump_blas_a3(static_cast(g)); } void dump_ir_pass::operator()(gemv_inst const &g) { *os_ << "gemv"; - *os_ << "." << to_string(g.tA()) << " "; + *os_ << "." << to_string(g.tA()); dump_blas_a3(static_cast(g)); } void dump_ir_pass::operator()(ger_inst const &g) { - *os_ << "ger "; + *os_ << "ger"; dump_blas_a3(static_cast(g)); } @@ -294,7 +302,7 @@ void dump_ir_pass::operator()(foreach_inst const &p) { } void dump_ir_pass::operator()(hadamard_inst const &g) { - *os_ << "hadamard "; + *os_ << "hadamard"; dump_blas_a3(static_cast(g)); } @@ -395,7 +403,7 @@ void dump_ir_pass::operator()(store_inst const &e) { void dump_ir_pass::operator()(sum_inst const &a) { *os_ << "sum"; - *os_ << "." << to_string(a.tA()) << " "; + *os_ << "." << to_string(a.tA()); dump_blas_a2(static_cast(a)); } diff --git a/tools/argparser/argparser.cpp b/tools/argparser/argparser.cpp index f0b10b16..6fbc26fa 100644 --- a/tools/argparser/argparser.cpp +++ b/tools/argparser/argparser.cpp @@ -33,6 +33,8 @@ auto to_string(parser_status status) -> char const * { return "Non-default convertible type need converter functional"; case parser_status::invalid_argument: return "Invalid argument"; + case parser_status::validator_failed: + return "Value fails to comply with validation rules"; case parser_status::argument_out_of_range: return "Argument is out of range"; case parser_status::required_must_not_follow_optional: @@ -124,7 +126,10 @@ void arg_parser::parse(int argc, char **argv) { throw arg_parser_error(argc, argv, pos, subpos, parser_status::unknown_positional_arg); } auto &arg = positional_[positional_arg_index]; - arg.par->set(argv[pos]); + auto status = arg.par->set(argv[pos]); + if (status != parser_status::success) { + throw arg_parser_error(argc, argv, pos, subpos, status); + } if (!arg.par->does_store_multiple()) { ++positional_arg_index; } diff --git a/tools/argparser/argparser.hpp b/tools/argparser/argparser.hpp index cd6aa56c..9a1afc53 100644 --- a/tools/argparser/argparser.hpp +++ b/tools/argparser/argparser.hpp @@ -33,6 +33,7 @@ enum class parser_status { flag_does_not_take_argument, converter_functional_missing, invalid_argument, + validator_failed, argument_out_of_range, required_must_not_follow_optional, positional_must_not_follow_multiarg, @@ -97,8 +98,8 @@ template class par_model : public par_concept { } else { status = parser_status::required_argument_missing; } - if (validator_ && !validator_(*ptr_)) { - status = parser_status::invalid_argument; + if (validator_ && status == parser_status::success && !validator_(*ptr_)) { + status = parser_status::validator_failed; } return status; } @@ -106,8 +107,14 @@ template class par_model : public par_concept { auto is_argument_required() const -> bool override { return !default_argument_.has_value(); } auto does_store_multiple() const -> bool override { return false; } - template auto converter(F &&fun) { converter_ = std::forward(fun); } - template auto validator(F &&fun) { validator_ = std::forward(fun); } + template auto converter(F &&fun) -> par_model & { + converter_ = std::forward(fun); + return *this; + } + template auto validator(F &&fun) -> par_model & { + validator_ = std::forward(fun); + return *this; + } protected: T *ptr_; @@ -186,17 +193,21 @@ class arg_parser { template auto add_positional_arg(char const *opt, T *ptr, char const *help = nullptr, - bool required = false) { - add_positional_arg( - positional_arg{opt, help, - std::make_unique>( - ptr, required ? std::nullopt : std::make_optional(*ptr))}); + bool required = false) -> par_model & { + auto model = + std::make_unique>(ptr, required ? std::nullopt : std::make_optional(*ptr)); + auto model_ptr = model.get(); + add_positional_arg(positional_arg{opt, help, std::move(model)}); + return *model_ptr; } template - auto add_positional_arg(char const *opt, std::vector *ptr, char const *help = nullptr) { - add_positional_arg( - {opt, help, std::make_unique>>(ptr, std::make_optional(T{}))}); + auto add_positional_arg(char const *opt, std::vector *ptr, + char const *help = nullptr) -> par_model & { + auto model = std::make_unique>>(ptr, std::make_optional(T{})); + auto model_ptr = model.get(); + add_positional_arg({opt, help, std::move(model)}); + return *model_ptr; } void parse(int argc, char **argv); From 34290d87ec3b790f47133ace25f3a28302d39e7a Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 14 Oct 2024 16:08:40 +0200 Subject: [PATCH 055/297] Fix tall and skinny example Signed-off-by: Carsten Uphoff --- docs/api/core_cxxapi.rst | 23 ++++--- docs/api/core_cxxapi.yaml | 5 +- examples/benchmark/main.cpp | 51 ++------------ examples/gemm_common.hpp | 61 +++++++++++++++++ examples/tall_and_skinny/CMakeLists.txt | 4 +- examples/tall_and_skinny/args.cpp | 84 ----------------------- examples/tall_and_skinny/args.hpp | 33 --------- examples/tall_and_skinny/main.cpp | 89 +++++++++++++++++-------- include/tinytc/tinytc.hpp | 27 ++++++-- src/recipe/tall_and_skinny.cpp | 4 +- 10 files changed, 173 insertions(+), 208 deletions(-) create mode 100644 examples/gemm_common.hpp delete mode 100644 examples/tall_and_skinny/args.cpp delete mode 100644 examples/tall_and_skinny/args.hpp diff --git a/docs/api/core_cxxapi.rst b/docs/api/core_cxxapi.rst index a2681cc6..72513bc9 100644 --- a/docs/api/core_cxxapi.rst +++ b/docs/api/core_cxxapi.rst @@ -332,7 +332,7 @@ Recipe * :ref:`auto_mem_type` - * :ref:`auto_mem_type\\>\>` + * :ref:`auto_mem_type\\>\>` * :ref:`mem` @@ -340,7 +340,9 @@ Recipe * :ref:`auto_mem_type_v` - * :ref:`usm_pointer_type` + * :ref:`is_supported_scalar_type` + + * :ref:`is_usm_pointer_type` Recipe Enumerations ------------------- @@ -399,10 +401,10 @@ auto_mem_type .. doxygenstruct:: tinytc::auto_mem_type -auto_mem_type>> -....................................................... +auto_mem_type>> +.......................................................... -.. doxygenstruct:: tinytc::auto_mem_type< T, std::enable_if_t< usm_pointer_type< T > > > +.. doxygenstruct:: tinytc::auto_mem_type< T, std::enable_if_t< is_usm_pointer_type< T > > > mem ... @@ -417,10 +419,15 @@ auto_mem_type_v .. doxygenvariable:: tinytc::auto_mem_type_v -usm_pointer_type -................ +is_supported_scalar_type +........................ + +.. doxygenvariable:: tinytc::is_supported_scalar_type + +is_usm_pointer_type +................... -.. doxygenvariable:: tinytc::usm_pointer_type +.. doxygenvariable:: tinytc::is_usm_pointer_type Source ====== diff --git a/docs/api/core_cxxapi.yaml b/docs/api/core_cxxapi.yaml index fbca9723..64461f1b 100644 --- a/docs/api/core_cxxapi.yaml +++ b/docs/api/core_cxxapi.yaml @@ -65,11 +65,12 @@ Core C++-API: - tinytc::tall_and_skinny struct: - tinytc::auto_mem_type - - tinytc::auto_mem_type< T, std::enable_if_t< usm_pointer_type< T > > > + - tinytc::auto_mem_type< T, std::enable_if_t< is_usm_pointer_type< T > > > - tinytc::mem variable: - tinytc::auto_mem_type_v - - tinytc::usm_pointer_type + - tinytc::is_supported_scalar_type + - tinytc::is_usm_pointer_type Source: class: - tinytc::source diff --git a/examples/benchmark/main.cpp b/examples/benchmark/main.cpp index 825f0408..2e03e103 100644 --- a/examples/benchmark/main.cpp +++ b/examples/benchmark/main.cpp @@ -1,6 +1,8 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause +#include "../gemm_common.hpp" + #include #include #include @@ -23,12 +25,6 @@ using namespace sycl; using namespace tinytc; -struct test_case { - std::int64_t m; - std::int64_t n; - std::int64_t k; -}; - struct args { bool atomic = false; bool dump = false; @@ -38,7 +34,7 @@ struct args { scalar_type ty = scalar_type::f32; bool update = false; bool verify = false; - std::vector tc; + std::vector tc; }; template double bench(F f, int nrepeat = 10) { @@ -286,47 +282,12 @@ int main(int argc, char **argv) { auto a = args{}; bool help = false; - auto const convert_data_type = [](char const *str, scalar_type &val) -> cmd::parser_status { - if (std::strcmp(str, "f32") == 0) { - val = scalar_type::f32; - } else if (std::strcmp(str, "f64") == 0) { - val = scalar_type::f64; - } else if (std::strcmp(str, "c32") == 0) { - val = scalar_type::c32; - } else if (std::strcmp(str, "c64") == 0) { - val = scalar_type::c64; - } else { - return cmd::parser_status::invalid_argument; - } - return cmd::parser_status::success; - }; - auto const convert_test_case = [](char const *str, test_case &tc) { - auto const parse = [](std::int64_t *v, char const *str, char **end, char sep) { - *v = strtol(str, end, 10); - if (*v == 0 || **end != sep) { - throw cmd::parser_status::invalid_argument; - } - if (errno == ERANGE) { - throw cmd::parser_status::argument_out_of_range; - } - }; - char *end = nullptr; - try { - parse(&tc.m, str, &end, 'x'); - parse(&tc.n, end + 1, &end, 'x'); - parse(&tc.k, end + 1, &end, 0); - } catch (cmd::parser_status st) { - return st; - } - return cmd::parser_status::success; - }; - auto parser = cmd::arg_parser{}; try { parser.set_short_opt('a', &a.atomic, "Update C atomically"); parser.set_short_opt('d', &a.dump, "Dump IR to stdout"); parser.set_short_opt('f', &a.ty, "Data type (f32, f64, c32, c64)") - .converter(convert_data_type); + .converter(examples::convert_data_type); parser .set_short_opt('i', &a.internal_repetitions, "Number of GEMM repetitions inside kernel (default: 1)") @@ -339,8 +300,8 @@ int main(int argc, char **argv) { parser.set_long_opt("transpose-a", &a.trans_a, "Transpose A matrix"); parser.set_long_opt("transpose-b", &a.trans_b, "Transpose B matrix"); parser.add_positional_arg("test-case", &a.tc, "MxNxK triplet (e.g. 64x64x64)") - .converter(convert_test_case) - .validator([](test_case const &tc) { return tc.m > 0 && tc.n > 0 && tc.k > 0; }); + .converter(examples::convert_test_case) + .validator(examples::validate_test_case); parser.parse(argc, argv); } catch (std::runtime_error const &e) { diff --git a/examples/gemm_common.hpp b/examples/gemm_common.hpp new file mode 100644 index 00000000..b0efcdaf --- /dev/null +++ b/examples/gemm_common.hpp @@ -0,0 +1,61 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GEMM_COMMON_20241014_HPP +#define GEMM_COMMON_20241014_HPP + +#include "argparser.hpp" +#include "tinytc/types.hpp" + +#include +#include + +namespace tinytc::examples { + +struct test_case { + std::int64_t m; + std::int64_t n; + std::int64_t k; +}; + +inline auto convert_data_type(char const *str, scalar_type &val) -> cmd::parser_status { + if (std::strcmp(str, "f32") == 0) { + val = scalar_type::f32; + } else if (std::strcmp(str, "f64") == 0) { + val = scalar_type::f64; + } else if (std::strcmp(str, "c32") == 0) { + val = scalar_type::c32; + } else if (std::strcmp(str, "c64") == 0) { + val = scalar_type::c64; + } else { + return cmd::parser_status::invalid_argument; + } + return cmd::parser_status::success; +}; +inline auto convert_test_case(char const *str, test_case &tc) -> cmd::parser_status { + auto const parse = [](std::int64_t *v, char const *str, char **end, char sep) { + *v = strtol(str, end, 10); + if (*v == 0 || **end != sep) { + throw cmd::parser_status::invalid_argument; + } + if (errno == ERANGE) { + throw cmd::parser_status::argument_out_of_range; + } + }; + char *end = nullptr; + try { + parse(&tc.m, str, &end, 'x'); + parse(&tc.n, end + 1, &end, 'x'); + parse(&tc.k, end + 1, &end, 0); + } catch (cmd::parser_status st) { + return st; + } + return cmd::parser_status::success; +} +inline auto validate_test_case(test_case const &tc) -> bool { + return tc.m > 0 && tc.n > 0 && tc.k > 0; +}; + +} // namespace tinytc::examples + +#endif // GEMM_COMMON_20241014_HPP diff --git a/examples/tall_and_skinny/CMakeLists.txt b/examples/tall_and_skinny/CMakeLists.txt index f65f8efb..9defb5c6 100644 --- a/examples/tall_and_skinny/CMakeLists.txt +++ b/examples/tall_and_skinny/CMakeLists.txt @@ -5,7 +5,7 @@ include(CommonOptions) find_package(SYCL REQUIRED) -add_executable(tall_and_skinny main.cpp args.cpp) +add_executable(tall_and_skinny main.cpp) add_sycl_to_target(TARGET tall_and_skinny SOURCES main.cpp) -target_link_libraries(tall_and_skinny PRIVATE tinytc tinytc_sycl) +target_link_libraries(tall_and_skinny PRIVATE tinytc tinytc_sycl argparser) set_cxx_common_options(tall_and_skinny) diff --git a/examples/tall_and_skinny/args.cpp b/examples/tall_and_skinny/args.cpp deleted file mode 100644 index 6aa1268a..00000000 --- a/examples/tall_and_skinny/args.cpp +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "args.hpp" - -#include -#include -#include -#include -#include - -args arg_parser::parse_args(int argc, char **argv) { - args a = {}; - a.beta = 0.0; - a.specialize_M = false; - a.specialize_ld = false; - auto num = std::vector(3); - for (int i = 1; i < argc; ++i) { - if (argv[i][0] == '-') { - auto const fail = [&]() { - throw std::runtime_error("==> Error: unrecognized argument " + - std::string(argv[i])); - }; - if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0) { - a.help = true; - } else if (std::strcmp(argv[i], "-v") == 0 || std::strcmp(argv[i], "--verify") == 0) { - a.verify = true; - } else if (std::strcmp(argv[i], "--specialize-M") == 0) { - a.specialize_M = true; - } else if (std::strcmp(argv[i], "--specialize-ld") == 0) { - a.specialize_ld = true; - } else if (i + 1 < argc) { - if (std::strcmp(argv[i], "-b") == 0 || std::strcmp(argv[i], "--beta") == 0) { - ++i; - a.beta = atof(argv[i]); - } else if (std::strcmp(argv[i], "-p") == 0 || - std::strcmp(argv[i], "--precision") == 0) { - ++i; - if (argv[i][0] == 'd') { - a.double_precision = true; - } else if (argv[i][0] == 's') { - a.double_precision = false; - } else { - fail(); - } - } else { - fail(); - } - } else { - fail(); - } - } else { - num.clear(); - char const *delim = "x"; - auto arg = std::string(argv[i]); - char *token = std::strtok(argv[i], delim); - while (token) { - num.emplace_back(atoi(token)); - token = std::strtok(nullptr, delim); - } - if (num.size() != 3) { - throw std::runtime_error("==> Could not parse test case: " + arg); - } - a.tc.push_back({num[0], num[1], num[2]}); - } - } - - return a; -} - -void arg_parser::show_help(std::ostream &os) { - os << "usage: tall_and_skinny test-case1 test-case2 ..." << std::endl - << R"HELP( -positional arguments: - test-caseN MxNxK triplet (e.g. 300000x64x64) - -optional arguments: - -h, --help Show help and quit - -p, --precision Precision (single = s, double = d) - -v, --verify Verify optimized implementation - --specialize-M Specialize M instead of using dynamic value - --specialize-ld Specialize ldA, ldB, ldC instead of using dynamic value -)HELP"; -} diff --git a/examples/tall_and_skinny/args.hpp b/examples/tall_and_skinny/args.hpp deleted file mode 100644 index 701710fe..00000000 --- a/examples/tall_and_skinny/args.hpp +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef ARGS_20240215_HPP -#define ARGS_20240215_HPP - -#include -#include -#include - -struct test_case { - std::int64_t m; - std::int64_t n; - std::int64_t k; -}; - -struct args { - std::vector tc; - bool double_precision; - bool help; - bool verify; - double beta; - bool specialize_M; - bool specialize_ld; -}; - -class arg_parser { - public: - static args parse_args(int argc, char **argv); - static void show_help(std::ostream &os); -}; - -#endif // ARGS_20240215_HPP diff --git a/examples/tall_and_skinny/main.cpp b/examples/tall_and_skinny/main.cpp index 8538cf0e..ff3913c3 100644 --- a/examples/tall_and_skinny/main.cpp +++ b/examples/tall_and_skinny/main.cpp @@ -1,14 +1,16 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "args.hpp" +#include "../gemm_common.hpp" +#include #include #include #include #include #include +#include #include #include #include @@ -20,6 +22,16 @@ using namespace sycl; using namespace tinytc; +struct args { + bool dump = false; + bool specialize_M = false; + bool specialize_ld = false; + scalar_type ty = scalar_type::f32; + bool update = false; + bool verify = false; + std::vector tc; +}; + template double bench(F f, int nrepeat = 10) { f(); double min_exec_time_ns = std::numeric_limits::max(); @@ -67,7 +79,7 @@ template void test(queue q, args &a) { std::size_t num_err = 0; for (std::int64_t i = 0; i < M * N; ++i) { auto err = std::abs(C_host[i] - C_ref_host[i]); - if (err > 10.0 * std::numeric_limits::epsilon()) { + if (err > 10.0 * std::numeric_limits::epsilon()) { if (num_err < 10) { std::cout << i << " " << err << " " << C_host[i] << " " << C_ref_host[i] << std::endl; @@ -80,13 +92,12 @@ template void test(queue q, args &a) { } }; - auto const &type = typeid(T); for (auto &c : a.tc) { + auto beta = a.update ? T{1} : T{0}; if (a.verify) { q.memset(C, 0, c.m * c.n * sizeof(T)).wait(); q.memset(C_ref, 0, c.m * c.n * sizeof(T)).wait(); q.submit([&](auto &h) { - auto beta = a.beta; h.parallel_for(range{static_cast(c.n), static_cast(c.m)}, [=](id<2> it) { auto m = it[1]; @@ -100,9 +111,7 @@ template void test(queue q, args &a) { }).wait(); } - auto source_ctx = source_context{}; try { - source_ctx = make_source_context(); auto info = make_core_info(q.get_device()); info.set_core_features(tinytc_core_feature_flag_large_register_file); @@ -113,31 +122,29 @@ template void test(queue q, args &a) { ldB = c.k; ldC = c.m; } - auto tas = make_recipe_handler( - q, - make_tall_and_skinny_specialized(info, to_scalar_type_v, M, c.n, c.k, ldA, ldB, - ldC, 0, source_ctx), - source_ctx); + auto r = make_tall_and_skinny_specialized(info, a.ty, M, c.n, c.k, ldA, ldB, ldC, 0); + if (a.dump) { + r.get_prog().dump(); + } + auto tas = make_recipe_handler(q, r); - tall_and_skinny::set_args(tas, c.m, T(1.0), A, c.m, B, c.k, T(a.beta), C, c.m); + tall_and_skinny::set_args(tas, c.m, T{1}, A, c.m, B, c.k, beta, C, c.m); tas.submit(q).wait(); if (a.verify) { check(c.m, c.n); } double min_exec_time_ns = bench([&]() { tas.submit(q).wait(); }); - auto bw_C_factor = a.beta != 0.0 ? 2 : 1; + auto bw_C_factor = a.update ? 2 : 1; auto bw = sizeof(T) * (c.m * c.n * bw_C_factor + c.m * c.k + c.k * c.n) / min_exec_time_ns; auto gflops = 2 * c.m * c.n * c.k / min_exec_time_ns; - std::cout << type.name() << "," << c.m << "," << c.n << "," << c.k << "," << a.beta - << "," << min_exec_time_ns / 1e9 << "," << bw << "," << gflops << std::endl; + std::cout << to_string(a.ty) << "," << c.m << "," << c.n << "," << c.k << "," + << a.update << "," << min_exec_time_ns / 1e9 << "," << bw << "," << gflops + << std::endl; } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << tinytc::error_string(st) << std::endl; - if (source_ctx.get_error_log()[0] != '\0') { - std::cerr << "Error log: " << std::endl << source_ctx.get_error_log() << std::endl; - } } catch (std::exception const &e) { std::cerr << "Error: " << e.what() << std::endl; } @@ -151,25 +158,55 @@ template void test(queue q, args &a) { int main(int argc, char **argv) { auto a = args{}; + bool help = false; + + auto parser = cmd::arg_parser{}; try { - a = arg_parser::parse_args(argc, argv); + parser.set_short_opt('d', &a.dump, "Dump IR to stdout"); + parser.set_short_opt('f', &a.ty, "Data type (f32, f64, c32, c64)") + .converter(examples::convert_data_type); + parser.set_short_opt('h', &help, "Show help"); + parser.set_short_opt('u', &a.update, + "Add A*B to C (beta=1) instead of overwriting C (beta=0)"); + parser.set_short_opt('v', &a.verify, "Verify optimized implementation"); + parser.set_long_opt("help", &help, "Show help"); + parser.set_long_opt("specialize-m", &a.specialize_M, + "Specialize M instead of using dynamic value"); + parser.set_long_opt("specialize-ld", &a.specialize_ld, + "Specialize ldA, ldB, ldC instead of using dynamic value"); + parser.add_positional_arg("test-case", &a.tc, "MxNxK triplet (e.g. 300000x64x64)") + .converter(examples::convert_test_case) + .validator(examples::validate_test_case); + + parser.parse(argc, argv); } catch (std::runtime_error const &e) { std::cerr << e.what() << std::endl; return -1; } - if (a.help || a.tc.empty()) { - arg_parser::show_help(std::cout); - return 0; + if (help || a.tc.empty()) { + parser.print_help(std::cout, "tall_and_skinny", ""); + return !help ? -1 : 0; } auto q = queue{}; - std::cout << "precision,m,n,k,beta,time,bandwidth,gflops" << std::endl; + std::cout << "precision,m,n,k,update,time,bandwidth,gflops" << std::endl; try { - if (a.double_precision) { - test(std::move(q), a); - } else { + switch (a.ty) { + case scalar_type::f32: test(std::move(q), a); + break; + case scalar_type::f64: + test(std::move(q), a); + break; + case scalar_type::c32: + test>(std::move(q), a); + break; + case scalar_type::c64: + test>(std::move(q), a); + break; + default: + return -1; } } catch (std::exception const &e) { std::cerr << e.what() << std::endl; diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 8c8b4bac..4f12c6b3 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -2081,16 +2081,31 @@ inline auto compile_to_opencl(prog prg, core_info const &info) -> source { template struct auto_mem_type; /** - * @brief True if T is either pointer to a fundamental type or a pointer to a pointer to a - * fundamental type + * @brief Check whether T maps to a scalar data type * * @tparam T type */ template -constexpr bool usm_pointer_type = +constexpr bool is_supported_scalar_type = std::is_same_v || // i8 + std::is_same_v || // i16 + std::is_same_v || // i32 + std::is_same_v || // i64 + std::is_same_v || // f32 + std::is_same_v || // f64 + std::is_same_v> || // c32 + std::is_same_v>; // c64 + +/** + * @brief True if T is either pointer to a support scalar type or a pointer to a pointer to a + * supported scalar type + * + * @tparam T type + */ +template +constexpr bool is_usm_pointer_type = std::is_pointer_v && - (std::is_fundamental_v> || - std::is_fundamental_v>>); + (is_supported_scalar_type> || + is_supported_scalar_type>>); /** * @brief Specialize auto_mem_type for pointer to non-class types @@ -2100,7 +2115,7 @@ constexpr bool usm_pointer_type = * * @tparam T memory object type */ -template struct auto_mem_type>> { +template struct auto_mem_type>> { constexpr static mem_type value = mem_type::usm_pointer; ///< Pointer maps to USM pointer type }; diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 13a9aa27..dc796f45 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -112,7 +112,7 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( auto beta = is_beta_nonzero ? beta_arg : bb.add(make_constant_zero(ty_, my_loc())); auto const static_offsets = std::array{dynamic, 0}; - auto const offsets = array_view{m}; + auto const offsets = array_view(m); auto const static_gemm = [&](region_builder &bb) { auto const A_static_sizes = std::array{M_block_size, K}; @@ -127,7 +127,7 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( auto const dynamic_gemm = [&](region_builder &bb, value dyn_block_size) { auto const A_static_sizes = std::array{dynamic, K}; auto const C_static_sizes = std::array{dynamic, N}; - auto const sizes = array_view{dyn_block_size}; + auto const sizes = array_view(dyn_block_size); auto a = bb.add( make_subview(A, static_offsets, A_static_sizes, offsets, sizes, my_loc())); auto c = bb.add( From ceb62359d6656288e874b2f755353cdfb8fe1722 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 14 Oct 2024 18:05:48 +0200 Subject: [PATCH 056/297] Fix compilation of matrix chain example Signed-off-by: Carsten Uphoff --- examples/benchmark/main.cpp | 2 +- examples/matrix_chain/CMakeLists.txt | 2 +- examples/matrix_chain/main.cpp | 85 ++++++++------ examples/matrix_chain/matrix_batch.hpp | 32 +++--- examples/matrix_chain/test_ader.cpp | 138 ++++++++++++++--------- examples/matrix_chain/test_ader.hpp | 4 +- examples/matrix_chain/test_multi.cpp | 9 +- examples/matrix_chain/test_multi.hpp | 2 +- examples/matrix_chain/test_volume.cpp | 149 ++++++++++++------------- examples/matrix_chain/test_volume.hpp | 4 +- examples/tall_and_skinny/main.cpp | 2 +- include/tinytc/tinytc.hpp | 5 +- tools/argparser/argparser.hpp | 4 +- tools/offline_compiler/main.cpp | 2 +- tools/opt/main.cpp | 2 +- 15 files changed, 248 insertions(+), 194 deletions(-) diff --git a/examples/benchmark/main.cpp b/examples/benchmark/main.cpp index 2e03e103..18ab4aba 100644 --- a/examples/benchmark/main.cpp +++ b/examples/benchmark/main.cpp @@ -304,7 +304,7 @@ int main(int argc, char **argv) { .validator(examples::validate_test_case); parser.parse(argc, argv); - } catch (std::runtime_error const &e) { + } catch (std::exception const &e) { std::cerr << e.what() << std::endl; return -1; } diff --git a/examples/matrix_chain/CMakeLists.txt b/examples/matrix_chain/CMakeLists.txt index 7fa1358c..9f5e9eb0 100644 --- a/examples/matrix_chain/CMakeLists.txt +++ b/examples/matrix_chain/CMakeLists.txt @@ -14,5 +14,5 @@ set(SOURCES add_executable(matrix_chain ${SOURCES}) add_sycl_to_target(TARGET matrix_chain SOURCES ${SOURCES}) -target_link_libraries(matrix_chain PRIVATE tinytc tinytc_sycl) +target_link_libraries(matrix_chain PRIVATE tinytc tinytc_sycl argparser) set_cxx_common_options(matrix_chain) diff --git a/examples/matrix_chain/main.cpp b/examples/matrix_chain/main.cpp index e8062846..d78498e5 100644 --- a/examples/matrix_chain/main.cpp +++ b/examples/matrix_chain/main.cpp @@ -4,6 +4,7 @@ #include "test.hpp" #include "test_multi.hpp" +#include #include #include @@ -14,8 +15,58 @@ #include using namespace sycl; +using namespace tinytc; int main(int argc, char **argv) { + bool dump = false; + std::int64_t N = 5, P = 9, howmany; + std::size_t alignment = 0; + char precision = 's'; + test_case tc = test_case::volume; + bool help = false; + + auto parser = cmd::arg_parser{}; + try { + parser.set_short_opt('a', &alignment, "Alignment"); + parser.set_short_opt('d', &dump, "Dump IR to stdout"); + parser.set_short_opt('f', &precision, "Data type (s or d)").validator([](char f) { + return f == 's' || f == 'd'; + }); + parser.set_short_opt('h', &help, "Show help"); + parser.set_short_opt('N', &N, "Polynomial degree").validator([](std::int64_t p) { + return p > 0; + }); + parser.set_short_opt('P', &P, "Number of quantities").validator([](std::int64_t n) { + return n > 0; + }); + parser.set_long_opt("help", &help, "Show help"); + parser.add_positional_arg("test_case", &tc, "Test case (volume or ader)", true) + .converter([](char const *str, test_case &val) -> cmd::parser_status { + if (strcmp(str, "volume") == 0) { + val = test_case::volume; + } else if (strcmp(str, "ader") == 0) { + val = test_case::ader; + } else { + return cmd::parser_status::invalid_argument; + } + return cmd::parser_status::success; + }); + parser.add_positional_arg("howmany", &howmany, "Batch size", true) + .validator([](std::int64_t h) { return h > 0; }); + + parser.parse(argc, argv); + } catch (std::exception const &e) { + if (!help) { + std::cerr << e.what() << std::endl; + } + parser.print_help(std::cout, "matrix_chain", ""); + return help ? 0 : -1; + } + if (help) { + parser.print_help(std::cout, "matrix_chain", ""); + return 0; + } + auto devices = platform{}.get_devices(); auto sub_devices = std::vector{}; for (auto &device : devices) { @@ -33,41 +84,9 @@ int main(int argc, char **argv) { q.emplace_back(queue(device)); } - std::int64_t N = 5, P = 9, howmany; - std::size_t alignment = 0; - char precision = 's'; - test_case tc = test_case::volume; - - if (argc < 5) { - std::cerr << "Usage: matrix_chain

[alignment] [s/d]" - << std::endl; - return -1; - } - if (strcmp(argv[1], "volume") == 0) { - tc = test_case::volume; - } else if (strcmp(argv[1], "ader") == 0) { - tc = test_case::ader; - } else { - std::cerr << "Unknown test case " << argv[1] << ". Available are: ader, volume." - << std::endl; - return -1; - } - N = static_cast(std::atol(argv[2])); - P = static_cast(std::atol(argv[3])); - howmany = static_cast(std::atol(argv[4])); - if (argc >= 6) { - alignment = static_cast(std::atol(argv[5])); - } - if (argc >= 7) { - precision = argv[6][0]; - if (precision != 's' && precision != 'd') { - std::cerr << "Precision must be single (s) or double (d)" << std::endl; - return -1; - } - } auto run_test_multi = [&](auto precision) { using T = decltype(precision); - auto t = test_multi(N, P, howmany, alignment, tc, q); + auto t = test_multi(N, P, howmany, alignment, tc, q, dump); if (!t.check()) { std::cerr << "Result mismatch between reference and optimized!" << std::endl; // return; diff --git a/examples/matrix_chain/matrix_batch.hpp b/examples/matrix_chain/matrix_batch.hpp index f882b5b7..721f54e3 100644 --- a/examples/matrix_chain/matrix_batch.hpp +++ b/examples/matrix_chain/matrix_batch.hpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -16,30 +17,35 @@ template class matrix_batch { public: matrix_batch(std::int64_t nrows, std::int64_t ncols, std::int64_t ld, std::int64_t howmany, sycl::queue q) - : nrows_(nrows), ncols_(ncols), ld_(ld), howmany_(howmany), - data_(ld_ * ncols_ * howmany, std::move(q)) {} + : shape_{nrows, ncols}, ld_{ld}, howmany_{howmany}, + data_(stride() * howmany_, std::move(q)) {} inline T *get() { return data_.get(); } inline T const *get() const { return data_.get(); } - inline std::int64_t nrows() const { return nrows_; } - inline std::int64_t ncols() const { return ncols_; } - inline std::int64_t ld() const { return ld_; } + inline auto shape() const -> tinytc::array_view { return shape_; } + inline std::int64_t nrows() const { return shape_[0]; } + inline std::int64_t ncols() const { return shape_[1]; } inline std::int64_t howmany() const { return howmany_; } - inline std::int64_t stride() const { return ld_ * ncols_; } + inline std::int64_t ld() const { return ld_; } + inline std::int64_t stride() const { return ld_ * ncols(); } inline std::size_t size() const { return data_.size(); } inline void fill(T const &v) { data_.fill(v); } inline void random() { data_.random(); } - inline tinytc::data_type type(bool include_batch_dim = true) { - constexpr auto real_t = tinytc::to_scalar_type_v; - if (include_batch_dim && howmany() > 1) { - return tinytc::make_memref(real_t, {nrows(), ncols(), tinytc::dynamic}, - {1, ld(), stride()}); + inline auto type(tinytc::data_type element_ty) -> tinytc::data_type { + if (howmany_ == 1) { + return tinytc::get_memref(element_ty, {nrows(), ncols()}, {1, ld()}); } - return tinytc::make_memref(real_t, {nrows(), ncols()}, {1, ld()}); + return tinytc::get_memref(element_ty, {nrows(), ncols(), tinytc::dynamic}, + {1, ld(), stride()}); + } + inline auto local_type(tinytc::data_type element_ty) -> tinytc::data_type { + return tinytc::get_memref(element_ty, {nrows(), ncols()}, {1, ld()}, + tinytc::address_space::local); } private: - std::int64_t nrows_, ncols_, ld_, howmany_; + std::array shape_; + std::int64_t ld_, howmany_; device_array data_; }; diff --git a/examples/matrix_chain/test_ader.cpp b/examples/matrix_chain/test_ader.cpp index 78192ff9..2ff2ffb8 100644 --- a/examples/matrix_chain/test_ader.cpp +++ b/examples/matrix_chain/test_ader.cpp @@ -5,8 +5,8 @@ #include #include -#include #include +#include #include using namespace sycl; @@ -14,14 +14,15 @@ using namespace tinytc; template test_ader::test_ader(std::int64_t N, std::int64_t P, std::int64_t howmany, std::size_t alignment, - queue q) + queue q, bool dump) : N_(N), P_(P), howmany_(howmany), alignment_(alignment), q_(std::move(q)), dev_info_(make_core_info(q_.get_device())), I_ref_(Bd(), P_, Bd_aligned(), howmany_, q_), I_opt_(Bd(), P_, Bd_aligned(), howmany_, q_), tmp_(Bd(), P_, Bd_aligned(N_ - 1), howmany_, q_), A_(dim, matrix_batch(P_, P_, P_, howmany_, q_)), K_(dim, matrix_batch(Bd(), Bd(), Bd_aligned(N_ - 1), 1, q_)), dQ_(make_dQ()), - opt_bundle_(make_optimized_kernel()), opt_kernel_(make_kernel(opt_bundle_, "ader_kernel")) { + opt_bundle_(make_optimized_kernel(dump)), + opt_kernel_(make_kernel(opt_bundle_, "ader_kernel")) { I_ref_.random(); I_opt_.random(); for (auto &a : A_) { @@ -58,63 +59,90 @@ template std::vector> test_ader::make_dQ() { } template -auto test_ader::make_optimized_kernel() -> sycl::kernel_bundle { +auto test_ader::make_optimized_kernel(bool dump) + -> sycl::kernel_bundle { constexpr auto real_t = to_scalar_type_v; - auto opt_kernel = [&](function_builder &fb) { - T dt = 1.01; - T num = T(1.0); - int denom = 1; - std::array A; - std::array K; + auto opt_kernel = [&](compiler_context const &ctx) { + auto element_ty = get_scalar(ctx, real_t); + std::array param_types; + param_types[0] = element_ty; for (std::size_t i = 0; i < dim; ++i) { - A[i] = fb.argument(A_[i].type(), "A"); + param_types[1 + i] = A_[i].type(element_ty); } for (std::size_t i = 0; i < dim; ++i) { - K[i] = fb.argument(K_[i].type(), "K"); + param_types[1 + dim + i] = K_[i].type(element_ty); } - auto Q = fb.argument(dQ_[0].type(), "dQ"); - auto I = fb.argument(I_opt_.type(), "I"); - fb.body([&](region_builder &bb) { - auto const gid = bb.add(make_group_id()); - auto const offsets3 = std::vector{make_index(0), make_index(0), gid}; - auto const size3 = std::vector{make_dynamic(), make_dynamic(), value{}}; - auto dq = bb.add(make_subview(Q, offsets3, size3)); + param_types[1 + 2 * dim + 0] = dQ_[0].type(element_ty); + param_types[1 + 2 * dim + 1] = I_opt_.type(element_ty); + + auto f = make_func("ader_kernel", param_types); + auto fn_body = f.get_body(); + + std::array params; + fn_body.get_parameters(params); + + auto dt = params[0]; + auto A = [¶ms](std::size_t i) -> value & { return params[1 + i]; }; + auto K = [¶ms](std::size_t i) -> value & { return params[1 + dim + i]; }; + auto Q = params[1 + 2 * dim + 0]; + auto I = params[1 + 2 * dim + 1]; + for (std::size_t i = 0; i < dim; ++i) { + A(i).set_name((std::ostringstream{} << 'A' << i).str()); + K(i).set_name((std::ostringstream{} << 'K' << i).str()); + } + Q.set_name("Q"); + I.set_name("I"); + + T dt0 = T{1.01}; + T num = T{1}; + int denom = 1; + + auto bb = region_builder{fn_body}; + auto const c0 = bb.add(make_constant_zero(element_ty)); + auto const c1 = bb.add(make_constant_one(element_ty)); + auto const gid = bb.add(make_group_id(ctx)); + auto const static_offsets3 = std::array{0, 0, dynamic}; + auto const static_sizes3 = [](matrix_batch const &b) -> std::array { + return {b.nrows(), b.ncols(), 0}; + }; + auto const offsets3 = array_view(gid); + auto dq = bb.add(make_subview(Q, static_offsets3, static_sizes3(dQ_[0]), offsets3)); + for (std::size_t d = 0; d < dim; ++d) { + A(d) = bb.add(make_subview(A(d), static_offsets3, static_sizes3(A_[d]), offsets3)); + } + auto i = bb.add(make_subview(I, static_offsets3, static_sizes3(I_opt_), offsets3)); + bb.add(make_axpby(transpose::N, false, c1, dq, c1, i)); + + auto const static_offsets2 = std::array{0, 0}; + for (std::int64_t n = 1; n <= N_; ++n) { + num *= dt0; + denom *= n + 1; + auto bn = Bd_aligned(N_ - n); + auto dq_next = bb.add(make_alloca(dQ_[n].local_type(element_ty))); + auto dq_nextv = bb.add(make_subview(dq_next, static_offsets2, {bn, P_})); + auto tmp = bb.add( + make_alloca(get_memref(element_ty, {bn, P_}, {1, bn}, address_space::local))); for (std::size_t d = 0; d < dim; ++d) { - A[d] = bb.add(make_subview(A[d], offsets3, size3)); + auto Kv = bb.add(make_subview(K(d), static_offsets2, {bn, Bd(N_ - n + 1)})); + bb.add(make_gemm(transpose::N, transpose::N, false, c1, Kv, dq, c0, tmp)); + bb.add(make_gemm(transpose::N, transpose::N, false, c1, tmp, A(d), d > 0 ? c1 : c0, + dq_nextv)); } - auto i = bb.add(make_subview(I, offsets3, size3)); - bb.add(make_axpby(transpose::N, false, make_imm(num / denom), dq, make_imm(T(1.0)), i)); - - auto const offsets2 = std::vector{make_index(0), make_index(0)}; - for (std::int64_t n = 1; n <= N_; ++n) { - num *= dt; - denom *= n + 1; - auto bn = Bd_aligned(N_ - n); - auto dq_next = bb.add(make_alloca(dQ_[n].type(false))); - auto dq_nextv = - bb.add(make_subview(dq_next, offsets2, {make_index(bn), make_index(P_)})); - auto tmp = bb.add(make_alloca(make_memref(real_t, {bn, P_}, {1, bn}))); - for (std::size_t d = 0; d < dim; ++d) { - auto Kv = bb.add( - make_subview(K[d], offsets2, {make_index(bn), make_index(Bd(N_ - n + 1))})); - bb.add(make_gemm(transpose::N, transpose::N, false, make_imm(T(1.0)), Kv, dq, - make_imm(T(0.0)), tmp)); - bb.add(make_gemm(transpose::N, transpose::N, false, make_imm(T(1.0)), tmp, A[d], - make_imm(T(d > 0 ? 1.0 : 0.0)), dq_nextv)); - } - auto iv = - bb.add(make_subview(i, offsets2, {make_index(Bd(N_ - n)), make_index(P_)})); - bb.add(make_axpby(transpose::N, false, make_imm(num / denom), dq_next, - make_imm(T(1.0)), iv)); - dq = dq_next; - } - }); - }; - auto pb = program_builder{}; - pb.create("ader_kernel", opt_kernel); + auto iv = bb.add(make_subview(i, static_offsets2, {Bd(N_ - n), P_})); + auto cfactor = bb.add(make_constant(num / denom, element_ty)); + bb.add(make_axpby(transpose::N, false, cfactor, dq_next, c1, iv)); + dq = dq_next; + } - return make_kernel_bundle(q_.get_context(), q_.get_device(), - compile_to_opencl(pb.get_product(), dev_info_)); + return f; + }; + auto ctx = make_compiler_context(); + auto p = make_prog(ctx); + p.add_function(opt_kernel(ctx)); + if (dump) { + p.dump(); + } + return make_kernel_bundle(q_.get_context(), q_.get_device(), compile_to_opencl(p, dev_info_)); } template @@ -157,10 +185,12 @@ template std::vector test_ader::reference() { } template std::vector test_ader::optimized() { + T dt = 1.01; auto exe_range = get_execution_range(opt_kernel_, howmany_); return {q_.submit([&](handler &h) { - h.set_args(A_[0].get(), howmany_, A_[1].get(), howmany_, A_[2].get(), howmany_, K_[0].get(), - K_[1].get(), K_[2].get(), dQ_[0].get(), howmany_, I_opt_.get(), howmany_); + h.set_args(dt, A_[0].get(), howmany_, A_[1].get(), howmany_, A_[2].get(), howmany_, + K_[0].get(), K_[1].get(), K_[2].get(), dQ_[0].get(), howmany_, I_opt_.get(), + howmany_); h.parallel_for(exe_range, opt_kernel_); })}; } diff --git a/examples/matrix_chain/test_ader.hpp b/examples/matrix_chain/test_ader.hpp index 0195c975..84100f89 100644 --- a/examples/matrix_chain/test_ader.hpp +++ b/examples/matrix_chain/test_ader.hpp @@ -20,7 +20,7 @@ template class test_ader : public test { public: test_ader(std::int64_t N, std::int64_t P, std::int64_t howmany, std::size_t alignment, - sycl::queue q); + sycl::queue q, bool dump = false); ~test_ader() = default; test_ader(test_ader const &other) = delete; test_ader(test_ader &&other) = default; @@ -41,7 +41,7 @@ template class test_ader : public test { inline std::int64_t Bd_aligned() { return aligned(Bd(N_), alignment_); } inline std::int64_t Bd_aligned(std::int64_t N) { return aligned(Bd(N), alignment_); } std::vector> make_dQ(); - auto make_optimized_kernel() -> sycl::kernel_bundle; + auto make_optimized_kernel(bool dump) -> sycl::kernel_bundle; sycl::event taylor_sum(matrix_batch &I, matrix_batch &dQ, T factor, std::vector const &dep_events = {}); diff --git a/examples/matrix_chain/test_multi.cpp b/examples/matrix_chain/test_multi.cpp index c0d30308..34fbc50f 100644 --- a/examples/matrix_chain/test_multi.cpp +++ b/examples/matrix_chain/test_multi.cpp @@ -25,14 +25,17 @@ template double bench(F f, int nrepeat = 10) { template test_multi::test_multi(std::int64_t N, std::int64_t P, std::int64_t howmany, - std::size_t alignment, test_case tc, std::vector const &q) { + std::size_t alignment, test_case tc, std::vector const &q, + bool dump) { for (auto &qu : q) { switch (tc) { case test_case::ader: - instances_.emplace_back(std::make_unique>(N, P, howmany, alignment, qu)); + instances_.emplace_back( + std::make_unique>(N, P, howmany, alignment, qu, dump)); break; case test_case::volume: - instances_.emplace_back(std::make_unique>(N, P, howmany, alignment, qu)); + instances_.emplace_back( + std::make_unique>(N, P, howmany, alignment, qu, dump)); break; default: break; diff --git a/examples/matrix_chain/test_multi.hpp b/examples/matrix_chain/test_multi.hpp index 1bd4beac..b7b3027f 100644 --- a/examples/matrix_chain/test_multi.hpp +++ b/examples/matrix_chain/test_multi.hpp @@ -17,7 +17,7 @@ enum class test_case { volume, ader }; template class test_multi { public: test_multi(std::int64_t N, std::int64_t P, std::int64_t howmany, std::size_t alignment, - test_case tc, std::vector<::sycl::queue> const &q); + test_case tc, std::vector<::sycl::queue> const &q, bool dump = false); void reference(); void optimized(); diff --git a/examples/matrix_chain/test_volume.cpp b/examples/matrix_chain/test_volume.cpp index 479e5e03..af463662 100644 --- a/examples/matrix_chain/test_volume.cpp +++ b/examples/matrix_chain/test_volume.cpp @@ -6,8 +6,8 @@ #include #include -#include #include +#include #include using namespace sycl; @@ -15,14 +15,15 @@ using namespace tinytc; template test_volume::test_volume(std::int64_t N, std::int64_t P, std::int64_t howmany, - std::size_t alignment, queue q) + std::size_t alignment, queue q, bool dump) : B3_(num_basis(N, dim)), B2_(num_basis(N - 1, dim)), P_(P), howmany_(howmany), B3_aligned_(aligned(B3_, alignment)), B2_aligned_(aligned(B2_, alignment)), q_(std::move(q)), dev_info_(make_core_info(q_.get_device())), Q_ref_(B3_, P_, B3_aligned_, howmany_, q_), Q_opt_(B3_, P_, B3_aligned_, howmany_, q_), I_(B3_, P_, B3_aligned_, howmany_, q_), tmp_(B3_, P_, B2_aligned_, howmany_, q_), A_(dim, matrix_batch(P_, P_, P_, howmany_, q_)), - K_(dim, matrix_batch(B3_, B3_, B3_aligned_, 1, q_)), opt_bundle_(make_optimized_kernel()), + K_(dim, matrix_batch(B3_, B3_, B3_aligned_, 1, q_)), + opt_bundle_(make_optimized_kernel(dump)), opt_kernel_(make_kernel(opt_bundle_, "volume_kernel")) { Q_ref_.random(); Q_opt_.random(); @@ -46,84 +47,78 @@ test_volume::test_volume(std::int64_t N, std::int64_t P, std::int64_t howmany } template -auto test_volume::make_optimized_kernel() +auto test_volume::make_optimized_kernel(bool dump) -> sycl::kernel_bundle { constexpr auto real_t = to_scalar_type_v; - /** - * With B3_ = 56, B3_aligned_ = 64, B2_ = 35, B2_aligned_ = 48, P_ = 9 - * - * func chain(A0: batch, distance<81>>, - * A1: batch, distance<81>>, - * A2: batch, distance<81>>, - * K0: memref, - * K1: memref, - * K2: memref, - * Q: batch, distance<576>>, - * I: batch, distance<576>>) { - * a0 = get_work_item A0 - * a1 = get_work_item A1 - * a2 = get_work_item A2 - * q = get_work_item Q - * i = get_work_item i - * tmp = alloca matrix; - * K0v = submatrix K0[0:64,0:35] - * K1v = submatrix K1[0:64,0:35] - * K2v = submatrix K2[0:64,0:35] - * qv = submatrix Q[0:64,0:9] - * iv = submatrix I[0:48,0:9] - * tmpv = submatrix tmp[0:48,0:9] - * matmul(iv, a0, tmpv, 1.0, 0.0); - * matmul(K0v, tmp, qv, 1.0, 0.0); - * matmul(iv, a1, tmpv, 1.0, 0.0); - * matmul(K1v, tmp, qv, 1.0, 0.0); - * matmul(iv, a2, tmpv, 1.0, 0.0); - * matmul(K2v, tmp, qv, 1.0, 0.0); - * } - */ // Optimized kernel - auto opt_kernel = [&](function_builder &fb) { - auto A0 = fb.argument(A_[0].type(), "A0"); - auto A1 = fb.argument(A_[1].type(), "A1"); - auto A2 = fb.argument(A_[2].type(), "A2"); - auto K0 = fb.argument(K_[0].type(), "K0"); - auto K1 = fb.argument(K_[1].type(), "K1"); - auto K2 = fb.argument(K_[2].type(), "K2"); - auto Q = fb.argument(Q_opt_.type(), "Q"); - auto I = fb.argument(I_.type(), "I"); - fb.body([&](region_builder &bb) { - auto gid = bb.add(make_group_id()); - auto const offsets2 = std::vector{make_index(0), make_index(0)}; - auto const offsets3 = std::vector{make_index(0), make_index(0), gid}; - auto const size3 = std::vector{make_dynamic(), make_dynamic(), value{}}; - auto const sizeK2 = std::vector{make_index(B3_aligned_), make_index(B2_)}; - auto tmp = bb.add(make_alloca(make_memref(real_t, {B2_, P_}, {1, B2_aligned_}))); - auto a0 = bb.add(make_subview(A0, offsets3, size3)); - auto a1 = bb.add(make_subview(A1, offsets3, size3)); - auto a2 = bb.add(make_subview(A2, offsets3, size3)); - auto K0v = bb.add(make_subview(K0, offsets2, sizeK2)); - auto K1v = bb.add(make_subview(K1, offsets2, sizeK2)); - auto K2v = bb.add(make_subview(K2, offsets2, sizeK2)); - auto qv = bb.add( - make_subview(Q, offsets3, {make_index(B3_aligned_), make_index(P_), value{}})); - auto iv = bb.add( - make_subview(I, offsets3, {make_index(B2_aligned_), make_index(P_), value{}})); - auto tmpv = - bb.add(make_subview(tmp, offsets2, {make_index(B2_aligned_), make_index(P_)})); - auto const s0 = make_imm(T(0.0)); - auto const s1 = make_imm(T(1.0)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, iv, a0, s0, tmpv)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, K0v, tmp, s1, qv)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, iv, a1, s0, tmpv)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, K1v, tmp, s1, qv)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, iv, a2, s0, tmpv)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, K2v, tmp, s1, qv)); - }); - }; - auto pb = program_builder{}; - pb.create("volume_kernel", opt_kernel); + auto opt_kernel = [&](compiler_context const &ctx) { + auto element_ty = get_scalar(ctx, real_t); + std::array param_types; + for (std::size_t i = 0; i < dim; ++i) { + param_types[i] = A_[i].type(element_ty); + } + for (std::size_t i = 0; i < dim; ++i) { + param_types[dim + i] = K_[i].type(element_ty); + } + param_types[2 * dim + 0] = Q_opt_.type(element_ty); + param_types[2 * dim + 1] = I_.type(element_ty); + + auto f = make_func("volume_kernel", param_types); + auto fn_body = f.get_body(); + + std::array params; + fn_body.get_parameters(params); + + auto A = [¶ms](std::size_t i) -> value & { return params[i]; }; + auto K = [¶ms](std::size_t i) -> value & { return params[dim + i]; }; + auto Q = params[2 * dim + 0]; + auto I = params[2 * dim + 1]; - return make_kernel_bundle(q_.get_context(), q_.get_device(), - compile_to_opencl(pb.get_product(), dev_info_)); + for (std::size_t i = 0; i < dim; ++i) { + A(i).set_name((std::ostringstream{} << 'A' << i).str()); + K(i).set_name((std::ostringstream{} << 'K' << i).str()); + } + Q.set_name("Q"); + I.set_name("I"); + + auto bb = region_builder{fn_body}; + auto gid = bb.add(make_group_id(ctx)); + auto const static_offsets2 = std::array{0, 0}; + auto const static_offsets3 = std::array{0, 0, dynamic}; + auto const static_sizes3 = [](matrix_batch const &b) -> std::array { + return {b.nrows(), b.ncols(), 0}; + }; + auto const offsets3 = array_view(gid); + auto const sizeK2 = std::array{B3_aligned_, B2_}; + auto tmp = bb.add( + make_alloca(get_memref(element_ty, {B2_aligned_, P_}, {}, address_space::local))); + auto a0 = bb.add(make_subview(A(0), static_offsets3, static_sizes3(A_[0]), offsets3)); + auto a1 = bb.add(make_subview(A(1), static_offsets3, static_sizes3(A_[1]), offsets3)); + auto a2 = bb.add(make_subview(A(2), static_offsets3, static_sizes3(A_[2]), offsets3)); + auto k0 = bb.add(make_subview(K(0), static_offsets2, sizeK2)); + auto k1 = bb.add(make_subview(K(1), static_offsets2, sizeK2)); + auto k2 = bb.add(make_subview(K(2), static_offsets2, sizeK2)); + auto qv = bb.add(make_subview(Q, static_offsets3, {B3_aligned_, P_, 0}, offsets3)); + auto iv = bb.add(make_subview(I, static_offsets3, {B2_aligned_, P_, 0}, offsets3)); + auto tmpv = bb.add(make_subview(tmp, static_offsets2, {B2_, P_})); + auto const c0 = bb.add(make_constant_zero(element_ty)); + auto const c1 = bb.add(make_constant_one(element_ty)); + bb.add(make_gemm(transpose::N, transpose::N, false, c1, iv, a0, c0, tmp)); + bb.add(make_gemm(transpose::N, transpose::N, false, c1, k0, tmpv, c1, qv)); + bb.add(make_gemm(transpose::N, transpose::N, false, c1, iv, a1, c0, tmp)); + bb.add(make_gemm(transpose::N, transpose::N, false, c1, k1, tmpv, c1, qv)); + bb.add(make_gemm(transpose::N, transpose::N, false, c1, iv, a2, c0, tmp)); + bb.add(make_gemm(transpose::N, transpose::N, false, c1, k2, tmpv, c1, qv)); + + return f; + }; + auto ctx = make_compiler_context(); + auto p = make_prog(ctx); + p.add_function(opt_kernel(ctx)); + if (dump) { + p.dump(); + } + return make_kernel_bundle(q_.get_context(), q_.get_device(), compile_to_opencl(p, dev_info_)); } template std::vector test_volume::reference() { diff --git a/examples/matrix_chain/test_volume.hpp b/examples/matrix_chain/test_volume.hpp index ebd64f55..ad163e44 100644 --- a/examples/matrix_chain/test_volume.hpp +++ b/examples/matrix_chain/test_volume.hpp @@ -19,7 +19,7 @@ template class test_volume : public test { public: test_volume(std::int64_t N, std::int64_t P, std::int64_t howmany, std::size_t alignment, - sycl::queue q); + sycl::queue q, bool dump = false); ~test_volume() = default; test_volume(test_volume const &other) = delete; test_volume(test_volume &&other) = default; @@ -35,7 +35,7 @@ template class test_volume : public test { private: constexpr static std::size_t dim = 3; - auto make_optimized_kernel() -> sycl::kernel_bundle; + auto make_optimized_kernel(bool dump) -> sycl::kernel_bundle; std::int64_t B3_, B2_, P_, howmany_, B3_aligned_, B2_aligned_; sycl::queue q_; diff --git a/examples/tall_and_skinny/main.cpp b/examples/tall_and_skinny/main.cpp index ff3913c3..30be5031 100644 --- a/examples/tall_and_skinny/main.cpp +++ b/examples/tall_and_skinny/main.cpp @@ -179,7 +179,7 @@ int main(int argc, char **argv) { .validator(examples::validate_test_case); parser.parse(argc, argv); - } catch (std::runtime_error const &e) { + } catch (std::exception const &e) { std::cerr << e.what() << std::endl; return -1; } diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 4f12c6b3..162f0d9b 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1265,8 +1265,9 @@ inline inst make_subgroup_size(compiler_context const &ctx, location const &loc * @return Instruction */ inline inst make_subview(value a, array_view static_offset_list, - array_view static_size_list, array_view offset_list, - array_view size_list, location const &loc = {}) { + array_view static_size_list, + array_view offset_list = {}, array_view size_list = {}, + location const &loc = {}) { tinytc_inst_t instr; if (static_offset_list.size() != static_size_list.size()) { throw std::invalid_argument( diff --git a/tools/argparser/argparser.hpp b/tools/argparser/argparser.hpp index 9a1afc53..e655cda5 100644 --- a/tools/argparser/argparser.hpp +++ b/tools/argparser/argparser.hpp @@ -45,8 +45,8 @@ template struct default_converter; template struct default_converter { auto operator()(char const *str, T &val) const -> parser_status { long v = strtol(str, nullptr, 0); - if (errno == ERANGE || v < std::numeric_limits::min() || - v > std::numeric_limits::max()) { + if (errno == ERANGE || static_cast(v) < std::numeric_limits::min() || + static_cast(v) > std::numeric_limits::max()) { return parser_status::argument_out_of_range; } val = v; diff --git a/tools/offline_compiler/main.cpp b/tools/offline_compiler/main.cpp index edb798be..cacc816f 100644 --- a/tools/offline_compiler/main.cpp +++ b/tools/offline_compiler/main.cpp @@ -48,7 +48,7 @@ int main(int argc, char **argv) { } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; return -1; - } catch (std::runtime_error const &e) { + } catch (std::exception const &e) { std::cerr << e.what() << std::endl; return -1; } diff --git a/tools/opt/main.cpp b/tools/opt/main.cpp index eb86ce67..dc6157c5 100644 --- a/tools/opt/main.cpp +++ b/tools/opt/main.cpp @@ -50,7 +50,7 @@ int main(int argc, char **argv) { } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; return -1; - } catch (std::runtime_error const &e) { + } catch (std::exception const &e) { std::cerr << e.what() << std::endl; return -1; } From 60a6388b511c0abbf0a969ee7f925fb4421eaa15 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 14 Oct 2024 18:43:19 +0200 Subject: [PATCH 057/297] Everything compiles again Signed-off-by: Carsten Uphoff --- docs/manual/builder.rst | 117 +++++++++++++++++++++++--------------- examples/builder/main.c | 76 +++++++++++++++---------- examples/builder/main.cpp | 36 +++++++----- include/tinytc/tinytc.hpp | 9 +-- test/visitor.cpp | 23 ++++---- 5 files changed, 156 insertions(+), 105 deletions(-) diff --git a/docs/manual/builder.rst b/docs/manual/builder.rst index f05a67f4..1ab7d59c 100644 --- a/docs/manual/builder.rst +++ b/docs/manual/builder.rst @@ -21,7 +21,9 @@ Consider the following simple copy kernel .. code-block:: func @copy(%A: memref<${type}x${M}x${N}>, %B: memref<${type}x${M}x${M}>) { - axpby.n 1.0, %A, 0.0, %B : ${type}, memref<${type}x${M}x${N}>, ${type}, memref<${type}x${M}x${N}> + %c0 = constant 0.0 -> ${type} + %c1 = constant 1.0 -> ${type} + axpby.n %c1, %A, %c0, %B : ${type}, memref<${type}x${M}x${N}>, ${type}, memref<${type}x${M}x${N}> } In the following example we build the above code programmatically and replace the place-holders (${.}) @@ -33,63 +35,88 @@ by actual values: .. code:: C - tinytc_scalar_type_t type = ...; + tinytc_scalar_type_t sty = ...; int64_t M = ...; int64_t N = ...; - tinytc_data_type_t dt; + char const *copy_fun_name = "copy"; + uint32_t num_results; + uint32_t num_params; + tinytc_compiler_context_t ctx; + tinytc_prog_t program; + tinytc_data_type_t element_ty, ty; + tinytc_func_t copy_fun; + tinytc_region_t copy_body; + tinytc_inst_t tmp; + tinytc_value_t params[2]; + tinytc_value_t alpha, beta; + + tinytc_compiler_context_create(&ctx); + + // Create program + tinytc_prog_create(&program, ctx, NULL); + + // Get types + tinytc_scalar_type_get(&element_ty, ctx, sty); int64_t shape[2] = {M, N}; - tinytc_memref_type_create(&dt, type, 2, shape, 0, NULL, NULL); - - tinytc_value_t A, B, alpha, beta; - tinytc_value_create(&A, dt, NULL); - tinytc_value_create(&B, dt, NULL); - tinytc_float_imm_create(&alpha, 1.0, type, NULL); - tinytc_float_imm_create(&beta, 0.0, type, NULL); - tinytc_data_type_release(dt); - - tinytc_inst_t copy_inst; - tinytc_axpby_inst_create(©_inst, tinytc_transpose_N, 0, alpha, A, beta, B, NULL); - tinytc_value_release(alpha); - tinytc_value_release(beta); - - tinytc_func_t copy_proto; - tinytc_value_t args[2] = {A, B}; - tinytc_function_prototype_create(©_proto, "copy", 2, args, NULL); - tinytc_value_release(A); - tinytc_value_release(B); + tinytc_memref_type_get(&ty, element_ty, 2, shape, 0, NULL, tinytc_address_space_global, NULL); - tinytc_region_t copy_body; - tinytc_region_create(©_body, 1, ©_inst, NULL); - tinytc_inst_release(copy_inst); + // Create function + tinytc_data_type_t param_types[2] = {ty, ty}; + tinytc_func_create(©_fun, sizeof(copy_fun_name), copy_fun_name, 2, param_types, NULL); + tinytc_prog_add_function(program, copy_fun); - tinytc_func_t copy_fun; - tinytc_function_create(©_fun, copy_proto, copy_body, NULL); - tinytc_func_release(copy_proto); - tinytc_region_release(copy_body); + // Get body + tinytc_func_get_body(copy_fun, ©_body); + num_params = 2; + tinytc_region_get_parameters(copy_body, &num_params, params); - tinytc_prog_t program; - tinytc_program_create(&program, 1, ©_fun, NULL); - tinytc_func_release(copy_fun); + // Create instructions + tinytc_constant_inst_create_one(&tmp, element_ty, NULL); + num_results = 1; + tinytc_inst_get_values(tmp, &num_results, &alpha); + tinytc_region_add_instruction(copy_body, tmp); + + tinytc_constant_inst_create_zero(&tmp, element_ty, NULL); + num_results = 1; + tinytc_inst_get_values(tmp, &num_results, &beta); + tinytc_region_add_instruction(copy_body, tmp); + + tinytc_axpby_inst_create(&tmp, tinytc_transpose_N, 0, alpha, params[0], beta, params[1], NULL); + tinytc_region_add_instruction(copy_body, tmp); + + // Dump program + tinytc_prog_dump(program); + + // Clean-up + tinytc_prog_release(program); + tinytc_compiler_context_release(ctx); .. tab:: C++ .. code:: C++ - scalar_type type = ...; + scalar_type sty = ...; int64_t M = ...; int64_t N = ...; - auto pb = program_builder{}; - pb.create("copy", [&](function_builder &fb) { - auto dt = make_memref(type, {M, N}); - auto A = fb.argument(dt); - auto B = fb.argument(dt); - fb.body([&](region_builder &bb) { - auto alpha = make_imm(1.0, type); - auto beta = make_imm(0.0, type); - bb.add(make_axpby(transpose::N, false, alpha, A, beta, B)); - }); - }); - auto program = pb.get_product(); + auto ctx = make_compiler_context(); + auto element_ty = get_scalar(ctx, sty); + auto ty = get_memref(element_ty, {M, N}); + + auto f = make_func("copy", {ty, ty}); + + auto body = f.get_body(); + std::array params; + body.get_parameters(params); + + auto bb = region_builder{body}; + auto alpha = bb.add(make_constant_one(element_ty)); + auto beta = bb.add(make_constant_zero(element_ty)); + bb.add(make_axpby(transpose::N, false, alpha, params[0], beta, params[1])); + + auto p = make_prog(ctx); + p.add_function(std::move(f)); + + p.dump(); diff --git a/examples/builder/main.c b/examples/builder/main.c index 49ad78d4..8877c87f 100644 --- a/examples/builder/main.c +++ b/examples/builder/main.c @@ -6,48 +6,62 @@ #include int main(void) { - tinytc_scalar_type_t type = tinytc_scalar_type_f32; + tinytc_scalar_type_t sty = tinytc_scalar_type_f32; int64_t M = 64; int64_t N = 32; - tinytc_data_type_t dt; + char const *copy_fun_name = "copy"; + uint32_t num_results; + uint32_t num_params; + tinytc_compiler_context_t ctx; + tinytc_prog_t program; + tinytc_data_type_t element_ty, ty; + tinytc_func_t copy_fun; + tinytc_region_t copy_body; + tinytc_inst_t tmp; + tinytc_value_t params[2]; + tinytc_value_t alpha, beta; + + tinytc_compiler_context_create(&ctx); + + // Create program + tinytc_prog_create(&program, ctx, NULL); + + // Get types + tinytc_scalar_type_get(&element_ty, ctx, sty); int64_t shape[2] = {M, N}; - tinytc_memref_type_create(&dt, type, 2, shape, 0, NULL, NULL); - - tinytc_value_t A, B, alpha, beta; - tinytc_value_create(&A, dt, NULL); - tinytc_value_create(&B, dt, NULL); - tinytc_float_imm_create(&alpha, 1.0, type, NULL); - tinytc_float_imm_create(&beta, 0.0, type, NULL); - tinytc_data_type_release(dt); - - tinytc_inst_t copy_inst; - tinytc_axpby_inst_create(©_inst, tinytc_transpose_N, 0, alpha, A, beta, B, NULL); - tinytc_value_release(alpha); - tinytc_value_release(beta); - - tinytc_func_t copy_proto; - tinytc_value_t args[2] = {A, B}; - tinytc_function_prototype_create(©_proto, "copy", 2, args, NULL); - tinytc_value_release(A); - tinytc_value_release(B); + tinytc_memref_type_get(&ty, element_ty, 2, shape, 0, NULL, tinytc_address_space_global, NULL); - tinytc_region_t copy_body; - tinytc_region_create(©_body, 1, ©_inst, NULL); - tinytc_inst_release(copy_inst); + // Create function + tinytc_data_type_t param_types[2] = {ty, ty}; + tinytc_func_create(©_fun, sizeof(copy_fun_name) - 1, copy_fun_name, 2, param_types, NULL); + tinytc_prog_add_function(program, copy_fun); - tinytc_func_t copy_fun; - tinytc_function_create(©_fun, copy_proto, copy_body, NULL); - tinytc_func_release(copy_proto); - tinytc_region_release(copy_body); + // Get body + tinytc_func_get_body(copy_fun, ©_body); + num_params = 2; + tinytc_region_get_parameters(copy_body, &num_params, params); - tinytc_prog_t program; - tinytc_program_create(&program, 1, ©_fun, NULL); - tinytc_func_release(copy_fun); + // Create instructions + tinytc_constant_inst_create_one(&tmp, element_ty, NULL); + num_results = 1; + tinytc_inst_get_values(tmp, &num_results, &alpha); + tinytc_region_add_instruction(copy_body, tmp); + + tinytc_constant_inst_create_zero(&tmp, element_ty, NULL); + num_results = 1; + tinytc_inst_get_values(tmp, &num_results, &beta); + tinytc_region_add_instruction(copy_body, tmp); + + tinytc_axpby_inst_create(&tmp, tinytc_transpose_N, 0, alpha, params[0], beta, params[1], NULL); + tinytc_region_add_instruction(copy_body, tmp); + // Dump program tinytc_prog_dump(program); + // Clean-up tinytc_prog_release(program); + tinytc_compiler_context_release(ctx); return 0; } diff --git a/examples/builder/main.cpp b/examples/builder/main.cpp index f506aeeb..3aeba1ea 100644 --- a/examples/builder/main.cpp +++ b/examples/builder/main.cpp @@ -4,31 +4,37 @@ #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" +#include #include #include using namespace tinytc; int main() { - scalar_type type = scalar_type::f32; + scalar_type sty = scalar_type::f32; int64_t M = 64; int64_t N = 32; try { - auto pb = program_builder{}; - pb.create("copy", [&](function_builder &fb) { - auto dt = make_memref(type, {M, N}); - auto A = fb.argument(dt); - auto B = fb.argument(dt); - fb.body([&](region_builder &bb) { - auto alpha = make_imm(1.0, type); - auto beta = make_imm(0.0, type); - bb.add(make_axpby(transpose::N, false, alpha, A, beta, B)); - }); - }); - auto program = pb.get_product(); - - program.dump(); + auto ctx = make_compiler_context(); + auto element_ty = get_scalar(ctx, sty); + auto ty = get_memref(element_ty, {M, N}); + + auto f = make_func("copy", {ty, ty}); + + auto body = f.get_body(); + std::array params; + body.get_parameters(params); + + auto bb = region_builder{body}; + auto alpha = bb.add(make_constant_one(element_ty)); + auto beta = bb.add(make_constant_zero(element_ty)); + bb.add(make_axpby(transpose::N, false, alpha, params[0], beta, params[1])); + + auto p = make_prog(ctx); + p.add_function(std::move(f)); + + p.dump(); } catch (builder_error const &e) { std::cerr << "Error " << static_cast(e.code()) << std::endl; } catch (status const &st) { diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 162f0d9b..77fd6bbc 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -2098,15 +2098,16 @@ constexpr bool is_supported_scalar_type = std::is_same_v || /** * @brief True if T is either pointer to a support scalar type or a pointer to a pointer to a - * supported scalar type + * supported scalar type; void* is fine, too * * @tparam T type */ template constexpr bool is_usm_pointer_type = - std::is_pointer_v && - (is_supported_scalar_type> || - is_supported_scalar_type>>); + std::is_same_v || + (std::is_pointer_v && + (is_supported_scalar_type> || + is_supported_scalar_type>>)); /** * @brief Specialize auto_mem_type for pointer to non-class types diff --git a/test/visitor.cpp b/test/visitor.cpp index 5f1ed893..00858fde 100644 --- a/test/visitor.cpp +++ b/test/visitor.cpp @@ -4,22 +4,25 @@ #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" +#include #include -#include using namespace tinytc; TEST_CASE("is_equal") { - CHECK(is_equal(*make_scalar(scalar_type::f32), *make_scalar(scalar_type::f32))); - CHECK(!is_equal(*make_scalar(scalar_type::i32), *make_scalar(scalar_type::i16))); - auto a = make_memref(scalar_type::f32, {1, 2}); - auto b = make_memref(scalar_type::f32, {2, 3}); - auto c = make_memref(scalar_type::f64, {1, 2}); + auto ctx = make_compiler_context(); + auto f32_ty = get_scalar(ctx, scalar_type::f32); + auto f64_ty = get_scalar(ctx, scalar_type::f64); + CHECK(is_equal(*f32_ty, *f32_ty)); + CHECK(!is_equal(*get_scalar(ctx, scalar_type::i32), *get_scalar(ctx, scalar_type::i16))); + auto a = get_memref(f32_ty, {1, 2}); + auto b = get_memref(f32_ty, {2, 3}); + auto c = get_memref(f64_ty, {1, 2}); CHECK(is_equal(*a, *a)); CHECK(!is_equal(*a, *b)); CHECK(!is_equal(*a, *c)); - CHECK(is_equal(*make_group(a), *make_group(a))); - CHECK(!is_equal(*make_group(a), *make_group(b))); - CHECK(!is_equal(*make_group(a), *make_group(c))); - CHECK(!is_equal(*make_group(a), *a)); + CHECK(is_equal(*get_group(a), *get_group(a))); + CHECK(!is_equal(*get_group(a), *get_group(b))); + CHECK(!is_equal(*get_group(a), *get_group(c))); + CHECK(!is_equal(*get_group(a), *a)); } From 4da4769742f12e3d42ce1b566f7b5bb74848129d Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 15 Oct 2024 08:51:55 +0200 Subject: [PATCH 058/297] Make factor calculation dynamic in test ader Signed-off-by: Carsten Uphoff --- examples/matrix_chain/test_ader.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/matrix_chain/test_ader.cpp b/examples/matrix_chain/test_ader.cpp index 2ff2ffb8..19e14d42 100644 --- a/examples/matrix_chain/test_ader.cpp +++ b/examples/matrix_chain/test_ader.cpp @@ -82,6 +82,7 @@ auto test_ader::make_optimized_kernel(bool dump) fn_body.get_parameters(params); auto dt = params[0]; + dt.set_name("dt"); auto A = [¶ms](std::size_t i) -> value & { return params[1 + i]; }; auto K = [¶ms](std::size_t i) -> value & { return params[1 + dim + i]; }; auto Q = params[1 + 2 * dim + 0]; @@ -93,10 +94,6 @@ auto test_ader::make_optimized_kernel(bool dump) Q.set_name("Q"); I.set_name("I"); - T dt0 = T{1.01}; - T num = T{1}; - int denom = 1; - auto bb = region_builder{fn_body}; auto const c0 = bb.add(make_constant_zero(element_ty)); auto const c1 = bb.add(make_constant_one(element_ty)); @@ -113,10 +110,14 @@ auto test_ader::make_optimized_kernel(bool dump) auto i = bb.add(make_subview(I, static_offsets3, static_sizes3(I_opt_), offsets3)); bb.add(make_axpby(transpose::N, false, c1, dq, c1, i)); + int denom = 1; + auto cnum = c1; auto const static_offsets2 = std::array{0, 0}; for (std::int64_t n = 1; n <= N_; ++n) { - num *= dt0; + cnum = bb.add(make_arith(arithmetic::mul, cnum, dt)); denom *= n + 1; + auto cdenom = bb.add(make_constant(static_cast(denom), element_ty)); + auto cfactor = bb.add(make_arith(arithmetic::div, cnum, cdenom)); auto bn = Bd_aligned(N_ - n); auto dq_next = bb.add(make_alloca(dQ_[n].local_type(element_ty))); auto dq_nextv = bb.add(make_subview(dq_next, static_offsets2, {bn, P_})); @@ -129,7 +130,6 @@ auto test_ader::make_optimized_kernel(bool dump) dq_nextv)); } auto iv = bb.add(make_subview(i, static_offsets2, {Bd(N_ - n), P_})); - auto cfactor = bb.add(make_constant(num / denom, element_ty)); bb.add(make_axpby(transpose::N, false, cfactor, dq_next, c1, iv)); dq = dq_next; } From defbeca0802d7ec289eecfe7c23ffd75dc93e3cc Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 15 Oct 2024 12:13:01 +0200 Subject: [PATCH 059/297] Add support for for-instruction returning values Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 95 ++++++++++++++++++++++---------- include/tinytc/tinytc.h | 10 +++- include/tinytc/tinytc.hpp | 55 ++++++++++++++++-- include/tinytc/types.h | 11 ++-- include/tinytc/types.hpp | 1 + src/error.cpp | 2 + src/inst.cpp | 14 +++-- src/node/inst_node.cpp | 28 ++++++++-- src/node/inst_node.hpp | 27 ++++++--- src/node/region_node.cpp | 5 ++ src/node/region_node.hpp | 2 + src/parser/lexer.re | 5 +- src/parser/parser_impl.yy | 43 ++++++++++++++- src/pass/convert_to_opencl.cpp | 26 ++++++--- src/pass/dump_ir.cpp | 40 +++++++++++--- test/codegen/for.ir | 30 +++++++++- test/opt/constant-propagation.ir | 17 ++++++ 17 files changed, 334 insertions(+), 77 deletions(-) diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 29fcd893..e800ca09 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -352,10 +352,10 @@ Overview A foreach loop that executes the loop's range [from; to) without any sequence guarantee. The region of a foreach is a *spmd region*. -The loop's range [from; to) is given by the first integer value and second integer value, -and the trip count is stored in the local identifier. -The integer type of the loop variable and the loop bounds is given after the colon. -The default integer type is ``index``. +The trip count is stored in the first local identifier and is accessible within the loop body. +The loop's range [from; to) is given by the first and the second local identifier after the equals sign. +The integer type of the loop variable and the loop bounds is given after the colon and +the default integer type is ``index``. GEMM .... @@ -805,6 +805,63 @@ The product of the expand shape must be the same as the mode size. If the product of the expand shape is only known at runtime, then it is undefined behaviour if the dynamic product does not match the mode size. +For +... + +.. code:: abnf + + multi-value-instruction = "for" local-identifier "=" + local-identifier "," local-identifier ["," local-identifier] + ["init" "(" init-value-list ")" "->" "(" scalar-type-list ")" ] + [":" integer-type] region + init-value-list = init-value *("," init-value) + init-value = local-identifier "=" local-identifier + scalar-type-list = scalar-type *("," scalar-type) + +Overview +~~~~~~~~ + +A for loop. +Instructions in the for loop execute sequentially and its region is a *mixed region*. + +Arguments +~~~~~~~~~ + +The trip count is stored in the first local identifier and is accessible within the loop body. +The loop's range [from; to) is given by the first and the second local identifier after the equals sign, +and a step size may be given with the third local identifier after the equals sign. +The step size defaults to 1 if omitted. +The integer type of the loop variable and the loop bounds is given after the colon and +the default integer type is ``index``. + +Values that are given in the init-value-list may be carried from one iteration to the next. +The local identifier gives the name of the loop-carried value as it is accessible in the loop body. +The local identifier given on the right-hand side of the init-value expression determines +the initial value of the loop-carried value, and its type must coincide with the scalar-type-list. +When loop-carried values are present, the loop's last instruction must be a yield instruction that +updates the loop-carried values for the next iteration. +The number and types of the yielded values must correspond the scalar-type-list. + +Returns +~~~~~~~ + +The final value of the loop-carried values are returned by the for instruction. + + +Example: + + .. code:: + + %from = constant 2 -> i32 + %to = constant 6 -> i32 + %f0 = constant 0 -> i64 + %f1 = constant 1 -> i64 + %fn_1, %fn = for %n=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) : i32 { + %fn = arith.add %fn_2, %fn_1 : i64 + yield %fn_1, %fn : i64, i64 + } + ; %fn_1 contains the fourth Fibonacci number and %fn the fifth Fibonacci number + Fuse .... @@ -894,9 +951,8 @@ If .. code:: abnf - multi-value-instruction = "if" local-identifier ["->" "(" scalar-type-list ")"] - region ["else" region] - type-list = scalar-type *("," scalar-type) + multi-value-instruction =/ "if" local-identifier ["->" "(" scalar-type-list ")"] + region ["else" region] Overview ~~~~~~~~ @@ -906,8 +962,8 @@ Both regions are *mixed regions*. The condition must be of bool type. -Arguments -~~~~~~~~~ +Returns +~~~~~~~ The if instruction may return multiple values, where the number of values and the value types are given by the scalar-type-list. @@ -971,27 +1027,6 @@ Overview Returns the number of subgroups the work-group is divided in; i32 integer. -For -... - -.. code:: abnf - - instruction =/ "for" local-identifier "=" local-identifier "," local-identifier - ["," local-identifier] [":" integer-type] region - -Overview -~~~~~~~~ - -A for loop. -Instructions in the for loop execute sequentially and its region is a *mixed region*. - -The loop's range [from; to) is given by the first integer constant and second integer constant, -and the trip count is stored in the local identifier. -A step size can be given with the third integer constant. -The step size defaults to 1 if omitted. -The integer type of the loop variable and the loop bounds is given after the colon. -The default integer type is ``index``. - Size .... diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 05033691..34c598d2 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -691,7 +691,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinyt * @brief Create for loop * * @code - * for %loop_var = %from, %to, %step : loop_var_type { } + * for %loop_var = %from, %to, %step init(initial_value_list) : loop_var_type { } * ; loop_var_type == type(%from) * ; loop_var_type == type(%to) * ; loop_var_type == type(%step) @@ -701,6 +701,11 @@ TINYTC_EXPORT tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinyt * @param from [in] loop begion * @param to [in] loop bound * @param step [in][optional] loop step; can be nullptr + * @param init_list_size [in] length of init_value_list and return_type_list + * @param init_value_list [in][range(0, init_list_size)] array of initial values; can be + * nullptr if init_value_list is 0 + * @param return_type_list [in][range(0, init_list_size)] return type array; can be nullptr + * if return_type_list_size is 0 * @param loop_var_type [in] type of loop variable * @param loc [in][optional] Source code location; can be nullptr * @@ -708,6 +713,9 @@ TINYTC_EXPORT tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinyt */ TINYTC_EXPORT tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, + uint32_t init_list_size, + const tinytc_value_t *init_value_list, + const tinytc_data_type_t *return_type_list, tinytc_data_type_t loop_var_type, const tinytc_location_t *loc); diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 77fd6bbc..8aabaa1e 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1346,16 +1346,29 @@ inline inst make_sum(transpose tA, bool atomic, value alpha, value A, value beta * * @param from Loop variable start * @param to Loop variable bound - * @param step Loop variable step + * @param step Loop variable step; can be {} + * @param initial_value_list Array of initial values; can be {} + * @param return_type_list Array of returned types; can be {} * @param loop_var_type Type of loop variable * @param loc Source code location * * @return Instruction */ -inline inst make_for(value from, value to, value step, data_type loop_var_type, +inline inst make_for(value from, value to, value step, array_view initial_value_list, + array_view return_type_list, data_type loop_var_type, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_for_inst_create(&instr, from, to, step, loop_var_type, &loc), loc); + if (initial_value_list.size() != return_type_list.size()) { + throw builder_error(status::ir_init_return_mismatch, loc); + } + auto len = return_type_list.size(); + if (len > std::numeric_limits::max()) { + throw std::out_of_range("return type list too long"); + } + const tinytc_value_t *il = reinterpret_cast(initial_value_list.data()); + CHECK_STATUS_LOC(tinytc_for_inst_create(&instr, from, to, step, len, il, + return_type_list.data(), loop_var_type, &loc), + loc); return inst(instr); } @@ -1622,7 +1635,7 @@ class region_builder { template void for_loop(value from, value to, value step, data_type loop_var_ty, F &&f, location const &loc = {}) { - auto fi = ::tinytc::make_for(from, to, step, loop_var_ty, loc); + auto fi = ::tinytc::make_for(from, to, step, {}, {}, loop_var_ty, loc); auto reg = region{}; fi.get_regions(reg); auto loop_var = value{}; @@ -1634,6 +1647,40 @@ class region_builder { auto bb = region_builder{reg}; f(bb, loop_var); } + /** + * @brief Build for-loop with functor f(region_builder&, array_view) -> void + * + * The loop trip count is the first value in the array_view. + * The following values are the loop-carried values. + * + * @tparam F Functor type + * @param from Loop variable start + * @param to Loop variable bound + * @param step Loop variable step + * @param initial_value_list Array of initial values; can be {} + * @param return_type_list Array of returned types; can be {} + * @param loop_var_ty Type of loop variable + * @param f Functor + * @param loc Source code location + */ + template + void for_loop(value from, value to, value step, array_view initial_value_list, + array_view return_type_list, data_type loop_var_ty, F &&f, + location const &loc = {}) { + auto fi = ::tinytc::make_for(from, to, step, initial_value_list, return_type_list, + loop_var_ty, loc); + auto reg = region{}; + fi.get_regions(reg); + auto num_params = reg.get_parameters({}); + auto params = std::vector(num_params); + reg.get_parameters(params); + if (!reg || num_params != 1 + return_type_list.size()) { + throw status::internal_compiler_error; + } + reg_.add_instruction(std::move(fi)); + auto bb = region_builder{reg}; + f(bb, array_view(params)); + } /** * @brief Build foreach-loop with functor f(region_builder&, value) -> void * diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 8989798a..8af0704a 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -65,11 +65,12 @@ typedef enum { tinytc_status_ir_expected_local_address_space = 0x114, ///< Expected local address space tinytc_status_ir_expected_global_address_space = 0x115, ///< Expected global address space tinytc_status_ir_invalid_offset = 0x116, ///< Invalid offset - tinytc_status_ir_int_unsupported = 0x117, ///< Instruction does not support int type - tinytc_status_ir_i1_unsupported = 0x118, ///< Instruction does not support i1 type - tinytc_status_ir_complex_unsupported = 0x119, ///< Instruction does not support complex type - tinytc_status_ir_forbidden_cast = 0x11a, ///< Forbidden cast - tinytc_status_ir_invalid_beta = 0x11b, ///< Invalid beta value + tinytc_status_ir_int_unsupported = 0x117, ///< Instruction does not support int type + tinytc_status_ir_i1_unsupported = 0x118, ///< Instruction does not support i1 type + tinytc_status_ir_complex_unsupported = 0x119, ///< Instruction does not support complex type + tinytc_status_ir_forbidden_cast = 0x11a, ///< Forbidden cast + tinytc_status_ir_invalid_beta = 0x11b, ///< Invalid beta value + tinytc_status_ir_init_return_mismatch = 0x11c, ///< Mismatch of init values and returned values // Level zero errors tinytc_status_ze_result_not_ready = 0x10000, ///< ZE_RESULT_NOT_READY tinytc_status_ze_result_error_device_lost = 0x10001, ///< ZE_RESULT_ERROR_DEVICE_LOST diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 4b20478b..73b6d99f 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -80,6 +80,7 @@ enum class status { ir_complex_unsupported = tinytc_status_ir_complex_unsupported, ir_forbidden_cast = tinytc_status_ir_forbidden_cast, ir_invalid_beta = tinytc_status_ir_invalid_beta, + ir_init_return_mismatch = tinytc_status_ir_init_return_mismatch, ze_result_not_ready = tinytc_status_ze_result_not_ready, ze_result_error_device_lost = tinytc_status_ze_result_error_device_lost, ze_result_error_out_of_host_memory = tinytc_status_ze_result_error_out_of_host_memory, diff --git a/src/error.cpp b/src/error.cpp index e9a9c6a6..ba9e8df3 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -175,6 +175,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Forbidden cast"; case tinytc_status_ir_invalid_beta: return "beta must be constant and 0 or 1 for atomic linear algebra operations"; + case tinytc_status_ir_init_return_mismatch: + return "The number or types of the initial values does not match the return type list"; // Level Zero case tinytc_status_ze_result_not_ready: return "ZE_RESULT_NOT_READY"; diff --git a/src/inst.cpp b/src/inst.cpp index 9ce8cc53..53b19cb6 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -511,14 +511,20 @@ tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinytc_transpose_t } tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_value_t from, tinytc_value_t to, - tinytc_value_t step, tinytc_data_type_t loop_var_type, + tinytc_value_t step, uint32_t init_list_size, + const tinytc_value_t *initial_value_list, + const tinytc_data_type_t *return_type_list, + tinytc_data_type_t loop_var_type, const tinytc_location_t *loc) { - if (instr == nullptr || loop_var_type == nullptr || from == nullptr || to == nullptr) { + if (instr == nullptr || loop_var_type == nullptr || from == nullptr || to == nullptr || + (init_list_size != 0 && (initial_value_list == nullptr || return_type_list == nullptr))) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = - std::make_unique(from, to, step, loop_var_type, get_optional(loc)).release(); + *instr = std::make_unique( + from, to, step, array_view{initial_value_list, init_list_size}, + array_view{return_type_list, init_list_size}, loop_var_type, get_optional(loc)) + .release(); }); } diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index a7297672..7fee5a0a 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -54,14 +54,33 @@ blas_a3_inst::blas_a3_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinyt } loop_inst::loop_inst(IK tid, tinytc_value_t from0, tinytc_value_t to0, tinytc_value_t step0, - tinytc_data_type_t loop_var_type, location const &lc) - : standard_inst{tid, step0 ? 3 : 2} { + array_view init_values, + array_view return_types, tinytc_data_type_t loop_var_type, + location const &lc) + : standard_inst{tid, (step0 ? 3 : 2) + static_cast(init_values.size()), + static_cast(return_types.size())} { + if (init_values.size() != return_types.size()) { + throw compilation_error(loc(), status::ir_init_return_mismatch); + } + op(op_from, from0); op(op_to, to0); if (step0) { op(op_step, step0); } - body().set_params(array_view{loop_var_type}, lc); + + body().set_num_params(1 + return_types.size()); + body().set_param(0, loop_var_type, lc); + for (std::size_t i = 0; i < return_types.size(); ++i) { + body().set_param(1 + i, return_types[i], lc); + result(i) = value_node{return_types[i], this, lc}; + } + for (std::size_t i = 0; i < init_values.size(); ++i) { + if (init_values[i]->ty() != return_types[i]) { + throw compilation_error(loc(), status::ir_init_return_mismatch); + } + op(op_init() + i, init_values[i]); + } loc(lc); auto lvt = get_scalar_type(loc(), loop_var()); @@ -495,7 +514,8 @@ ger_inst::ger_inst(tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t B0, foreach_inst::foreach_inst(tinytc_value_t from, tinytc_value_t to, tinytc_data_type_t loop_var_type, location const &loc) - : loop_inst{IK::foreach_loop, std::move(from), std::move(to), {}, loop_var_type, loc} { + : loop_inst{ + IK::foreach_loop, std::move(from), std::move(to), nullptr, {}, {}, loop_var_type, loc} { child_region(0).kind(region_kind::spmd); } diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 8ddb7fef..b286b1c7 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -363,22 +363,34 @@ class blas_a3_inst : public standard_inst<5, 0> { bool atomic_; }; -class loop_inst : public standard_inst<3, 0, 1> { +class loop_inst : public standard_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() >= IK::loop && i.type_id() <= IK::last_loop; } enum op_number { op_from = 0, op_to = 1, op_step = 2 }; loop_inst(IK tid, tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, + array_view init_values, array_view return_types, tinytc_data_type_t loop_var_type, location const &loc = {}); inline auto from() const -> tinytc_value const & { return op(op_from); } inline auto to() const -> tinytc_value const & { return op(op_to); } - inline auto has_step() const -> bool { return get_use(op_step).get() != nullptr; } + inline auto has_step() const -> bool { return op_init() == 3; } inline auto step() const -> tinytc_value const & { return op(op_step); } inline auto body() -> tinytc_region & { return child_region(0); } inline auto body() const -> tinytc_region const & { return child_region(0); } inline auto loop_var() -> tinytc_value & { return body().param(0); } inline auto loop_var() const -> tinytc_value const & { return body().param(0); } + inline auto iter_arg(std::int64_t no) -> tinytc_value & { return body().param(no + 1); } + inline auto iter_arg(std::int64_t no) const -> tinytc_value const & { + return body().param(no + 1); + } + inline auto iter_init(std::int64_t no) -> tinytc_value & { return op(op_init() + no); } + inline auto iter_init(std::int64_t no) const -> tinytc_value const & { + return op(op_init() + no); + } + + private: + inline auto op_init() const -> std::int64_t { return num_operands() - num_results(); } }; class alloca_inst : public standard_inst<0, 1> { @@ -595,13 +607,12 @@ class ger_inst : public blas_a3_inst { class for_inst : public loop_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::for_loop; } - inline for_inst(tinytc_value_t from, tinytc_value_t to, tinytc_data_type_t loop_var_type, - location const &loc = {}) - : for_inst{std::move(from), std::move(to), {}, loop_var_type, loc} {} inline for_inst(tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, - tinytc_data_type_t loop_var_type, location const &loc = {}) - : loop_inst{IK::for_loop, std::move(from), std::move(to), - std::move(step), loop_var_type, loc} {} + array_view init_values, + array_view return_types, tinytc_data_type_t loop_var_type, + location const &loc = {}) + : loop_inst{IK::for_loop, std::move(from), std::move(to), std::move(step), + std::move(init_values), std::move(return_types), loop_var_type, loc} {} }; class foreach_inst : public loop_inst { diff --git a/src/node/region_node.cpp b/src/node/region_node.cpp index 90421c98..31266172 100644 --- a/src/node/region_node.cpp +++ b/src/node/region_node.cpp @@ -37,3 +37,8 @@ void tinytc_region::set_params(array_view param_types, locat params_[i] = tinytc_value{param_types[i], nullptr, lc}; } } + +void tinytc_region::set_num_params(std::size_t num_params) { params_.resize(num_params); } +void tinytc_region::set_param(std::size_t idx, tinytc_data_type_t param_type, location const &lc) { + params_[idx] = tinytc_value{param_type, nullptr, lc}; +} diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp index f142db8a..d07e9b07 100644 --- a/src/node/region_node.hpp +++ b/src/node/region_node.hpp @@ -64,6 +64,8 @@ struct tinytc_region final { } inline auto num_params() const noexcept -> std::int64_t { return params_.size(); } void set_params(tinytc::array_view param_types, tinytc::location const &lc); + void set_num_params(std::size_t num_params); + void set_param(std::size_t idx, tinytc_data_type_t param_type, tinytc::location const &lc); private: static auto inst_list_offset() -> std::size_t { diff --git a/src/parser/lexer.re b/src/parser/lexer.re index 1f570de9..3832c4ce 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -96,10 +96,11 @@ lex: ".t" { adv_loc(); return parser::make_TRANS(loc_); } ".atomic" { adv_loc(); return parser::make_ATOMIC(loc_); } ".atomic_add" { adv_loc(); return parser::make_ATOMIC_ADD(loc_); } + "init" { adv_loc(); return parser::make_INIT(loc_); } "local" { adv_loc(); return parser::make_LOCAL(loc_); } "global" { adv_loc(); return parser::make_GLOBAL(loc_); } - ".local" { adv_loc(); return parser::make_LOCAL_ATTR(loc_); } - ".global" { adv_loc(); return parser::make_GLOBAL_ATTR(loc_); } + ".local" { adv_loc(); return parser::make_LOCAL_ATTR(loc_); } + ".global" { adv_loc(); return parser::make_GLOBAL_ATTR(loc_); } // constants "true" { adv_loc(); return parser::make_INTEGER_CONSTANT(1, loc_); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 394ab6fb..13d7c99d 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -17,6 +17,7 @@ #include #include #include + #include #include namespace tinytc { @@ -95,6 +96,7 @@ TRANS ".t" ATOMIC ".atomic" ATOMIC_ADD ".atomic_add" + INIT "init" LOCAL "local" GLOBAL "global" LOCAL_ATTR ".local" @@ -175,6 +177,9 @@ %nterm ger_inst %nterm transpose %nterm for_inst +%nterm >, std::vector, std::vector>> optional_loop_carried_values +%nterm >, std::vector>> init_value_list +%nterm , tinytc_value_t>> init_value %nterm optional_step %nterm foreach_inst %nterm hadamard_inst @@ -561,19 +566,27 @@ ger_inst: ; for_inst: - FOR LOCAL_IDENTIFIER[loop_var] EQUALS var[from] COMMA var[to] optional_step for_loop_var_type { + FOR LOCAL_IDENTIFIER[loop_var] EQUALS var[from] COMMA var[to] optional_step optional_loop_carried_values[lcv] for_loop_var_type { check_type($from, $for_loop_var_type, @from, @for_loop_var_type); check_type($to, $for_loop_var_type, @to, @for_loop_var_type); if ($optional_step) { check_type($optional_step, $for_loop_var_type, @optional_step, @for_loop_var_type); } try { + auto &[lcv_id, lcv_init, lcv_type] = $lcv; + if (lcv_init.size() != lcv_type.size()) { + throw parser::syntax_error(@lcv, "Length of init value list must match scalar type list"); + } location loc = @FOR; loc.end = @for_loop_var_type.end; - auto inode = std::make_unique($from, $to, $optional_step, $for_loop_var_type, loc); + auto inode = std::make_unique($from, $to, $optional_step, lcv_init, + lcv_type, $for_loop_var_type, loc); ctx.push_scope(); auto &loop_var = inode->loop_var(); ctx.val($loop_var, loop_var, @loop_var); + for (std::int64_t i = 0; i < inode->num_results(); ++i) { + ctx.val(lcv_id[i], inode->iter_arg(i), @lcv); + } ctx.push_region(&inode->body()); $$ = inst{inode.release()}; } catch (compilation_error const &e) { @@ -590,6 +603,31 @@ for_inst: optional_step: %empty { $$ = {}; } | COMMA var { $$ = $var; } +; + +optional_loop_carried_values: + %empty { $$ = {}; } + | INIT LPAREN init_value_list RPAREN RETURNS LPAREN scalar_type_list RPAREN { + $$ = std::make_tuple(std::move($init_value_list.first), std::move($init_value_list.second), + std::move($scalar_type_list)); + } +; + +init_value_list: + init_value { + $$.first.emplace_back($init_value.first); + $$.second.emplace_back($init_value.second); + } + | init_value_list COMMA init_value { + $$ = std::move($1); + $$.first.emplace_back($init_value.first); + $$.second.emplace_back($init_value.second); + } +; + +init_value: + LOCAL_IDENTIFIER EQUALS var { $$ = std::make_pair($LOCAL_IDENTIFIER, $var); } +; foreach_inst: FOREACH LOCAL_IDENTIFIER[loop_var] EQUALS var[from] COMMA var[to] for_loop_var_type { @@ -709,6 +747,7 @@ valued_inst: | compare_inst { $$ = std::move($1); } | constant_inst { $$ = std::move($1); } | expand_inst { $$ = std::move($1); } + | for_inst { $$ = std::move($1); } | fuse_inst { $$ = std::move($1); } | group_id_inst { $$ = std::move($1); } | group_size_inst { $$ = std::move($1); } diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index b5b9b813..63d43bee 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -784,19 +784,31 @@ std::vector convert_to_opencl_pass::operator()(ger_inst const &g) { return {bb.get_product()}; } -std::vector convert_to_opencl_pass::operator()(for_inst const &p) { +std::vector convert_to_opencl_pass::operator()(for_inst const &in) { auto clinst = std::vector{}; + yielded_vars_.push_back(std::vector{}); + for (std::int64_t i = 0; i < in.num_results(); ++i) { + auto v = declare(in.iter_arg(i)); + clinst.emplace_back(clir::declaration_assignment(visit(*this, *in.iter_arg(i).ty()), v, + val(in.iter_init(i)))); + yielded_vars_.back().emplace_back(std::move(v)); + } - auto lv = declare(p.loop_var()); - auto lv_ty = visit(*this, *p.loop_var().ty()); - auto start = clir::declaration_assignment(std::move(lv_ty), lv, val(p.from())); - auto condition = lv < val(p.to()); - auto step = p.has_step() ? clir::add_into(lv, val(p.step())) : ++lv; - auto body = run_on_region(p.body()); + auto lv = declare(in.loop_var()); + auto lv_ty = visit(*this, *in.loop_var().ty()); + auto start = clir::declaration_assignment(std::move(lv_ty), lv, val(in.from())); + auto condition = lv < val(in.to()); + auto step = in.has_step() ? clir::add_into(lv, val(in.step())) : ++lv; + auto body = run_on_region(in.body()); clinst.emplace_back(clir::stmt(std::make_shared( std::move(start), std::move(condition), std::move(step), std::move(body)))); + for (std::int64_t i = 0; i < in.num_results(); ++i) { + clinst.emplace_back(clir::declaration_assignment( + visit(*this, *in.result(i).ty()), declare(in.result(i)), yielded_vars_.back()[i])); + } + yielded_vars_.pop_back(); return clinst; } diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index a122779f..4c1bd425 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -271,21 +271,40 @@ void dump_ir_pass::operator()(ger_inst const &g) { dump_blas_a3(static_cast(g)); } -void dump_ir_pass::operator()(for_inst const &p) { +void dump_ir_pass::operator()(for_inst const &in) { + if (in.num_results() > 0) { + do_with_infix(in.result_begin(), in.result_end(), [this](auto const &i) { dump_val(i); }); + *os_ << " = "; + } *os_ << "for "; - dump_val(p.loop_var()); + dump_val(in.loop_var()); *os_ << "="; - dump_val(p.from()); + dump_val(in.from()); *os_ << ","; - dump_val(p.to()); - if (p.has_step()) { + dump_val(in.to()); + if (in.has_step()) { *os_ << ","; - dump_val(p.step()); + dump_val(in.step()); + } + if (in.num_results() > 0) { + *os_ << " init("; + for (std::int64_t i = 0; i < in.num_results(); ++i) { + if (i != 0) { + *os_ << ","; + } + dump_val(in.iter_arg(i)); + *os_ << "="; + dump_val(in.iter_init(i)); + } + *os_ << ") -> ("; + do_with_infix(in.result_begin(), in.result_end(), + [this](auto const &i) { visit(*this, *i.ty()); }); + *os_ << ")"; } *os_ << " : "; - visit(*this, *p.loop_var().ty()); + visit(*this, *in.loop_var().ty()); *os_ << " "; - dump_region(p.body()); + dump_region(in.body()); } void dump_ir_pass::operator()(foreach_inst const &p) { @@ -307,6 +326,11 @@ void dump_ir_pass::operator()(hadamard_inst const &g) { } void dump_ir_pass::operator()(if_inst const &in) { + + if (in.num_results() > 0) { + do_with_infix(in.result_begin(), in.result_end(), [this](auto const &i) { dump_val(i); }); + *os_ << " = "; + } *os_ << "if "; dump_val(in.condition()); *os_ << " "; diff --git a/test/codegen/for.ir b/test/codegen/for.ir index 3b6564f6..6f57d11b 100644 --- a/test/codegen/for.ir +++ b/test/codegen/for.ir @@ -6,11 +6,37 @@ func @for1() { %lb0 = constant 0 -> index %ub0 = constant 10 -> index for %0 = %lb0,%ub0 { - ; CHECK: for (long x = lb0; x < ub0; ++x) } %lb1 = constant -2 -> i16 %ub1 = constant 2 -> i16 for %1 = %lb1,%ub1 : i16 { - ; CHECK: for (short x = lb1; x < ub1; ++x) } +; CHECK-LABEL: void for1({{.*}} +; CHECK: for (long x = lb0; x < ub0; ++x) +; CHECK: for (short x = lb1; x < ub1; ++x) +} + +func @for2(%fib: memref) { + %from = constant 2 -> i32 + %to = constant 6 -> i32 + %f0 = constant 0 -> i64 + %f1 = constant 1 -> i64 + %fn_1, %fn = for %n=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) : i32 { + %fn = arith.add %fn_2, %fn_1 : i64 + yield %fn_1, %fn : i64, i64 + } + store %fn, %fib[] : memref +; CHECK-LABEL: void for2({{.*}} +; CHECK: long f0 = 0ll; +; CHECK-NEXT: long f1 = 1ll; +; CHECK-NEXT: long fn_2 = f0; +; CHECK-NEXT: long fn_1 = f1; +; CHECK-NEXT: for (int n = from; n < to; ++n) { +; CHECK-NEXT: long fn = fn_2 + fn_1; +; CHECK-NEXT: fn_2 = fn_1; +; CHECK-NEXT: fn_1 = fn; +; CHECK-NEXT: } +; CHECK-NEXT: long fn_11 = fn_2; +; CHECK-NEXT: long fn = fn_1; +; CHECK-NEXT: *fib = fn; } diff --git a/test/opt/constant-propagation.ir b/test/opt/constant-propagation.ir index 702135b6..fbd3adc8 100644 --- a/test/opt/constant-propagation.ir +++ b/test/opt/constant-propagation.ir @@ -36,6 +36,23 @@ func @known_loop_bounds() { ; CHECK-NEXT: for %i=%lb,%1 : index { } +func @known_loop_iter_args() { + %c1 = constant 1 -> index + %c5 = constant 5 -> index + %0 = arith.add %c1, %c5 : index + %2 = for %i=%c1,%c5 init(%1=%0) -> (index) { + yield %1 : index + } +; CHECK-LABEL: func @known_loop_iter_args({{.*}} +; CHECK-NEXT: %c1 = constant 1 -> index +; CHECK-NEXT: %c5 = constant 5 -> index +; CHECK-NEXT: %0 = constant 6 -> index +; CHECK-NEXT: %1 = arith.add %c1, %c5 : index +; CHECK-NEXT: %3 = for %i=%c1,%c5 init(%2=%0) -> (index) : index { +; CHECK-NEXT: yield %2 : index +; CHECK-NEXT: } +} + func @known_arith() { %0 = constant 1 -> i64 %1 = constant 2 -> i64 From 7ff1a3505ab05b5210749a0d666872417dea95b0 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 15 Oct 2024 12:48:10 +0200 Subject: [PATCH 060/297] Update linalg lowering Signed-off-by: Carsten Uphoff --- include/tinytc/tinytc.hpp | 7 ++++--- src/pass/lower_linalg.cpp | 27 ++++++++++++++++----------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 8aabaa1e..44d5718e 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1664,9 +1664,9 @@ class region_builder { * @param loc Source code location */ template - void for_loop(value from, value to, value step, array_view initial_value_list, + auto for_loop(value from, value to, value step, array_view initial_value_list, array_view return_type_list, data_type loop_var_ty, F &&f, - location const &loc = {}) { + location const &loc = {}) -> std::vector { auto fi = ::tinytc::make_for(from, to, step, initial_value_list, return_type_list, loop_var_ty, loc); auto reg = region{}; @@ -1677,9 +1677,10 @@ class region_builder { if (!reg || num_params != 1 + return_type_list.size()) { throw status::internal_compiler_error; } - reg_.add_instruction(std::move(fi)); + auto results = add_multivalued(std::move(fi)); auto bb = region_builder{reg}; f(bb, array_view(params)); + return results; } /** * @brief Build foreach-loop with functor f(region_builder&, value) -> void diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index 2b171ce3..3c4e1326 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -196,17 +196,22 @@ auto linalg_generator::operator()(sum_inst &in) -> inst { tile_loop_by_sgs_standard( bb, c_shape0, core_cfg_.subgroup_size, tiling_.m_tiles() * tiling_.n_tiles(), sgid, [&](region_builder &bb, value mm) { - auto zero = bb.add(make_constant(0, index_ty)); - // @todo need for loop that yields values - bb.for_loop(zero, c_trip_count, index_ty, [&](region_builder &bb, value n) { - auto index_list = std::array{mm, n}; - if (in.tA() == transpose::T) { - std::swap(index_list[0], index_list[1]); - } - auto a = bb.add(make_load(&in.A(), index_list, in.loc())); - blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), {mm}, - in.loc()); - }); + auto from = bb.add(make_constant(0, index_ty)); + auto zero = bb.add(make_constant_zero(bt->element_data_ty())); + auto acc = + bb.for_loop(from, c_trip_count, {}, {zero}, {bt->element_data_ty()}, index_ty, + [&](region_builder &bb, array_view args) { + auto index_list = std::array{mm, args[0]}; + if (in.tA() == transpose::T) { + std::swap(index_list[0], index_list[1]); + } + auto a = bb.add(make_load(&in.A(), index_list, in.loc())); + auto sum = mixed_precision_arithmetic(bb, arithmetic::add, + args[1], a, in.loc()); + bb.add(make_yield({sum}, in.loc())); + }); + blas_update(bb, in.atomic(), &in.alpha(), acc[0], &in.beta(), &in.B(), {mm}, + in.loc()); }); } return parallel; From aa66075ebe025d25ff6f978f80d267a30124651b Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 16 Oct 2024 14:28:54 +0200 Subject: [PATCH 061/297] Specify cooperative matrix Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 189 +++++++++++++++++++++++++++++++------- 1 file changed, 156 insertions(+), 33 deletions(-) diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index e800ca09..1c24916e 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -16,7 +16,9 @@ The unit of execution described by a function written in the tensor language is called a **kernel**. Kernels are launched in batches, where each instance of the kernel is called a work-group. The kernel has access to its group id that is used to select the work done in the work group. -Each work group consists of a fixed number of work-items that execute concurrently. +Each work group consists of a fixed number of subgroups that execute concurrently. +Subgroups can be further divided into work-items, where the number of work-items per subgroup +is given by the subgroup size. The language distinguishes between *collective*, *SPMD*, and *mixed* instructions. A collective instruction distributes the work among the work-items in an implementation-defined manner. @@ -251,6 +253,21 @@ Dynamic values ('?') may appear in the memref-type and in the offset. These values are stored in the dope vector; the calling convention for groups is implementation-defined. +Cooperative matrix type +----------------------- + +.. code:: abnf + + coopmatrix-type = "coopmatrix<" scalar-type 2*2("x" integer-constant) "," matrix-use ">" + matrix-use = "matrix_a" / "matrix_b" / "matrix_acc" + +The coopmatrix represents a matrix distributed across a subgroup, where each work-item in a subgroup +stores a part of the matrix. +The scalar-type specifies the matrix element type, the first integer-constant the number of rows, +and the second integer-constant the number of columns. +The matrix-use may affect the distribution of the matrix in the subgroup, and the name refers to the +position of the matrix in a matrix multiplication. + Instructions ============ @@ -587,31 +604,33 @@ Arithmetic (binary) ".and" / ".or" / ".xor" - value-instruction =/ "arith" arith-binary-type local-identifier "," local-identifier ":" scalar-type + value-instruction =/ "arith" arith-binary-type local-identifier "," local-identifier + ":" (scalar-type / coopmatrix-type) Overview ~~~~~~~~ -Binary arithmetic operation on scalars. -Both operands, as well as the returned type, have the same scalar type. +Binary arithmetic operation on scalars and cooperative matrices. +Both operands, as well as the returned type, have the same (underlying) scalar type. +Arithmetic on cooperative matrices is done component-wise. The following table shows the operations' description and the types that are allowed for the operation. The backslash "\\" is used to exclude types from the list of allowed types. -==== ============================ ================================================================ -Op Allowed type Description -==== ============================ ================================================================ -.add scalar-type Sum of operands -.sub scalar-type Difference of operands -.mul scalar-type Product of operands -.div scalar-type Quotient of operands -.rem scalar-type \\ complex-type Remainder from the division of operands -.shl integer-type \\ i1 Left shift first operand by second operand -.shr integer-type \\ i1 Arithmetic right shift first operand by second operand -.and integer-type Bitwise and -.or integer-type Bitwise or -.xor integer-type Bitwise xor -==== ============================ ================================================================ +==== ============================= ================================================================ +Op Allowed type Description +==== ============================= ================================================================ +.add scalar-type / coopmatrix-type Sum of operands +.sub scalar-type / coopmatrix-type Difference of operands +.mul scalar-type / coopmatrix-type Product of operands +.div scalar-type / coopmatrix-type Quotient of operands +.rem scalar-type \\ complex-type Remainder from the division of operands +.shl integer-type \\ i1 Left shift first operand by second operand +.shr integer-type \\ i1 Arithmetic right shift first operand by second operand +.and integer-type Bitwise and +.or integer-type Bitwise or +.xor integer-type Bitwise xor +==== ============================= ================================================================ Arithmetic (unary) .................. @@ -619,12 +638,13 @@ Arithmetic (unary) .. code:: abnf arith-unary-type = ".abs" / ".neg" / ".not" / ".conj" / ".im" / ".re" - value-instruction =/ "arith" arith-unary-type local-identifier ":" scalar-type + value-instruction =/ "arith" arith-unary-type local-identifier + ":" (scalar-type / coopmatrix-type) Overview ~~~~~~~~ -Unary arithmetic operation on scalars. +Unary arithmetic operation on scalars and cooperative matrices. For integer and floating point input, the returned value has the same type as the operand. For complex input, the returned value has the underlying floating point type for ".abs", ".im", and ".re", and the returned value has the same type as the operand @@ -632,16 +652,16 @@ for ".neg" and ".conj". The following table shows the operations' description and the types that are allowed for the operation. -===== ============ ============================================================================== +===== ============================= ============================= Op Allowed type Description -===== ============ ============================================================================== -.abs scalar-type Compute absolute value -.neg scalar-type Negation -.not integer-type Bitwise not -.conj complex-type Complex conjugate -.im complex-type Extract imaginary part -.re complex-type Extract real part -===== ============ ============================================================================== +===== ============================= ============================= +.abs scalar-type Compute absolute value +.neg scalar-type / coopmatrix-type Negation +.not integer-type Bitwise not +.conj complex-type Complex conjugate +.im complex-type Extract imaginary part +.re complex-type Extract real part +===== ============================= ============================= Barrier ....... @@ -677,11 +697,12 @@ Cast .. code:: abnf value-instruction =/ "cast" local-identifier ":" scalar-type "->" scalar-type + value-instruction =/ "cast" local-identifier ":" coopmatrix-type "->" coopmatrix-type Overview ~~~~~~~~ -Cast scalar values. +Cast scalar values or cooperative matrices. Casts from complex types to non-complex types are forbidden. The following table summarizes the casts and the mapping to SPIR-V: @@ -732,15 +753,117 @@ Constant .. code:: abnf - value-instruction =/ "constant" constant "->" scalar-type + value-instruction =/ "constant" constant "->" (scalar-type / coopmatrix-type) Overview ~~~~~~~~ Sets the result value to a constant value. -The type of the constant must match the scalar type +The type of the constant must match the (underlying) scalar type (e.g. an integer type requires an integer-constant and a floating type requires a floating-constant). +When the result is a cooperative matrix, all entries are set to the same constant value. + +Cooperative matrix load +....................... + +.. code:: abnf + + value-instruction =/ "cooperative_matrix_load" [".checked"] transpose + local-identifier "[" local-identifier "," local-identifier "]" + ":" memref-type "->" coopmatrix-type + +Overview +~~~~~~~~ + +Load a cooperative matrix from a 2d-memref at the position given by the indices in square brackets. +The position gives the starting row and column index, that is, +when a coopmatrix of size :math:`X\times Y` is loaded from memref :math:`M` at +position :math:`x, y`, then the components :math:`A_{ij}` of the coopmatrix are given by + +.. math:: + + \forall i \in [0,X), j \in [0,Y): A_{ij} := M[(x + i) S_1 + (y + j) S_2] + +When the checked flag is set, memory loads that would be out of bounds are not executed and the corresponding +value in the cooperative matrix are set to 0. + +When the transpose modifier ".t" is given, we have + +.. math:: + + \forall i \in [0,X), j \in [0,Y): A_{ij} := M[(x + j) S_1 + (y + i) S_2] + +Arguments +~~~~~~~~~ + +The first operand must have memref type of dimension 2 with the same underlying scalar type +as the coopmatrix type. +The indices must be of ``index`` type. + +Cooperative matrix mul add +.......................... + + +.. code:: abnf + + value-instruction =/ "cooperative_matrix_mul_add" + local-identifier "," local-identifier "," local-identifier + ":" coopmatrix-type "," coopmatrix-type "," coopmatrix-type + "->" coopmatrix-type + +Overview +~~~~~~~~ + +Matrix mul add returns the value of + +.. math:: + + AB + C, + +where A, B, and C are matrices given by the three operands. + +The operands must have cooperative matrix type, where the first operand has shape :math:`M\times K` +with use "matrix_a", the second operand has shape :math:`K\times N` with use "matrix_b", +and the third operand and the result have shape :math:`M\times N` with use "matrix_acc". + +The underlying scalar types of the operands and the result do not need to match. + +Cooperative matrix store +........................ + +.. code:: abnf + + instruction =/ "cooperative_matrix_store" [".checked"] [store-flag] + local-identifier "," local-identifier "[" local-identifier "," local-identifier "]" + ":" coopmatrix-type "," memref-type + +Overview +~~~~~~~~ + +Store a cooperative matrix value in a 2d-memref at the position given by the indices in square brackets. +The position gives the starting row and column index, that is, +when a coopmatrix of size :math:`X\times Y` is written to memref :math:`M` at +position :math:`x, y`, then the components :math:`A_{ij}` of the coopmatrix are written to + +.. math:: + + \forall i \in [0,X), j \in [0,Y): M[(x + i) S_1 + (y + j) S_2] := A_{ij} + +If the checked flag is set, only memory locations that are in-bounds are written. + +The store is atomic when the atomic flag is set with relaxed memory ordering. +When the atomic_add flag is set, the coopmatrix is added to the memref atomically. + +When storing a complex value the update may be pseudo-atomic, meaning that an atomic store is used +for the the real and imaginary separately. + +Arguments +~~~~~~~~~ + +The first operand must have cooperative matrix type with the same underlying scalar type as the memref type. +The indices must be of ``index`` type. + Expand ...... @@ -1221,7 +1344,7 @@ Subgroup local id Overview ~~~~~~~~ -Returns the work-item id within the sub-group; i32 integer from 0 to subgroup_size - 1. +Returns the work-item id within the subgroup; i32 integer from 0 to subgroup_size - 1. Sample code =========== From db369171c6cd13c8a457df484846672b3d8b10b8 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 17 Oct 2024 10:17:05 +0200 Subject: [PATCH 062/297] Introduce cooperative matrix type Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.yaml | 3 + docs/api/builder_cxxapi.yaml | 5 +- docs/manual/tensor-ir.rst | 32 +++-- include/tinytc/tinytc.h | 19 +++ include/tinytc/tinytc.hpp | 31 ++++ include/tinytc/types.h | 56 +++++--- include/tinytc/types.hpp | 10 ++ src/compiler_context_cache.hpp | 1 + src/data_type.cpp | 26 ++++ src/error.cpp | 6 + src/node/data_type_node.cpp | 47 ++++++ src/node/data_type_node.hpp | 43 +++++- src/node/inst_node.cpp | 232 ++++++++++++++++++------------ src/parser/lexer.re | 6 + src/parser/parser_impl.yy | 19 ++- src/pass/constant_propagation.cpp | 27 +++- src/pass/convert_to_opencl.cpp | 170 +++++++++++++++------- src/pass/convert_to_opencl.hpp | 1 + src/pass/dump_ir.cpp | 5 + src/pass/dump_ir.hpp | 1 + 20 files changed, 561 insertions(+), 179 deletions(-) diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index 1e25b4e9..c313a952 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -7,6 +7,7 @@ Builder C-API: - tinytc_arithmetic_t - tinytc_arithmetic_unary_t - tinytc_cmp_condition_t + - tinytc_matrix_use_t - tinytc_scalar_type_t - tinytc_store_flag_t - tinytc_transpose_t @@ -17,6 +18,7 @@ Builder C-API: - tinytc_arithmetic_to_string - tinytc_arithmetic_unary_to_string - tinytc_cmp_condition_to_string + - tinytc_matrix_use_to_string - tinytc_scalar_type_size - tinytc_scalar_type_to_string - tinytc_store_flag_to_string @@ -41,6 +43,7 @@ Builder C-API: - const_tinytc_value_t Data Type: function: + - tinytc_coopmatrix_type_get - tinytc_group_type_get - tinytc_memref_type_get - tinytc_scalar_type_get diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index c06aaa38..5e488d39 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -7,6 +7,7 @@ Builder C++-API: - tinytc::arithmetic - tinytc::arithmetic_unary - tinytc::cmp_condition + - tinytc::matrix_use - tinytc::scalar_type - tinytc::store_flag - tinytc::transpose @@ -16,6 +17,7 @@ Builder C++-API: - tinytc::to_string(arithmetic) - tinytc::to_string(arithmetic_unary) - tinytc::to_string(cmp_condition) + - tinytc::to_string(matrix_use) - tinytc::to_string(scalar_type) - tinytc::to_string(store_flag) - tinytc::to_string(transpose) @@ -29,8 +31,9 @@ Builder C++-API: - tinytc::dynamic Data Type: function: - - tinytc::get_memref + - tinytc::get_coopmatrix - tinytc::get_group + - tinytc::get_memref - tinytc::get_scalar struct: - tinytc::to_scalar_type diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 1c24916e..75b221eb 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -611,7 +611,7 @@ Overview ~~~~~~~~ Binary arithmetic operation on scalars and cooperative matrices. -Both operands, as well as the returned type, have the same (underlying) scalar type. +Both operands, as well as the returned type, have the same scalar or component type. Arithmetic on cooperative matrices is done component-wise. The following table shows the operations' description and the types that are allowed for the operation. @@ -646,7 +646,7 @@ Overview Unary arithmetic operation on scalars and cooperative matrices. For integer and floating point input, the returned value has the same type as the operand. -For complex input, the returned value has the underlying floating point type +For complex input, the returned value has the component floating point type for ".abs", ".im", and ".re", and the returned value has the same type as the operand for ".neg" and ".conj". @@ -703,8 +703,10 @@ Overview ~~~~~~~~ Cast scalar values or cooperative matrices. +The shape and the use the coopmatrix types must match. Casts from complex types to non-complex types are forbidden. -The following table summarizes the casts and the mapping to SPIR-V: +The following table summarizes the casts and the mapping to SPIR-V +(the casts are done component-wise for coopmatrix types): ============= ============= ================================================== Operand type Result type SPIR-V Op @@ -759,7 +761,7 @@ Overview ~~~~~~~~ Sets the result value to a constant value. -The type of the constant must match the (underlying) scalar type +The type of the constant must match the scalar or component type (e.g. an integer type requires an integer-constant and a floating type requires a floating-constant). When the result is a cooperative matrix, all entries are set to the same constant value. @@ -797,14 +799,13 @@ When the transpose modifier ".t" is given, we have Arguments ~~~~~~~~~ -The first operand must have memref type of dimension 2 with the same underlying scalar type +The first operand must have memref type of dimension 2 with the same component type as the coopmatrix type. The indices must be of ``index`` type. Cooperative matrix mul add .......................... - .. code:: abnf value-instruction =/ "cooperative_matrix_mul_add" @@ -827,7 +828,22 @@ The operands must have cooperative matrix type, where the first operand has shap with use "matrix_a", the second operand has shape :math:`K\times N` with use "matrix_b", and the third operand and the result have shape :math:`M\times N` with use "matrix_acc". -The underlying scalar types of the operands and the result do not need to match. +The component types of the operands and the result do not need to match. + +Cooperative matrix scale +........................ + +.. code:: abnf + + value-instruction =/ "cooperative_matrix_scale" + local-identifier "," local-identifier + ":" scalar-type "," coopmatrix-type + +Overview +~~~~~~~~ + +Scale a matrix by a scalar. +The scalar type of the scalar and the component type of the matrix must match. Cooperative matrix store ........................ @@ -861,7 +877,7 @@ for the the real and imaginary separately. Arguments ~~~~~~~~~ -The first operand must have cooperative matrix type with the same underlying scalar type as the memref type. +The first operand must have cooperative matrix type with the same component type as the memref type. The indices must be of ``index`` type. Expand diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 34c598d2..97b09b31 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -102,6 +102,23 @@ TINYTC_EXPORT tinytc_status_t tinytc_memref_type_get(tinytc_data_type_t *dt, TINYTC_EXPORT tinytc_status_t tinytc_group_type_get(tinytc_data_type_t *dt, tinytc_data_type_t memref_ty, int64_t offset, const tinytc_location_t *loc); +/** + * @brief Get coopmatrix data type + * + * Note: modifies compiler context + * + * @param dt [out] pointer to the data type object created + * @param scalar_ty [in] component type + * @param rows [in] number of rows + * @param cols [in] number of cols + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_coopmatrix_type_get(tinytc_data_type_t *dt, + tinytc_data_type_t scalar_ty, int64_t rows, + int64_t cols, tinytc_matrix_use_t use, + const tinytc_location_t *loc); //////////////////////////// /////////// Value ////////// @@ -154,6 +171,8 @@ TINYTC_EXPORT char const *tinytc_arithmetic_to_string(tinytc_arithmetic_t op); TINYTC_EXPORT char const *tinytc_arithmetic_unary_to_string(tinytc_arithmetic_unary_t op); //! Convert cmp condition to string TINYTC_EXPORT char const *tinytc_cmp_condition_to_string(tinytc_cmp_condition_t cond); +//! Convert matrix use to string +TINYTC_EXPORT char const *tinytc_matrix_use_to_string(tinytc_matrix_use_t u); //! Convert store flag to string TINYTC_EXPORT char const *tinytc_store_flag_to_string(tinytc_store_flag_t flag); //! Convert transpose to string diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 44d5718e..24eded3e 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -605,6 +605,26 @@ inline data_type get_group(data_type memref_ty, std::int64_t offset = 0, locatio return gt; } +/** + * @brief Get a coopmatrix data type + * + * @param scalar_ty Component type + * @param rows Number of rows + * @param cols Number of cols + * @param use Matrix use + * @param loc Source code location + * + * @return Data type + */ +inline data_type get_coopmatrix(data_type scalar_ty, std::int64_t rows, std::int64_t cols, + matrix_use use, location const &loc = {}) { + tinytc_data_type_t ct; + CHECK_STATUS_LOC(tinytc_coopmatrix_type_get(&ct, scalar_ty, rows, cols, + static_cast<::tinytc_matrix_use_t>(use), &loc), + loc); + return ct; +} + //////////////////////////// /////////// Value ////////// //////////////////////////// @@ -684,6 +704,17 @@ inline char const *to_string(cmp_condition cond) { return ::tinytc_cmp_condition_to_string(static_cast<::tinytc_cmp_condition_t>(cond)); } +/** + * @brief Convert matrix use to string + * + * @param u Matrix use + * + * @return C-string + */ +inline char const *to_string(matrix_use u) { + return ::tinytc_matrix_use_to_string(static_cast<::tinytc_matrix_use_t>(u)); +} + /** * @brief Convert store flag to string * diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 8af0704a..dc430f25 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -49,28 +49,33 @@ typedef enum { tinytc_status_ir_scalar_mismatch = 0x104, ///< Mismatch of scalar types tinytc_status_ir_invalid_number_of_indices = 0x105, /// Invalid number of indices tinytc_status_ir_expected_scalar = 0x106, ///< Expected a value of scalar type - tinytc_status_ir_expected_memref = 0x107, ///< Expected a value of memref type - tinytc_status_ir_expected_memref_or_scalar = 0x108, ///< Expected memref or scalar type - tinytc_status_ir_expected_memref_or_group = 0x109, ///< Expected a value of memref or group type - tinytc_status_ir_expected_vector_or_matrix = 0x10a, ///< Expected a vector or marix - tinytc_status_ir_unexpected_yield = 0x10b, ///< Unexpected yield instruction - tinytc_status_ir_yield_mismatch = 0x10c, ///< Wrong number of yielded values - tinytc_status_ir_subview_mismatch = 0x10d, ///< Mismatch in subview - tinytc_status_ir_invalid_slice = 0x10e, ///< Invalid slice - tinytc_status_ir_expand_shape_order_too_small = 0x10f, ///< Expand shape too small - tinytc_status_ir_expand_shape_mismatch = 0x110, ///< Invalid expand shape - tinytc_status_ir_collective_called_from_spmd = 0x111, ///< Collective instruction from SPMD - tinytc_status_ir_fp_unsupported = 0x112, ///< Instruction does not support floating type - tinytc_status_ir_spmd_called_from_collective = 0x113, ///< SPMD instruction from collective - tinytc_status_ir_expected_local_address_space = 0x114, ///< Expected local address space - tinytc_status_ir_expected_global_address_space = 0x115, ///< Expected global address space - tinytc_status_ir_invalid_offset = 0x116, ///< Invalid offset - tinytc_status_ir_int_unsupported = 0x117, ///< Instruction does not support int type - tinytc_status_ir_i1_unsupported = 0x118, ///< Instruction does not support i1 type - tinytc_status_ir_complex_unsupported = 0x119, ///< Instruction does not support complex type - tinytc_status_ir_forbidden_cast = 0x11a, ///< Forbidden cast - tinytc_status_ir_invalid_beta = 0x11b, ///< Invalid beta value - tinytc_status_ir_init_return_mismatch = 0x11c, ///< Mismatch of init values and returned values + tinytc_status_ir_expected_coopmatrix = 0x107, ///< Expected a value of coopmatrix type + tinytc_status_ir_expected_coopmatrix_or_scalar = + 0x108, ///< Expected a value of coopmatrix or scalar type + tinytc_status_ir_expected_memref = 0x109, ///< Expected a value of memref type + tinytc_status_ir_expected_memref_or_scalar = 0x10a, ///< Expected memref or scalar type + tinytc_status_ir_expected_memref_or_group = 0x10b, ///< Expected a value of memref or group type + tinytc_status_ir_expected_vector_or_matrix = 0x10c, ///< Expected a vector or marix + tinytc_status_ir_unexpected_yield = 0x10d, ///< Unexpected yield instruction + tinytc_status_ir_yield_mismatch = 0x10e, ///< Wrong number of yielded values + tinytc_status_ir_subview_mismatch = 0x10f, ///< Mismatch in subview + tinytc_status_ir_invalid_slice = 0x110, ///< Invalid slice + tinytc_status_ir_expand_shape_order_too_small = 0x111, ///< Expand shape too small + tinytc_status_ir_expand_shape_mismatch = 0x112, ///< Invalid expand shape + tinytc_status_ir_collective_called_from_spmd = 0x113, ///< Collective instruction from SPMD + tinytc_status_ir_fp_unsupported = 0x114, ///< Instruction does not support floating type + tinytc_status_ir_spmd_called_from_collective = 0x115, ///< SPMD instruction from collective + tinytc_status_ir_expected_local_address_space = 0x116, ///< Expected local address space + tinytc_status_ir_expected_global_address_space = 0x117, ///< Expected global address space + tinytc_status_ir_invalid_offset = 0x118, ///< Invalid offset + tinytc_status_ir_int_unsupported = 0x119, ///< Instruction does not support int type + tinytc_status_ir_i1_unsupported = 0x11a, ///< Instruction does not support i1 type + tinytc_status_ir_complex_unsupported = 0x11b, ///< Instruction does not support complex type + tinytc_status_ir_coopmatrix_unsupported = + 0x11c, ///< Instruction does not support coopmatrix type + tinytc_status_ir_forbidden_cast = 0x11d, ///< Forbidden cast + tinytc_status_ir_invalid_beta = 0x11e, ///< Invalid beta value + tinytc_status_ir_init_return_mismatch = 0x11f, ///< Mismatch of init values and returned values // Level zero errors tinytc_status_ze_result_not_ready = 0x10000, ///< ZE_RESULT_NOT_READY tinytc_status_ze_result_error_device_lost = 0x10001, ///< ZE_RESULT_ERROR_DEVICE_LOST @@ -290,6 +295,13 @@ typedef enum { tinytc_store_flag_atomic_add = 2 ///< Atomic fetch add } tinytc_store_flag_t; +//! Matrix use +typedef enum { + tinytc_matrix_use_a, ///< matrix_a + tinytc_matrix_use_b, ///< matrix_b + tinytc_matrix_use_acc ///< matrix_acc +} tinytc_matrix_use_t; + //! Core features that may be optionally enabled typedef enum { /** diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 73b6d99f..e9bf42c8 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -59,6 +59,8 @@ enum class status { ir_scalar_mismatch = tinytc_status_ir_scalar_mismatch, ir_invalid_number_of_indices = tinytc_status_ir_invalid_number_of_indices, ir_expected_scalar = tinytc_status_ir_expected_scalar, + ir_expected_coopmatrix = tinytc_status_ir_expected_coopmatrix, + ir_expected_coopmatrix_or_scalar = tinytc_status_ir_expected_coopmatrix_or_scalar, ir_expected_memref = tinytc_status_ir_expected_memref, ir_expected_memref_or_scalar = tinytc_status_ir_expected_memref_or_scalar, ir_expected_memref_or_group = tinytc_status_ir_expected_memref_or_group, @@ -78,6 +80,7 @@ enum class status { ir_int_unsupported = tinytc_status_ir_int_unsupported, ir_i1_unsupported = tinytc_status_ir_i1_unsupported, ir_complex_unsupported = tinytc_status_ir_complex_unsupported, + ir_coopmatrix_unsupported = tinytc_status_ir_coopmatrix_unsupported, ir_forbidden_cast = tinytc_status_ir_forbidden_cast, ir_invalid_beta = tinytc_status_ir_invalid_beta, ir_init_return_mismatch = tinytc_status_ir_init_return_mismatch, @@ -272,6 +275,13 @@ enum class store_flag { atomic_add = tinytc_store_flag_atomic_add ///< Atomic fetch add }; +//! Matrix use +enum class matrix_use { + a = tinytc_matrix_use_a, ///< matrix_a + b = tinytc_matrix_use_b, ///< matrix_b + acc = tinytc_matrix_use_acc ///< matrix_acc +}; + //! @brief Cf. @ref tinytc_core_feature_flag_t enum class core_feature_flag { large_register_file = tinytc_core_feature_flag_large_register_file }; diff --git a/src/compiler_context_cache.hpp b/src/compiler_context_cache.hpp index 679bf182..a6855a0c 100644 --- a/src/compiler_context_cache.hpp +++ b/src/compiler_context_cache.hpp @@ -37,6 +37,7 @@ class compiler_context_cache { std::array, TINYTC_NUMBER_OF_SCALAR_TYPES> scalar_tys; std::unordered_multimap memref_tys; + std::unordered_multimap coopmatrix_tys; std::unordered_map, tinytc_data_type_t> group_tys; }; diff --git a/src/data_type.cpp b/src/data_type.cpp index 2af618c5..3b9058cd 100644 --- a/src/data_type.cpp +++ b/src/data_type.cpp @@ -16,6 +16,19 @@ using namespace tinytc; extern "C" { + +char const *tinytc_matrix_use_to_string(tinytc_matrix_use_t u) { + switch (u) { + case tinytc_matrix_use_a: + return "matrix_a"; + case tinytc_matrix_use_b: + return "matrix_b"; + case tinytc_matrix_use_acc: + return "matrix_acc"; + } + return "unknown"; +} + tinytc_status_t tinytc_scalar_type_get(tinytc_data_type_t *dt, tinytc_compiler_context_t ctx, tinytc_scalar_type_t type) { if (dt == nullptr || ctx == nullptr) { @@ -52,4 +65,17 @@ tinytc_status_t tinytc_group_type_get(tinytc_data_type_t *dt, tinytc_data_type_t return exception_to_status_code( [&] { *dt = group_data_type::get(memref_ty, offset, get_optional(loc)); }); } + +tinytc_status_t tinytc_coopmatrix_type_get(tinytc_data_type_t *dt, tinytc_data_type_t scalar_ty, + int64_t rows, int64_t cols, tinytc_matrix_use_t use, + const tinytc_location_t *loc) { + if (dt == nullptr || scalar_ty == nullptr) { + return tinytc_status_invalid_arguments; + } + + return exception_to_status_code([&] { + *dt = coopmatrix_data_type::get(scalar_ty, rows, cols, enum_cast(use), + get_optional(loc)); + }); +} } diff --git a/src/error.cpp b/src/error.cpp index ba9e8df3..ea9bfa36 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -132,6 +132,10 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Number of indices must match memref order or must be 1 for group types"; case tinytc_status_ir_expected_scalar: return "Expected scalar type"; + case tinytc_status_ir_expected_coopmatrix: + return "Expected coopmatrix type"; + case tinytc_status_ir_expected_coopmatrix_or_scalar: + return "Expected coopmatrix type or scalar type"; case tinytc_status_ir_expected_memref: return "Expected memref type"; case tinytc_status_ir_expected_memref_or_scalar: @@ -171,6 +175,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "i1 type unsupported by instruction"; case tinytc_status_ir_complex_unsupported: return "complex type unsupported by instruction"; + case tinytc_status_ir_coopmatrix_unsupported: + return "coopmatrix type unsupported by instruction"; case tinytc_status_ir_forbidden_cast: return "Forbidden cast"; case tinytc_status_ir_invalid_beta: diff --git a/src/node/data_type_node.cpp b/src/node/data_type_node.cpp index 5a196b30..07bd33c0 100644 --- a/src/node/data_type_node.cpp +++ b/src/node/data_type_node.cpp @@ -18,6 +18,53 @@ namespace tinytc { +auto coopmatrix_data_type::get(tinytc_data_type_t ty, std::int64_t rows, std::int64_t cols, + matrix_use use, location const &lc) -> tinytc_data_type_t { + auto ctx = ty->context(); + + auto key = coopmatrix_data_type_key(ty, rows, cols, use); + std::uint64_t map_key = key.hash(); + + auto &tys = ctx->cache()->coopmatrix_tys; + auto range = tys.equal_range(map_key); + for (auto it = range.first; it != range.second; ++it) { + if (key == *dyn_cast(it->second)) { + return it->second; + } + } + auto new_ct = std::unique_ptr( + new coopmatrix_data_type(ctx, key.ty, key.rows, key.cols, key.use, lc)); + return tys.emplace(map_key, new_ct.release())->second; +} + +coopmatrix_data_type::coopmatrix_data_type(tinytc_compiler_context_t ctx, tinytc_data_type_t ty, + std::int64_t rows, std::int64_t cols, matrix_use use, + location const &lc) + : data_type_node(DTK::coopmatrix, ctx), ty_(std::move(ty)), rows_(rows), cols_(cols), + use_(use) { + if (!isa(*ty_)) { + throw compilation_error(lc, status::ir_expected_scalar); + } + if (rows_ < 0 || is_dynamic_value(rows_)) { + throw compilation_error(lc, status::ir_invalid_shape); + } + if (cols_ < 0 || is_dynamic_value(cols_)) { + throw compilation_error(lc, status::ir_invalid_shape); + } +} + +auto coopmatrix_data_type::component_ty() const -> scalar_type { + return dyn_cast(ty_)->ty(); +} + +auto coopmatrix_data_type_key::hash() -> std::uint64_t { + return fnv1a_combine(ty, rows, cols, use); +} + +auto coopmatrix_data_type_key::operator==(coopmatrix_data_type const &ct) -> bool { + return ty == ct.ty() && rows == ct.rows() && cols == ct.cols() && use == ct.use(); +} + auto group_data_type::get(tinytc_data_type_t ty, std::int64_t offset, location const &lc) -> tinytc_data_type_t { auto ctx = ty->context(); diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp index 0e135c5a..617d3f74 100644 --- a/src/node/data_type_node.hpp +++ b/src/node/data_type_node.hpp @@ -15,9 +15,10 @@ #include namespace tinytc { -enum class DTK { group, memref, scalar, void_ }; -using data_type_nodes = type_list; +enum class DTK { coopmatrix, group, memref, scalar, void_ }; +using data_type_nodes = + type_list; } // namespace tinytc struct tinytc_data_type { @@ -39,6 +40,42 @@ namespace tinytc { using data_type_node = ::tinytc_data_type; +class coopmatrix_data_type : public data_type_node { + public: + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::coopmatrix; } + static auto get(tinytc_data_type_t ty, std::int64_t rows, std::int64_t cols, matrix_use use, + location const &lc = {}) -> tinytc_data_type_t; + + inline auto ty() const -> tinytc_data_type_t { return ty_; } + auto component_ty() const -> scalar_type; + inline auto rows() const -> std::int64_t { return rows_; } + inline auto cols() const -> std::int64_t { return cols_; } + inline auto use() const -> matrix_use { return use_; } + // Number of components per work-item + inline auto length(std::int32_t subgroup_size) const -> std::int64_t { + const std::int64_t blocks = 1 + (rows_ - 1) / subgroup_size; + return blocks * cols_; + } + + protected: + coopmatrix_data_type(tinytc_compiler_context_t ctx, tinytc_data_type_t ty, std::int64_t rows, + std::int64_t cols, matrix_use use, location const &lc = {}); + + private: + tinytc_data_type_t ty_; + std::int64_t rows_, cols_; + matrix_use use_; +}; + +struct coopmatrix_data_type_key { + tinytc_data_type_t ty; + std::int64_t rows, cols; + matrix_use use; + + auto hash() -> std::uint64_t; + auto operator==(coopmatrix_data_type const &ct) -> bool; +}; + class group_data_type : public data_type_node { public: inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::group; } diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 7fee5a0a..d5174430 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -144,47 +144,68 @@ arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b op(op_b, b0); loc(lc); - auto at = get_scalar_type(loc(), a()); - auto bt = get_scalar_type(loc(), b()); + if (isa(*a().ty())) { + if (!isa(*b().ty())) { + throw compilation_error(loc(), status::ir_expected_coopmatrix); + } + bool inst_supports_coopmatrix = false; + switch (operation) { + case arithmetic::add: + case arithmetic::sub: + case arithmetic::mul: + case arithmetic::div: + inst_supports_coopmatrix = true; + break; + default: + break; + } + if (!inst_supports_coopmatrix) { + throw compilation_error(loc(), status::ir_coopmatrix_unsupported); + } + } else { + auto a_ty = get_scalar_type(loc(), a())->ty(); + auto b_ty = get_scalar_type(loc(), b())->ty(); - if (at->ty() != bt->ty()) { - throw compilation_error(loc(), status::ir_scalar_mismatch); - } - bool inst_supports_fp = true; - bool inst_supports_complex = true; - bool inst_supports_i1 = true; - switch (operation) { - case arithmetic::add: - case arithmetic::sub: - case arithmetic::mul: - case arithmetic::div: - break; - case arithmetic::rem: - inst_supports_complex = false; - break; - case arithmetic::and_: - case arithmetic::or_: - case arithmetic::xor_: - inst_supports_fp = false; - inst_supports_complex = false; - break; - case arithmetic::shl: - case arithmetic::shr: - inst_supports_i1 = false; - inst_supports_fp = false; - inst_supports_complex = false; - break; - } - if (!inst_supports_i1 && at->ty() == scalar_type::i1) { - throw compilation_error(loc(), status::ir_i1_unsupported); - } - if (!inst_supports_fp && is_floating_type(at->ty())) { - throw compilation_error(loc(), status::ir_fp_unsupported); - } - if (!inst_supports_complex && is_complex_type(at->ty())) { - throw compilation_error(loc(), status::ir_complex_unsupported); + if (a_ty != b_ty) { + throw compilation_error(loc(), status::ir_scalar_mismatch); + } + bool inst_supports_fp = true; + bool inst_supports_complex = true; + bool inst_supports_i1 = true; + switch (operation) { + case arithmetic::add: + case arithmetic::sub: + case arithmetic::mul: + case arithmetic::div: + break; + case arithmetic::rem: + inst_supports_complex = false; + break; + case arithmetic::and_: + case arithmetic::or_: + case arithmetic::xor_: + inst_supports_fp = false; + inst_supports_complex = false; + break; + case arithmetic::shl: + case arithmetic::shr: + inst_supports_i1 = false; + inst_supports_fp = false; + inst_supports_complex = false; + break; + } + if (!inst_supports_i1 && a_ty == scalar_type::i1) { + throw compilation_error(loc(), status::ir_i1_unsupported); + } + if (!inst_supports_fp && is_floating_type(a_ty)) { + throw compilation_error(loc(), status::ir_fp_unsupported); + } + if (!inst_supports_complex && is_complex_type(a_ty)) { + throw compilation_error(loc(), status::ir_complex_unsupported); + } } - result(0) = value_node{at, this, lc}; + + result(0) = value_node{a().ty(), this, lc}; } arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0, @@ -193,45 +214,55 @@ arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0 op(op_a, a0); loc(lc); - auto a_ty = get_scalar_type(loc(), a()); - tinytc_data_type_t to_ty = a_ty; + tinytc_data_type_t to_ty = nullptr; - bool inst_supports_int = true; - bool inst_supports_fp = true; - bool inst_supports_complex = true; - switch (operation) { - case arithmetic_unary::abs: - case arithmetic_unary::neg: - break; - case arithmetic_unary::not_: - inst_supports_fp = false; - inst_supports_complex = false; - break; - case arithmetic_unary::conj: - case arithmetic_unary::im: - case arithmetic_unary::re: - inst_supports_int = false; - inst_supports_fp = false; - break; - } - if (!inst_supports_int && is_integer_type(a_ty->ty())) { - throw compilation_error(loc(), status::ir_int_unsupported); - } - if (!inst_supports_fp && is_floating_type(a_ty->ty())) { - throw compilation_error(loc(), status::ir_fp_unsupported); - } - if (!inst_supports_complex && is_complex_type(a_ty->ty())) { - throw compilation_error(loc(), status::ir_complex_unsupported); - } - switch (operation) { - case arithmetic_unary::abs: - case arithmetic_unary::im: - case arithmetic_unary::re: - to_ty = scalar_data_type::get(a_ty->context(), element_type(a_ty->ty())); - break; - default: - break; + if (isa(*a().ty())) { + if (operation_ != arithmetic_unary::neg) { + throw compilation_error(loc(), status::ir_coopmatrix_unsupported); + } + to_ty = a().ty(); + } else { + auto a_ty = get_scalar_type(loc(), a()); + to_ty = a_ty; + + bool inst_supports_int = true; + bool inst_supports_fp = true; + bool inst_supports_complex = true; + switch (operation_) { + case arithmetic_unary::abs: + case arithmetic_unary::neg: + break; + case arithmetic_unary::not_: + inst_supports_fp = false; + inst_supports_complex = false; + break; + case arithmetic_unary::conj: + case arithmetic_unary::im: + case arithmetic_unary::re: + inst_supports_int = false; + inst_supports_fp = false; + break; + } + if (!inst_supports_int && is_integer_type(a_ty->ty())) { + throw compilation_error(loc(), status::ir_int_unsupported); + } + if (!inst_supports_fp && is_floating_type(a_ty->ty())) { + throw compilation_error(loc(), status::ir_fp_unsupported); + } + if (!inst_supports_complex && is_complex_type(a_ty->ty())) { + throw compilation_error(loc(), status::ir_complex_unsupported); + } + switch (operation_) { + case arithmetic_unary::abs: + case arithmetic_unary::im: + case arithmetic_unary::re: + to_ty = scalar_data_type::get(a_ty->context(), element_type(a_ty->ty())); + break; + default: + break; + } } + result(0) = value_node{to_ty, this, lc}; } @@ -240,17 +271,33 @@ cast_inst::cast_inst(tinytc_value_t a0, tinytc_data_type_t to_ty, location const op(op_a, a0); loc(lc); - auto rt = dyn_cast(to_ty); - if (rt == nullptr) { - throw compilation_error(lc, status::ir_expected_scalar); - } + auto const check_scalar_casting_rules = [](scalar_type a_ty, scalar_type r_ty, + location const &lc) { + if (is_complex_type(a_ty) && !is_complex_type(r_ty)) { + throw compilation_error(lc, status::ir_forbidden_cast); + } + }; - auto at = get_scalar_type(loc(), a()); - if (is_complex_type(at->ty()) && !is_complex_type(rt->ty())) { - throw compilation_error(lc, status::ir_forbidden_cast); + if (auto ct = dyn_cast(a().ty()); ct) { + auto rt = dyn_cast(to_ty); + if (!rt) { + throw compilation_error(loc(), status::ir_expected_coopmatrix); + } + if (ct->rows() != rt->rows() || ct->cols() != rt->cols() || ct->use() != rt->use()) { + throw compilation_error(lc, status::ir_forbidden_cast); + } + check_scalar_casting_rules(ct->component_ty(), rt->component_ty(), loc()); + } else { + auto rt = dyn_cast(to_ty); + if (rt == nullptr) { + throw compilation_error(lc, status::ir_expected_scalar); + } + + auto at = get_scalar_type(loc(), a()); + check_scalar_casting_rules(at->ty(), rt->ty(), loc()); } - result(0) = value_node{to_ty, this, lc}; + result(0) = value_node{to_ty, this, loc()}; } compare_inst::compare_inst(cmp_condition cond, tinytc_value_t a0, tinytc_value_t b0, @@ -291,17 +338,22 @@ constant_inst::constant_inst(value_type const &value, tinytc_data_type_t ty, loc : standard_inst{IK::constant}, value_(value) { loc(lc); + const auto type_ok = [](value_type const &val, scalar_type ty) { + return (is_integer_type(ty) && std::holds_alternative(val)) || + (is_floating_type(ty) && std::holds_alternative(val)) || + (is_complex_type(ty) && std::holds_alternative>(val)); + }; + if (auto st = dyn_cast(ty); st) { - const auto type_ok = [](value_type const &val, scalar_type ty) { - return (is_integer_type(ty) && std::holds_alternative(val)) || - (is_floating_type(ty) && std::holds_alternative(val)) || - (is_complex_type(ty) && std::holds_alternative>(val)); - }; if (!type_ok(value_, st->ty())) { throw compilation_error(loc(), status::ir_scalar_mismatch); } + } else if (auto ct = dyn_cast(ty); ct) { + if (!type_ok(value_, ct->component_ty())) { + throw compilation_error(loc(), status::ir_scalar_mismatch); + } } else { - throw compilation_error(loc(), status::ir_expected_scalar); + throw compilation_error(loc(), status::ir_expected_coopmatrix_or_scalar); } result(0) = value_node{ty, this, lc}; diff --git a/src/parser/lexer.re b/src/parser/lexer.re index 3832c4ce..58b57fe8 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -127,6 +127,7 @@ lex: auto t = lex_floating_type(b, YYCURSOR); return parser::make_FLOATING_TYPE(t, loc_); } + "coopmatrix" { adv_loc(); return parser::make_COOPMATRIX(loc_); } "memref" { adv_loc(); return parser::make_MEMREF(loc_); } "group" { adv_loc(); return parser::make_GROUP(loc_); } @@ -134,6 +135,11 @@ lex: "offset" { adv_loc(); return parser::make_OFFSET(loc_); } "strided" { adv_loc(); return parser::make_STRIDED(loc_); } + // matrix use + "matrix_a" { adv_loc(); return parser::make_MATRIX_USE(matrix_use::a, loc_); } + "matrix_b" { adv_loc(); return parser::make_MATRIX_USE(matrix_use::b, loc_); } + "matrix_acc" { adv_loc(); return parser::make_MATRIX_USE(matrix_use::acc, loc_); } + // instructions "axpby" { adv_loc(); return parser::make_AXPBY(loc_); } "arith" { adv_loc(); return parser::make_ARITH(loc_); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 13d7c99d..0ada8d08 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -101,6 +101,7 @@ GLOBAL "global" LOCAL_ATTR ".local" GLOBAL_ATTR ".global" + COOPMATRIX "coopmatrix" MEMREF "memref" GROUP "group" OFFSET "offset" @@ -144,6 +145,7 @@ %token ARITHMETIC %token ARITHMETIC_UNARY %token CMP_CONDITION +%token MATRIX_USE %nterm prog %nterm > func_list @@ -154,6 +156,7 @@ %nterm > attribute %nterm data_type %nterm scalar_type +%nterm coopmatrix_type %nterm memref_type %nterm optional_address_space %nterm > mode_list @@ -314,6 +317,7 @@ attribute: data_type: scalar_type + | coopmatrix_type | memref_type | group_type ; @@ -323,6 +327,17 @@ scalar_type: | FLOATING_TYPE { $$ = get_scalar(ctx.cctx(), $FLOATING_TYPE); } ; +coopmatrix_type: + COOPMATRIX LCHEV scalar_type TIMES INTEGER_CONSTANT[rows] TIMES INTEGER_CONSTANT[cols] COMMA MATRIX_USE RCHEV { + try { + $$ = get_coopmatrix($scalar_type, $rows, $cols, $MATRIX_USE, @coopmatrix_type); + } catch (compilation_error const& e) { + error(e.loc(), e.what()); + YYERROR; + } + } +; + memref_type: MEMREF LCHEV scalar_type mode_list optional_address_space RCHEV { try { @@ -775,7 +790,7 @@ alloca_inst: ; arith_inst: - ARITH ARITHMETIC var[a] COMMA var[b] COLON scalar_type[ty] { + ARITH ARITHMETIC var[a] COMMA var[b] COLON data_type[ty] { check_type($a, $ty, @a, @ty); check_type($b, $ty, @b, @ty); try { @@ -791,7 +806,7 @@ arith_inst: ; arith_unary_inst: - ARITH ARITHMETIC_UNARY var[a] COLON scalar_type[ty] { + ARITH ARITHMETIC_UNARY var[a] COLON data_type[ty] { check_type($a, $ty, @a, @ty); try { $$ = inst { diff --git a/src/pass/constant_propagation.cpp b/src/pass/constant_propagation.cpp index e14df6b6..bb495bcf 100644 --- a/src/pass/constant_propagation.cpp +++ b/src/pass/constant_propagation.cpp @@ -169,7 +169,14 @@ auto constant_folding::operator()(arith_inst &in) -> fold_result { auto at = dyn_cast(op_a.ty()); if (at == nullptr) { - throw compilation_error(op_a.loc(), status::ir_expected_scalar); + // Arithmetic on coopmatrix is component-wise and if a coopmatrix is constant, then all + // elements have the same value. Thus, constant folding on coopmatrix types is identical to + // constant folding on scalar types. + auto ct = dyn_cast(op_a.ty()); + if (ct == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_coopmatrix_or_scalar); + } + at = dyn_cast(ct->ty()); } if (a_const != nullptr && b_const != nullptr) { @@ -200,7 +207,14 @@ auto constant_folding::operator()(arith_unary_inst &in) -> fold_result { auto at = dyn_cast(op_a.ty()); if (at == nullptr) { - throw compilation_error(op_a.loc(), status::ir_expected_scalar); + // Arithmetic on coopmatrix is component-wise and if a coopmatrix is constant, then all + // elements have the same value. Thus, constant folding on coopmatrix types is identical to + // constant folding on scalar types. + auto ct = dyn_cast(op_a.ty()); + if (ct == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_coopmatrix_or_scalar); + } + at = dyn_cast(ct->ty()); } auto computer = compute_unary_op{in.operation(), op_a.ty(), in.loc()}; @@ -218,7 +232,14 @@ auto constant_folding::operator()(cast_inst &in) -> fold_result { auto rt = dyn_cast(in.result(0).ty()); if (rt == nullptr) { - throw compilation_error(in.result(0).loc(), status::ir_expected_scalar); + // Cast on coopmatrix is component-wise and if a coopmatrix is constant, then all + // elements have the same value. Thus, constant folding on coopmatrix types is identical to + // constant folding on scalar types. + auto ct = dyn_cast(in.result(0).ty()); + if (ct == nullptr) { + throw compilation_error(in.result(0).loc(), status::ir_expected_coopmatrix_or_scalar); + } + rt = dyn_cast(ct->ty()); } return std::visit( diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 63d43bee..5bca1f2b 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -156,20 +156,20 @@ auto convert_to_opencl_pass::get_memref_type(value_node const &v) const } auto convert_to_opencl_pass::get_scalar_type(value_node const &v) -> scalar_type { - return visit(overloaded{[](scalar_data_type const &i) -> scalar_type { return i.ty(); }, - [](memref_data_type const &i) -> scalar_type { return i.element_ty(); }, - [&](auto const &) -> scalar_type { - throw compilation_error(v.loc(), - status::ir_expected_memref_or_scalar); - return scalar_type{}; - }}, - *v.ty()); + auto st = dyn_cast(v.ty()); + if (!st) { + throw compilation_error(v.loc(), status::ir_expected_scalar); + } + return st->ty(); }; /* Data type nodes */ clir::data_type convert_to_opencl_pass::operator()(void_data_type const &) { return clir::builtin_type::void_t; } +clir::data_type convert_to_opencl_pass::operator()(coopmatrix_data_type const &ct) { + return array_of(to_clir_ty(ct.component_ty()), ct.length(core_cfg_.subgroup_size)); +} clir::data_type convert_to_opencl_pass::operator()(group_data_type const &g) { auto ptr_ty = visit(*this, *g.ty()); ptr_ty = clir::visit(overloaded{[](clir::internal::pointer &t) { @@ -355,10 +355,27 @@ std::vector convert_to_opencl_pass::operator()(arith_inst const &a) } return {}; }; - auto sty = get_scalar_type(a.a()); - auto v = declare(*a.result()); - return {declaration_assignment(visit(*this, *a.result()->ty()), std::move(v), - make(a.operation(), val(a.a()), val(a.b()), sty))}; + + auto lhs = declare(a.result(0)); + auto lhs_ty = visit(*this, *a.result()->ty()); + auto av = val(a.a()); + auto bv = val(a.b()); + if (auto st = dyn_cast(a.result(0).ty()); st) { + auto op = make(a.operation(), av, bv, st->ty()); + return {declaration_assignment(std::move(lhs_ty), std::move(lhs), std::move(op))}; + } else if (auto ct = dyn_cast(a.result(0).ty()); ct) { + auto clinst = std::vector{}; + auto const len = ct->length(core_cfg_.subgroup_size); + clinst.reserve(len + 1); + clinst.emplace_back(declaration(std::move(lhs_ty), lhs)); + const auto sty = ct->component_ty(); + for (std::int64_t i = 0; i < len; ++i) { + auto op = make(a.operation(), av[i], bv[i], sty); + clinst.emplace_back(expression_statement(assignment(lhs[i], std::move(op)))); + } + return clinst; + } + throw compilation_error(a.loc(), status::ir_expected_coopmatrix_or_scalar); } std::vector convert_to_opencl_pass::operator()(arith_unary_inst const &a) { @@ -389,36 +406,72 @@ std::vector convert_to_opencl_pass::operator()(arith_unary_inst cons } return {}; }; - auto sty = get_scalar_type(a.a()); - auto v = declare(*a.result()); - return {declaration_assignment(visit(*this, *a.result()->ty()), std::move(v), - make(a.operation(), val(a.a()), sty))}; + + auto lhs = declare(a.result(0)); + auto lhs_ty = visit(*this, *a.result()->ty()); + auto av = val(a.a()); + if (auto st = dyn_cast(a.a().ty()); st) { + auto op = make(a.operation(), av, st->ty()); + return {declaration_assignment(std::move(lhs_ty), std::move(lhs), std::move(op))}; + } else if (auto ct = dyn_cast(a.a().ty()); ct) { + auto clinst = std::vector{}; + auto const len = ct->length(core_cfg_.subgroup_size); + clinst.reserve(len + 1); + clinst.emplace_back(declaration(std::move(lhs_ty), lhs)); + const auto sty = ct->component_ty(); + for (std::int64_t i = 0; i < len; ++i) { + auto op = make(a.operation(), av[i], sty); + clinst.emplace_back(expression_statement(assignment(lhs[i], std::move(op)))); + } + return clinst; + } + throw compilation_error(a.loc(), status::ir_expected_coopmatrix_or_scalar); } std::vector convert_to_opencl_pass::operator()(cast_inst const &c) { - auto v = declare(*c.result()); - auto aty = get_scalar_type(c.a().ty()); - auto rty = get_scalar_type(c.result(0).ty()); + auto const make = [](clir::expr a, scalar_type aty, scalar_type rty) -> clir::expr { + if (is_complex_type(aty) && is_complex_type(rty)) { + switch (rty) { + case scalar_type::c32: + return clir::call("convert_float2", {std::move(a)}); + case scalar_type::c64: + return clir::call("convert_double2", {std::move(a)}); + default: + throw status::internal_compiler_error; + } + } else if (is_complex_type(rty)) { + return clir::init_vector(to_clir_ty(rty), {std::move(a), 0}); + } + return cast(to_clir_ty(rty), std::move(a)); + }; + + auto lhs = declare(c.result(0)); + auto lhs_ty = visit(*this, *c.result(0).ty()); auto av = val(c.a()); - auto cst = clir::expr{}; - auto result_ty = visit(*this, *c.result()->ty()); - if (is_complex_type(aty) && is_complex_type(rty)) { - switch (rty) { - case scalar_type::c32: - cst = clir::call("convert_float2", {std::move(av)}); - break; - case scalar_type::c64: - cst = clir::call("convert_double2", {std::move(av)}); - break; - default: - throw status::internal_compiler_error; + + if (auto rt = dyn_cast(c.result(0).ty()); rt) { + auto aty = get_scalar_type(c.a()); + auto op = make(av, aty, rt->ty()); + return {declaration_assignment(std::move(lhs_ty), std::move(lhs), std::move(op))}; + } else if (auto ct = dyn_cast(c.result(0).ty()); ct) { + const auto rty = ct->component_ty(); + auto at = dyn_cast(c.a().ty()); + if (!at) { + throw compilation_error(c.loc(), status::ir_expected_coopmatrix); } - } else if (is_complex_type(rty)) { - cst = clir::init_vector(result_ty, {std::move(av), 0}); - } else { - cst = cast(result_ty, std::move(av)); + const auto aty = at->component_ty(); + + auto clinst = std::vector{}; + auto const len = ct->length(core_cfg_.subgroup_size); + clinst.reserve(len + 1); + clinst.emplace_back(declaration(std::move(lhs_ty), lhs)); + for (std::int64_t i = 0; i < len; ++i) { + auto op = make(av[i], aty, rty); + clinst.emplace_back(expression_statement(assignment(lhs[i], std::move(op)))); + } + return clinst; } - return {declaration_assignment(std::move(result_ty), std::move(v), std::move(cst))}; + throw compilation_error(c.loc(), status::ir_expected_coopmatrix_or_scalar); } std::vector convert_to_opencl_pass::operator()(compare_inst const &c) { @@ -445,20 +498,37 @@ std::vector convert_to_opencl_pass::operator()(compare_inst const &c } std::vector convert_to_opencl_pass::operator()(constant_inst const &c) { - auto v = declare(c.result(0)); - auto ty = get_scalar_type(c.result(0)); - auto ty_bits = static_cast(size(ty) * 8); - auto rhs = - std::visit(overloaded{ - [&](std::int64_t i) { return clir::expr(i, ty_bits); }, - [&](double d) { return clir::expr(d, ty_bits); }, - [&](std::complex d) { - return init_vector(to_clir_ty(ty), {clir::expr(d.real(), ty_bits), - clir::expr(d.imag(), ty_bits)}); - }, - }, - c.value()); - return {declaration_assignment(visit(*this, *c.result()->ty()), std::move(v), std::move(rhs))}; + auto const get_rhs = [&c](scalar_type ty, short ty_bits) { + return std::visit(overloaded{ + [&](std::int64_t i) { return clir::expr(i, ty_bits); }, + [&](double d) { return clir::expr(d, ty_bits); }, + [&](std::complex d) { + return init_vector(to_clir_ty(ty), + {clir::expr(d.real(), ty_bits), + clir::expr(d.imag(), ty_bits)}); + }, + }, + c.value()); + }; + auto lhs = declare(c.result(0)); + auto lhs_ty = visit(*this, *c.result()->ty()); + if (auto st = dyn_cast(c.result(0).ty()); st) { + auto ty_bits = static_cast(size(st->ty()) * 8); + return { + declaration_assignment(std::move(lhs_ty), std::move(lhs), get_rhs(st->ty(), ty_bits))}; + } else if (auto ct = dyn_cast(c.result(0).ty()); ct) { + auto ty_bits = static_cast(size(ct->component_ty()) * 8); + auto rhs = get_rhs(ct->component_ty(), ty_bits); + auto clinst = std::vector{}; + auto const len = ct->length(core_cfg_.subgroup_size); + clinst.reserve(len + 1); + clinst.emplace_back(declaration(std::move(lhs_ty), lhs)); + for (std::int64_t i = 0; i < len; ++i) { + clinst.emplace_back(expression_statement(assignment(lhs[i], rhs))); + } + return clinst; + } + throw compilation_error(c.loc(), status::ir_expected_coopmatrix_or_scalar); } std::vector convert_to_opencl_pass::operator()(expand_inst const &e) { diff --git a/src/pass/convert_to_opencl.hpp b/src/pass/convert_to_opencl.hpp index 6203fdf8..539cb4d7 100644 --- a/src/pass/convert_to_opencl.hpp +++ b/src/pass/convert_to_opencl.hpp @@ -61,6 +61,7 @@ class convert_to_opencl_pass { /* Data type nodes */ clir::data_type operator()(void_data_type const &); + clir::data_type operator()(coopmatrix_data_type const &ct); clir::data_type operator()(group_data_type const &g); clir::data_type operator()(memref_data_type const &m); clir::data_type operator()(scalar_data_type const &s); diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 4c1bd425..e7f973ae 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -23,6 +23,11 @@ dump_ir_pass::dump_ir_pass(std::ostream &os, int level_limit) : os_(&os), lvl_li /* Data type nodes */ void dump_ir_pass::operator()(void_data_type const &) { *os_ << "void"; } +void dump_ir_pass::operator()(coopmatrix_data_type const &ct) { + *os_ << "coopmatrix<"; + visit(*this, *ct.ty()); + *os_ << "x" << ct.rows() << "x" << ct.cols() << "," << to_string(ct.use()) << ">"; +} void dump_ir_pass::operator()(group_data_type const &g) { *os_ << "group<"; visit(*this, *g.ty()); diff --git a/src/pass/dump_ir.hpp b/src/pass/dump_ir.hpp index 9d080acc..be273053 100644 --- a/src/pass/dump_ir.hpp +++ b/src/pass/dump_ir.hpp @@ -23,6 +23,7 @@ class dump_ir_pass { /* Data type nodes */ void operator()(void_data_type const &); + void operator()(coopmatrix_data_type const &ct); void operator()(group_data_type const &g); void operator()(memref_data_type const &m); void operator()(scalar_data_type const &s); From 17ab25a903381821c1bdf24eab802988d0ec8a7d Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 21 Oct 2024 21:27:00 +0200 Subject: [PATCH 063/297] Add coopmatrix instructions Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.rst | 49 ++++ docs/api/builder_capi.yaml | 4 + docs/api/builder_cxxapi.rst | 57 ++++- docs/api/builder_cxxapi.yaml | 4 + docs/manual/tensor-ir.rst | 18 +- include/tinytc/tinytc.h | 76 ++++++ include/tinytc/tinytc.hpp | 80 +++++++ include/tinytc/types.h | 54 +++-- include/tinytc/types.hpp | 4 + src/error.cpp | 9 + src/inst.cpp | 52 +++++ src/node/data_type_node.hpp | 8 +- src/node/inst_node.cpp | 151 +++++++++++- src/node/inst_node.hpp | 83 ++++++- src/parser/lexer.re | 5 + src/parser/parser_impl.yy | 114 ++++++++- src/pass/convert_to_opencl.cpp | 357 ++++++++++++++++++++++++++++- src/pass/convert_to_opencl.hpp | 5 + src/pass/dump_ir.cpp | 71 ++++++ src/pass/dump_ir.hpp | 4 + test/codegen/coopmatrix_basic.ir | 56 +++++ test/codegen/coopmatrix_load.ir | 153 +++++++++++++ test/codegen/coopmatrix_mul_add.ir | 100 ++++++++ test/codegen/coopmatrix_store.ir | 55 +++++ test/codegen/for.ir | 10 +- test/opt/insert-barrier.ir | 4 +- 26 files changed, 1506 insertions(+), 77 deletions(-) create mode 100644 test/codegen/coopmatrix_basic.ir create mode 100644 test/codegen/coopmatrix_load.ir create mode 100644 test/codegen/coopmatrix_mul_add.ir create mode 100644 test/codegen/coopmatrix_store.ir diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index ee78cef0..5a50545c 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -18,6 +18,8 @@ Common * :ref:`tinytc_cmp_condition_t` + * :ref:`tinytc_matrix_use_t` + * :ref:`tinytc_scalar_type_t` * :ref:`tinytc_store_flag_t` @@ -38,6 +40,8 @@ Common * :ref:`tinytc_cmp_condition_to_string` + * :ref:`tinytc_matrix_use_to_string` + * :ref:`tinytc_scalar_type_size` * :ref:`tinytc_scalar_type_to_string` @@ -105,6 +109,11 @@ tinytc_cmp_condition_t .. doxygenenum:: tinytc_cmp_condition_t +tinytc_matrix_use_t +................... + +.. doxygenenum:: tinytc_matrix_use_t + tinytc_scalar_type_t .................... @@ -151,6 +160,11 @@ tinytc_cmp_condition_to_string .. doxygenfunction:: tinytc_cmp_condition_to_string +tinytc_matrix_use_to_string +........................... + +.. doxygenfunction:: tinytc_matrix_use_to_string + tinytc_scalar_type_size ....................... @@ -262,6 +276,8 @@ Data Type * Functions + * :ref:`tinytc_coopmatrix_type_get` + * :ref:`tinytc_group_type_get` * :ref:`tinytc_memref_type_get` @@ -271,6 +287,11 @@ Data Type Data Type Functions ------------------- +tinytc_coopmatrix_type_get +.......................... + +.. doxygenfunction:: tinytc_coopmatrix_type_get + tinytc_group_type_get ..................... @@ -356,6 +377,14 @@ Instruction * :ref:`tinytc_constant_inst_create_zero` + * :ref:`tinytc_cooperative_matrix_load_inst_create` + + * :ref:`tinytc_cooperative_matrix_mul_add_inst_create` + + * :ref:`tinytc_cooperative_matrix_scale_inst_create` + + * :ref:`tinytc_cooperative_matrix_store_inst_create` + * :ref:`tinytc_expand_inst_create` * :ref:`tinytc_for_inst_create` @@ -464,6 +493,26 @@ tinytc_constant_inst_create_zero .. doxygenfunction:: tinytc_constant_inst_create_zero +tinytc_cooperative_matrix_load_inst_create +.......................................... + +.. doxygenfunction:: tinytc_cooperative_matrix_load_inst_create + +tinytc_cooperative_matrix_mul_add_inst_create +............................................. + +.. doxygenfunction:: tinytc_cooperative_matrix_mul_add_inst_create + +tinytc_cooperative_matrix_scale_inst_create +........................................... + +.. doxygenfunction:: tinytc_cooperative_matrix_scale_inst_create + +tinytc_cooperative_matrix_store_inst_create +........................................... + +.. doxygenfunction:: tinytc_cooperative_matrix_store_inst_create + tinytc_expand_inst_create ......................... diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index c313a952..6d99df31 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -67,6 +67,10 @@ Builder C-API: - tinytc_constant_inst_create_int - tinytc_constant_inst_create_one - tinytc_constant_inst_create_zero + - tinytc_cooperative_matrix_load_inst_create + - tinytc_cooperative_matrix_mul_add_inst_create + - tinytc_cooperative_matrix_scale_inst_create + - tinytc_cooperative_matrix_store_inst_create - tinytc_expand_inst_create - tinytc_for_inst_create - tinytc_foreach_inst_create diff --git a/docs/api/builder_cxxapi.rst b/docs/api/builder_cxxapi.rst index ba128a65..9c10938f 100644 --- a/docs/api/builder_cxxapi.rst +++ b/docs/api/builder_cxxapi.rst @@ -18,6 +18,8 @@ Common * :ref:`cmp_condition` + * :ref:`matrix_use` + * :ref:`scalar_type` * :ref:`store_flag` @@ -36,6 +38,8 @@ Common * :ref:`to_string(cmp_condition)` + * :ref:`to_string(matrix_use)` + * :ref:`to_string(scalar_type)` * :ref:`to_string(store_flag)` @@ -81,6 +85,11 @@ cmp_condition .. doxygenenum:: tinytc::cmp_condition +matrix_use +.......... + +.. doxygenenum:: tinytc::matrix_use + scalar_type ........... @@ -124,6 +133,11 @@ to_string(cmp_condition) .. doxygenfunction:: tinytc::to_string(cmp_condition) +to_string(matrix_use) +..................... + +.. doxygenfunction:: tinytc::to_string(matrix_use) + to_string(scalar_type) ...................... @@ -178,10 +192,12 @@ Data Type * Functions - * :ref:`get_memref` + * :ref:`get_coopmatrix` * :ref:`get_group` + * :ref:`get_memref` + * :ref:`get_scalar` * Structures @@ -199,16 +215,21 @@ Data Type Data Type Functions ------------------- -get_memref -.......... +get_coopmatrix +.............. -.. doxygenfunction:: tinytc::get_memref +.. doxygenfunction:: tinytc::get_coopmatrix get_group ......... .. doxygenfunction:: tinytc::get_group +get_memref +.......... + +.. doxygenfunction:: tinytc::get_memref + get_scalar .......... @@ -294,6 +315,14 @@ Instruction * :ref:`make_constant_zero` + * :ref:`make_cooperative_matrix_load` + + * :ref:`make_cooperative_matrix_mul_add` + + * :ref:`make_cooperative_matrix_scale` + + * :ref:`make_cooperative_matrix_store` + * :ref:`make_expand` * :ref:`make_for` @@ -405,6 +434,26 @@ make_constant_zero .. doxygenfunction:: tinytc::make_constant_zero +make_cooperative_matrix_load +............................ + +.. doxygenfunction:: tinytc::make_cooperative_matrix_load + +make_cooperative_matrix_mul_add +............................... + +.. doxygenfunction:: tinytc::make_cooperative_matrix_mul_add + +make_cooperative_matrix_scale +............................. + +.. doxygenfunction:: tinytc::make_cooperative_matrix_scale + +make_cooperative_matrix_store +............................. + +.. doxygenfunction:: tinytc::make_cooperative_matrix_store + make_expand ........... diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index 5e488d39..f2c07869 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -60,6 +60,10 @@ Builder C++-API: - tinytc::make_constant(std::int64_t,data_type,location const&) - tinytc::make_constant_one - tinytc::make_constant_zero + - tinytc::make_cooperative_matrix_load + - tinytc::make_cooperative_matrix_mul_add + - tinytc::make_cooperative_matrix_scale + - tinytc::make_cooperative_matrix_store - tinytc::make_expand - tinytc::make_for - tinytc::make_foreach diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 75b221eb..8a1a89dd 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -268,6 +268,9 @@ and the second integer-constant the number of columns. The matrix-use may affect the distribution of the matrix in the subgroup, and the name refers to the position of the matrix in a matrix multiplication. +Not all matrix shapes need to be supported in the implementation. +The supported matrix shapes may depend on data type, matrix use, and target hardware. + Instructions ============ @@ -771,7 +774,7 @@ Cooperative matrix load .. code:: abnf - value-instruction =/ "cooperative_matrix_load" [".checked"] transpose + value-instruction =/ "cooperative_matrix_load" transpose [".checked"] local-identifier "[" local-identifier "," local-identifier "]" ":" memref-type "->" coopmatrix-type @@ -951,11 +954,12 @@ For multi-value-instruction = "for" local-identifier "=" local-identifier "," local-identifier ["," local-identifier] - ["init" "(" init-value-list ")" "->" "(" scalar-type-list ")" ] + ["init" "(" init-value-list ")" "->" "(" return-type-list ")" ] [":" integer-type] region init-value-list = init-value *("," init-value) init-value = local-identifier "=" local-identifier - scalar-type-list = scalar-type *("," scalar-type) + return-type-list = return-type *("," return-type) + return-type = scalar-type / coopmatrix-type Overview ~~~~~~~~ @@ -1090,7 +1094,7 @@ If .. code:: abnf - multi-value-instruction =/ "if" local-identifier ["->" "(" scalar-type-list ")"] + multi-value-instruction =/ "if" local-identifier ["->" "(" return-type-list ")"] region ["else" region] Overview @@ -1105,7 +1109,7 @@ Returns ~~~~~~~ The if instruction may return multiple values, where the number of values and the value types -are given by the scalar-type-list. +are given by the return-type-list. If values are returned, the last instruction in both the "then"-region and the "else"-region must be a yield instruction (the "else"-region cannot be omitted). @@ -1316,7 +1320,7 @@ Yield .. code:: abnf - instruction =/ "yield" [local-identifier-list] ":" [scalar-type-list] + instruction =/ "yield" [local-identifier-list] ":" [return-type-list] Overview ~~~~~~~~ @@ -1326,7 +1330,7 @@ Yield returns values from an if or for instruction. Arguments ~~~~~~~~~ -The length of the local identifier list must equal the length of the scalar type list. +The length of the local identifier list must equal the length of the return type list. Additional instructions ....................... diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 97b09b31..139b5a8e 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -111,6 +111,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_group_type_get(tinytc_data_type_t *dt, * @param scalar_ty [in] component type * @param rows [in] number of rows * @param cols [in] number of cols + * @param use [in] matrix use * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise @@ -316,6 +317,81 @@ TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_zero(tinytc_inst_t *in tinytc_data_type_t ty, const tinytc_location_t *loc); +/** + * @brief Create cooperative matrix load instruction + * + * @code %value = cooperative_matrix_load.transpose.checked %op[%p0, %p1] : type(%op) -> to_ty + * @endcode + * + * @param instr [out] pointer to the inst object created + * @param transpose [in] transpose operation applied on load + * @param checked [in] true for out-of-bounds checks + * @param op [in] %op + * @param p0 [in] %p0 + * @param p1 [in] %p1 + * @param to_ty [in] result type + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_load_inst_create( + tinytc_inst_t *instr, tinytc_transpose_t transpose, tinytc_bool_t checked, tinytc_value_t op, + tinytc_value_t p0, tinytc_value_t p1, tinytc_data_type_t to_ty, const tinytc_location_t *loc); + +/** + * @brief Create cooperative matrix mul add instruction + * + * @code cooperative_matrix_mul_add %a, %b, %c : type(%a), type(%b), type(%c) -> to_ty @endcode + * + * @param instr [out] pointer to the inst object created + * @param a [in] %a + * @param b [in] %b + * @param c [in] %c + * @param to_ty [in] result type + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_mul_add_inst_create( + tinytc_inst_t *instr, tinytc_value_t a, tinytc_value_t b, tinytc_value_t c, + tinytc_data_type_t to_ty, const tinytc_location_t *loc); + +/** + * @brief Create cooperative matrix scale instruction + * + * @code cooperative_matrix_scale %a, %b : type(%a), type(%b) @endcode + * + * @param instr [out] pointer to the inst object created + * @param a [in] %a + * @param b [in] %b + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_scale_inst_create( + tinytc_inst_t *instr, tinytc_value_t a, tinytc_value_t b, const tinytc_location_t *loc); + +/** + * @brief Create cooperative matrix store instruction + * + * @code cooperative_matrix_store.checked.store_flag %val, %op[%p0, %p1] : type(%val), type(%op) + * @endcode + * + * @param instr [out] pointer to the inst object created + * @param checked [in] true for out-of-bounds checks + * @param flag [in] store flag + * @param val [in] %val + * @param op [in] %op + * @param p0 [in] %p0 + * @param p1 [in] %p1 + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_store_inst_create( + tinytc_inst_t *instr, tinytc_bool_t checked, tinytc_store_flag_t flag, tinytc_value_t val, + tinytc_value_t op, tinytc_value_t p0, tinytc_value_t p1, const tinytc_location_t *loc); + /** * @brief Create alloca instruction * diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 24eded3e..959527a4 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -979,6 +979,86 @@ inline inst make_constant_zero(data_type ty, location const &loc = {}) { return inst(instr); } +/** + * @brief Create cooperative matrix load instruction + * + * @param trans transpose operation applied on load + * @param checked true for out-of-bounds checks + * @param op %op + * @param p0 %p0 + * @param p1 %p1 + * @param to_ty result type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_cooperative_matrix_load(transpose trans, bool checked, value op, value p0, + value p1, data_type to_ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC( + tinytc_cooperative_matrix_load_inst_create(&instr, static_cast(trans), + checked, op, p0, p1, to_ty, &loc), + loc); + return inst(instr); +} + +/** + * @brief Create cooperative matrix mul add instruction + * + * @param a %a + * @param b %b + * @param c %c + * @param to_ty result type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_cooperative_matrix_mul_add(value a, value b, value c, data_type to_ty, + location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_cooperative_matrix_mul_add_inst_create(&instr, a, b, c, to_ty, &loc), + loc); + return inst(instr); +} + +/** + * @brief Create cooperative matrix scale instruction + * + * @param a %a + * @param b %b + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_cooperative_matrix_scale(value a, value b, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_cooperative_matrix_scale_inst_create(&instr, a, b, &loc), loc); + return inst(instr); +} + +/** + * @brief Create cooperative matrix store instruction + * + * @param checked true for out-of-bounds checks + * @param flag store flag + * @param val %val + * @param op %op + * @param p0 %p0 + * @param p1 %p1 + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_cooperative_matrix_store(bool checked, store_flag flag, value val, value op, + value p0, value p1, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC( + tinytc_cooperative_matrix_store_inst_create( + &instr, checked, static_cast(flag), val, op, p0, p1, &loc), + loc); + return inst(instr); +} + /** * @brief Make alloca instruction * diff --git a/include/tinytc/types.h b/include/tinytc/types.h index dc430f25..68070f69 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -49,33 +49,37 @@ typedef enum { tinytc_status_ir_scalar_mismatch = 0x104, ///< Mismatch of scalar types tinytc_status_ir_invalid_number_of_indices = 0x105, /// Invalid number of indices tinytc_status_ir_expected_scalar = 0x106, ///< Expected a value of scalar type - tinytc_status_ir_expected_coopmatrix = 0x107, ///< Expected a value of coopmatrix type + tinytc_status_ir_expected_index = 0x107, ///< Expected a value of index type + tinytc_status_ir_expected_coopmatrix = 0x108, ///< Expected a value of coopmatrix type tinytc_status_ir_expected_coopmatrix_or_scalar = - 0x108, ///< Expected a value of coopmatrix or scalar type - tinytc_status_ir_expected_memref = 0x109, ///< Expected a value of memref type - tinytc_status_ir_expected_memref_or_scalar = 0x10a, ///< Expected memref or scalar type - tinytc_status_ir_expected_memref_or_group = 0x10b, ///< Expected a value of memref or group type - tinytc_status_ir_expected_vector_or_matrix = 0x10c, ///< Expected a vector or marix - tinytc_status_ir_unexpected_yield = 0x10d, ///< Unexpected yield instruction - tinytc_status_ir_yield_mismatch = 0x10e, ///< Wrong number of yielded values - tinytc_status_ir_subview_mismatch = 0x10f, ///< Mismatch in subview - tinytc_status_ir_invalid_slice = 0x110, ///< Invalid slice - tinytc_status_ir_expand_shape_order_too_small = 0x111, ///< Expand shape too small - tinytc_status_ir_expand_shape_mismatch = 0x112, ///< Invalid expand shape - tinytc_status_ir_collective_called_from_spmd = 0x113, ///< Collective instruction from SPMD - tinytc_status_ir_fp_unsupported = 0x114, ///< Instruction does not support floating type - tinytc_status_ir_spmd_called_from_collective = 0x115, ///< SPMD instruction from collective - tinytc_status_ir_expected_local_address_space = 0x116, ///< Expected local address space - tinytc_status_ir_expected_global_address_space = 0x117, ///< Expected global address space - tinytc_status_ir_invalid_offset = 0x118, ///< Invalid offset - tinytc_status_ir_int_unsupported = 0x119, ///< Instruction does not support int type - tinytc_status_ir_i1_unsupported = 0x11a, ///< Instruction does not support i1 type - tinytc_status_ir_complex_unsupported = 0x11b, ///< Instruction does not support complex type + 0x109, ///< Expected a value of coopmatrix or scalar type + tinytc_status_ir_expected_memref = 0x10a, ///< Expected a value of memref type + tinytc_status_ir_expected_memref_or_scalar = 0x10b, ///< Expected memref or scalar type + tinytc_status_ir_expected_memref_or_group = 0x10c, ///< Expected a value of memref or group type + tinytc_status_ir_expected_matrix = 0x10d, ///< Expected a marix + tinytc_status_ir_expected_vector_or_matrix = 0x10e, ///< Expected a vector or marix + tinytc_status_ir_unexpected_yield = 0x10f, ///< Unexpected yield instruction + tinytc_status_ir_yield_mismatch = 0x110, ///< Wrong number of yielded values + tinytc_status_ir_subview_mismatch = 0x111, ///< Mismatch in subview + tinytc_status_ir_invalid_slice = 0x112, ///< Invalid slice + tinytc_status_ir_expand_shape_order_too_small = 0x113, ///< Expand shape too small + tinytc_status_ir_expand_shape_mismatch = 0x114, ///< Invalid expand shape + tinytc_status_ir_collective_called_from_spmd = 0x115, ///< Collective instruction from SPMD + tinytc_status_ir_fp_unsupported = 0x116, ///< Instruction does not support floating type + tinytc_status_ir_spmd_called_from_collective = 0x117, ///< SPMD instruction from collective + tinytc_status_ir_expected_local_address_space = 0x118, ///< Expected local address space + tinytc_status_ir_expected_global_address_space = 0x119, ///< Expected global address space + tinytc_status_ir_invalid_offset = 0x11a, ///< Invalid offset + tinytc_status_ir_int_unsupported = 0x11b, ///< Instruction does not support int type + tinytc_status_ir_i1_unsupported = 0x11c, ///< Instruction does not support i1 type + tinytc_status_ir_complex_unsupported = 0x11d, ///< Instruction does not support complex type tinytc_status_ir_coopmatrix_unsupported = - 0x11c, ///< Instruction does not support coopmatrix type - tinytc_status_ir_forbidden_cast = 0x11d, ///< Forbidden cast - tinytc_status_ir_invalid_beta = 0x11e, ///< Invalid beta value - tinytc_status_ir_init_return_mismatch = 0x11f, ///< Mismatch of init values and returned values + 0x11e, ///< Instruction does not support coopmatrix type + tinytc_status_ir_forbidden_cast = 0x11f, ///< Forbidden cast + tinytc_status_ir_invalid_beta = 0x120, ///< Invalid beta value + tinytc_status_ir_init_return_mismatch = 0x121, ///< Mismatch of init values and returned values + tinytc_status_ir_invalid_matrix_use = 0x122, ///< Invalid matrix use + tinytc_status_ir_unsupported_coopmatrix_shape = 0x123, ///< Unsupported coopmatrix shape // Level zero errors tinytc_status_ze_result_not_ready = 0x10000, ///< ZE_RESULT_NOT_READY tinytc_status_ze_result_error_device_lost = 0x10001, ///< ZE_RESULT_ERROR_DEVICE_LOST diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index e9bf42c8..6f7a4ad9 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -59,11 +59,13 @@ enum class status { ir_scalar_mismatch = tinytc_status_ir_scalar_mismatch, ir_invalid_number_of_indices = tinytc_status_ir_invalid_number_of_indices, ir_expected_scalar = tinytc_status_ir_expected_scalar, + ir_expected_index = tinytc_status_ir_expected_index, ir_expected_coopmatrix = tinytc_status_ir_expected_coopmatrix, ir_expected_coopmatrix_or_scalar = tinytc_status_ir_expected_coopmatrix_or_scalar, ir_expected_memref = tinytc_status_ir_expected_memref, ir_expected_memref_or_scalar = tinytc_status_ir_expected_memref_or_scalar, ir_expected_memref_or_group = tinytc_status_ir_expected_memref_or_group, + ir_expected_matrix = tinytc_status_ir_expected_matrix, ir_expected_vector_or_matrix = tinytc_status_ir_expected_vector_or_matrix, ir_unexpected_yield = tinytc_status_ir_unexpected_yield, ir_yield_mismatch = tinytc_status_ir_yield_mismatch, @@ -84,6 +86,8 @@ enum class status { ir_forbidden_cast = tinytc_status_ir_forbidden_cast, ir_invalid_beta = tinytc_status_ir_invalid_beta, ir_init_return_mismatch = tinytc_status_ir_init_return_mismatch, + ir_invalid_matrix_use = tinytc_status_ir_invalid_matrix_use, + ir_unsupported_coopmatrix_shape = tinytc_status_ir_unsupported_coopmatrix_shape, ze_result_not_ready = tinytc_status_ze_result_not_ready, ze_result_error_device_lost = tinytc_status_ze_result_error_device_lost, ze_result_error_out_of_host_memory = tinytc_status_ze_result_error_out_of_host_memory, diff --git a/src/error.cpp b/src/error.cpp index ea9bfa36..a7d934ff 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -132,6 +132,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Number of indices must match memref order or must be 1 for group types"; case tinytc_status_ir_expected_scalar: return "Expected scalar type"; + case tinytc_status_ir_expected_index: + return "Expected index type"; case tinytc_status_ir_expected_coopmatrix: return "Expected coopmatrix type"; case tinytc_status_ir_expected_coopmatrix_or_scalar: @@ -142,6 +144,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Expected memref type or scalar type"; case tinytc_status_ir_expected_memref_or_group: return "Expected memref or group operand"; + case tinytc_status_ir_expected_matrix: + return "Expected matrix input"; case tinytc_status_ir_expected_vector_or_matrix: return "Expected vector or matrix input"; case tinytc_status_ir_unexpected_yield: @@ -183,6 +187,11 @@ char const *tinytc_error_string(tinytc_status_t status) { return "beta must be constant and 0 or 1 for atomic linear algebra operations"; case tinytc_status_ir_init_return_mismatch: return "The number or types of the initial values does not match the return type list"; + case tinytc_status_ir_invalid_matrix_use: + return "Operands have invalid matrix use"; + case tinytc_status_ir_unsupported_coopmatrix_shape: + return "Unsupported coopmatrix shape for the combination of scalar type, matrix use, and " + "target architecture"; // Level Zero case tinytc_status_ze_result_not_ready: return "ZE_RESULT_NOT_READY"; diff --git a/src/inst.cpp b/src/inst.cpp index 53b19cb6..4d17e713 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -265,6 +265,58 @@ tinytc_status_t tinytc_constant_inst_create_zero(tinytc_inst_t *instr, tinytc_da }); } +tinytc_status_t tinytc_cooperative_matrix_load_inst_create( + tinytc_inst_t *instr, tinytc_transpose_t trans, tinytc_bool_t checked, tinytc_value_t op, + tinytc_value_t p0, tinytc_value_t p1, tinytc_data_type_t to_ty, const tinytc_location_t *loc) { + if (instr == nullptr || op == nullptr || p0 == nullptr || p1 == nullptr || to_ty == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = std::make_unique( + enum_cast(trans), checked, op, p0, p1, to_ty, get_optional(loc)) + .release(); + }); +} + +tinytc_status_t tinytc_cooperative_matrix_mul_add_inst_create(tinytc_inst_t *instr, + tinytc_value_t a, tinytc_value_t b, + tinytc_value_t c, + tinytc_data_type_t to_ty, + const tinytc_location_t *loc) { + if (instr == nullptr || a == nullptr || b == nullptr || c == nullptr || to_ty == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = + std::make_unique(a, b, c, to_ty, get_optional(loc)) + .release(); + }); +} + +tinytc_status_t tinytc_cooperative_matrix_scale_inst_create(tinytc_inst_t *instr, tinytc_value_t a, + tinytc_value_t b, + const tinytc_location_t *loc) { + if (instr == nullptr || a == nullptr || b == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = std::make_unique(a, b, get_optional(loc)).release(); + }); +} + +tinytc_status_t tinytc_cooperative_matrix_store_inst_create( + tinytc_inst_t *instr, tinytc_bool_t checked, tinytc_store_flag_t flag, tinytc_value_t val, + tinytc_value_t op, tinytc_value_t p0, tinytc_value_t p1, const tinytc_location_t *loc) { + if (instr == nullptr || val == nullptr || op == nullptr || p0 == nullptr || p1 == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = std::make_unique( + checked, enum_cast(flag), val, op, p0, p1, get_optional(loc)) + .release(); + }); +} + tinytc_status_t tinytc_alloca_inst_create(tinytc_inst_t *instr, tinytc_data_type_t ty, const tinytc_location_t *loc) { if (instr == nullptr) { diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp index 617d3f74..288ee2eb 100644 --- a/src/node/data_type_node.hpp +++ b/src/node/data_type_node.hpp @@ -48,13 +48,17 @@ class coopmatrix_data_type : public data_type_node { inline auto ty() const -> tinytc_data_type_t { return ty_; } auto component_ty() const -> scalar_type; + inline auto shape(int mode) const -> std::int64_t { return mode == 1 ? cols_ : rows_; } inline auto rows() const -> std::int64_t { return rows_; } inline auto cols() const -> std::int64_t { return cols_; } inline auto use() const -> matrix_use { return use_; } + inline auto distributed_mode() const -> int { return use_ == matrix_use::b ? 1 : 0; } + inline auto num_blocks(std::int32_t subgroup_size) const -> std::int64_t { + return 1 + (shape(distributed_mode()) - 1) / subgroup_size; + } // Number of components per work-item inline auto length(std::int32_t subgroup_size) const -> std::int64_t { - const std::int64_t blocks = 1 + (rows_ - 1) / subgroup_size; - return blocks * cols_; + return num_blocks(subgroup_size) * shape(1 - distributed_mode()); } protected: diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index d5174430..3707ccf6 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -34,6 +34,12 @@ memref_data_type *get_memref_type(location const &loc, tinytc_value const &v) { return m; } +void check_index_ty(location const &loc, tinytc_data_type_t ty) { + if (auto sty = dyn_cast(ty); !sty || sty->ty() != scalar_type::index) { + throw compilation_error(loc, status::ir_expected_index); + } +} + blas_a2_inst::blas_a2_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t beta, tinytc_value_t B, bool atomic) : standard_inst{tid}, atomic_(atomic) { @@ -76,6 +82,10 @@ loop_inst::loop_inst(IK tid, tinytc_value_t from0, tinytc_value_t to0, tinytc_va result(i) = value_node{return_types[i], this, lc}; } for (std::size_t i = 0; i < init_values.size(); ++i) { + if (!isa(*return_types[i]) && + !isa(*return_types[i])) { + throw compilation_error(loc(), status::ir_expected_coopmatrix_or_scalar); + } if (init_values[i]->ty() != return_types[i]) { throw compilation_error(loc(), status::ir_init_return_mismatch); } @@ -366,6 +376,131 @@ auto constant_inst::is_identity() const -> bool { return std::visit([](auto const &v) { return v == decltype(v){1}; }, value_); } +cooperative_matrix_load_inst::cooperative_matrix_load_inst(transpose t, bool checked, + tinytc_value_t op0, tinytc_value_t p0, + tinytc_value_t p1, + tinytc_data_type_t to_ty, + location const &lc) + : standard_inst{IK::cooperative_matrix_load}, t_(t), checked_(checked) { + op(op_operand, op0); + op(op_pos0, p0); + op(op_pos1, p1); + loc(lc); + + auto ot = dyn_cast(operand().ty()); + if (!ot) { + throw compilation_error(loc(), status::ir_expected_memref); + } + auto rt = dyn_cast(to_ty); + if (!rt) { + throw compilation_error(loc(), status::ir_expected_coopmatrix); + } + if (ot->element_ty() != rt->component_ty()) { + throw compilation_error(loc(), status::ir_scalar_mismatch); + } + if (ot->dim() != 2) { + throw compilation_error(loc(), status::ir_expected_matrix); + } + + check_index_ty(lc, pos0().ty()); + check_index_ty(lc, pos1().ty()); + + result(0) = value_node{to_ty, this, lc}; +} + +cooperative_matrix_mul_add_inst::cooperative_matrix_mul_add_inst(tinytc_value_t a0, + tinytc_value_t b0, + tinytc_value_t c0, + tinytc_data_type_t to_ty, + location const &lc) + : standard_inst{IK::cooperative_matrix_mul_add} { + op(op_a, a0); + op(op_b, b0); + op(op_c, c0); + loc(lc); + + auto at = dyn_cast(a().ty()); + auto bt = dyn_cast(b().ty()); + auto ct = dyn_cast(c().ty()); + auto rt = dyn_cast(to_ty); + if (!at || !bt || !ct || !rt) { + throw compilation_error(loc(), status::ir_expected_memref); + } + + auto M = rt->rows(); + auto N = rt->cols(); + auto K = at->cols(); + if (ct->rows() != M || ct->cols() != N || at->rows() != M || bt->rows() != K || + bt->cols() != N) { + std::ostringstream oss; + oss << "Got "; + oss << "A=" << at->rows() << "x" << at->cols() << ", "; + oss << "B=" << bt->rows() << "x" << bt->cols() << ", "; + oss << "C=" << ct->rows() << "x" << ct->cols() << ", "; + oss << "result=" << rt->rows() << "x" << rt->cols(); + throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); + } + if (at->use() != matrix_use::a && bt->use() != matrix_use::b && ct->use() != matrix_use::acc && + rt->use() != matrix_use::acc) { + throw compilation_error(loc(), status::ir_invalid_matrix_use); + } + + result(0) = value_node{to_ty, this, lc}; +} + +cooperative_matrix_scale_inst::cooperative_matrix_scale_inst(tinytc_value_t a0, tinytc_value_t b0, + location const &lc) + : standard_inst{IK::cooperative_matrix_scale} { + op(op_a, a0); + op(op_b, b0); + loc(lc); + + auto at = dyn_cast(a().ty()); + if (!at) { + throw compilation_error(loc(), status::ir_expected_scalar); + } + auto bt = dyn_cast(b().ty()); + if (!bt) { + throw compilation_error(loc(), status::ir_expected_memref); + } + + if (at->ty() != bt->component_ty()) { + throw compilation_error(loc(), status::ir_scalar_mismatch); + } + + result(0) = value_node{b().ty(), this, lc}; +} + +cooperative_matrix_store_inst::cooperative_matrix_store_inst(bool checked, store_flag flag, + tinytc_value_t val0, + tinytc_value_t op0, tinytc_value_t p0, + tinytc_value_t p1, location const &lc) + : standard_inst{IK::cooperative_matrix_store}, checked_(checked), flag_(flag) { + op(op_val, val0); + op(op_operand, op0); + op(op_pos0, p0); + op(op_pos1, p1); + loc(lc); + + auto vt = dyn_cast(val().ty()); + if (!vt) { + throw compilation_error(loc(), status::ir_expected_coopmatrix); + } + auto ot = dyn_cast(operand().ty()); + if (!ot) { + throw compilation_error(loc(), status::ir_expected_memref); + } + if (vt->component_ty() != ot->element_ty()) { + throw compilation_error(loc(), status::ir_scalar_mismatch); + } + if (ot->dim() != 2) { + throw compilation_error(loc(), status::ir_expected_matrix); + } + + check_index_ty(lc, pos0().ty()); + check_index_ty(lc, pos1().ty()); +} + expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, array_view static_expand_shape0, array_view expand_shape0, location const &lc) @@ -457,6 +592,7 @@ load_inst::load_inst(tinytc_value_t op0, array_view index_list0, : standard_inst{IK::load, static_cast(1 + index_list0.size())} { op(0, op0); for (std::size_t i = 0; i < index_list0.size(); ++i) { + check_index_ty(lc, index_list0[i]->ty()); op(1 + i, index_list0[i]); } loc(lc); @@ -491,7 +627,7 @@ gemm_inst::gemm_inst(transpose tA, transpose tB, tinytc_value_t alpha0, tinytc_v auto c = get_memref_type(loc(), C()); if (a->dim() != 2 || b->dim() != 2 || c->dim() != 2) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix, + throw compilation_error(loc(), status::ir_expected_matrix, "gemm only supported for memref of order 2 (matrices)"); } @@ -603,6 +739,10 @@ if_inst::if_inst(tinytc_value_t condition, array_view return op(0, condition); loc(lc); for (std::size_t i = 0; i < return_types.size(); ++i) { + if (!isa(*return_types[i]) && + !isa(*return_types[i])) { + throw compilation_error(loc(), status::ir_expected_coopmatrix_or_scalar); + } result(i) = value_node{return_types[i], this, lc}; } } @@ -685,6 +825,7 @@ store_inst::store_inst(store_flag flag, tinytc_value_t val0, tinytc_value_t op0, { std::size_t i = op_operand; for (auto const &val : index_list0) { + check_index_ty(lc, val->ty()); op(++i, val); } } @@ -723,4 +864,12 @@ sum_inst::sum_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, tinyt } } +yield_inst::yield_inst(array_view vals, location const &lc) + : standard_inst{IK::yield, static_cast(vals.size())} { + loc(lc); + for (std::size_t i = 0; i < vals.size(); ++i) { + op(i, vals[i]); + } +} + } // namespace tinytc diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index b286b1c7..15f6bb6d 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -45,6 +45,10 @@ enum class IK { cast, compare, constant, + cooperative_matrix_load, + cooperative_matrix_mul_add, + cooperative_matrix_scale, + cooperative_matrix_store, expand, fuse, load, @@ -82,6 +86,8 @@ enum class IK { using inst_nodes = type_list case tinytc::IK::cast: case tinytc::IK::compare: case tinytc::IK::constant: + case tinytc::IK::cooperative_matrix_load: + case tinytc::IK::cooperative_matrix_mul_add: + case tinytc::IK::cooperative_matrix_scale: + case tinytc::IK::cooperative_matrix_store: case tinytc::IK::expand: case tinytc::IK::fuse: case tinytc::IK::load: @@ -501,6 +511,71 @@ class constant_inst : public standard_inst<0, 1> { value_type value_; }; +class cooperative_matrix_load_inst : public standard_inst<3, 1, 0> { + public: + inline static bool classof(inst_node const &i) { + return i.type_id() == IK::cooperative_matrix_load; + } + enum op_number { op_operand = 0, op_pos0 = 1, op_pos1 = 2 }; + cooperative_matrix_load_inst(transpose t, bool checked, tinytc_value_t op0, tinytc_value_t p0, + tinytc_value_t p1, tinytc_data_type_t to_ty, + location const &lc = {}); + inline auto t() const -> transpose { return t_; } + inline auto checked() const -> bool { return checked_; } + inline auto operand() const -> tinytc_value const & { return op(op_operand); } + inline auto pos0() const -> tinytc_value const & { return op(op_pos0); } + inline auto pos1() const -> tinytc_value const & { return op(op_pos1); } + + private: + transpose t_; + bool checked_; +}; + +class cooperative_matrix_mul_add_inst : public standard_inst<3, 1, 0> { + public: + inline static bool classof(inst_node const &i) { + return i.type_id() == IK::cooperative_matrix_mul_add; + } + enum op_number { op_a = 0, op_b = 1, op_c = 2 }; + cooperative_matrix_mul_add_inst(tinytc_value_t a0, tinytc_value_t b0, tinytc_value_t c0, + tinytc_data_type_t to_ty, location const &lc = {}); + inline auto a() const -> tinytc_value const & { return op(op_a); } + inline auto b() const -> tinytc_value const & { return op(op_b); } + inline auto c() const -> tinytc_value const & { return op(op_c); } +}; + +class cooperative_matrix_scale_inst : public standard_inst<2, 1, 0> { + public: + inline static bool classof(inst_node const &i) { + return i.type_id() == IK::cooperative_matrix_scale; + } + enum op_number { op_a = 0, op_b = 1 }; + cooperative_matrix_scale_inst(tinytc_value_t a0, tinytc_value_t b0, location const &lc = {}); + inline auto a() const -> tinytc_value const & { return op(op_a); } + inline auto b() const -> tinytc_value const & { return op(op_b); } +}; + +class cooperative_matrix_store_inst : public standard_inst<4, 0, 0> { + public: + inline static bool classof(inst_node const &i) { + return i.type_id() == IK::cooperative_matrix_store; + } + enum op_number { op_val = 0, op_operand = 1, op_pos0 = 2, op_pos1 = 3 }; + cooperative_matrix_store_inst(bool checked, store_flag flag, tinytc_value_t val0, + tinytc_value_t op0, tinytc_value_t p0, tinytc_value_t p1, + location const &lc = {}); + inline auto checked() const -> bool { return checked_; } + inline auto flag() const -> store_flag { return flag_; } + inline auto val() const -> tinytc_value const & { return op(op_val); } + inline auto operand() const -> tinytc_value const & { return op(op_operand); } + inline auto pos0() const -> tinytc_value const & { return op(op_pos0); } + inline auto pos1() const -> tinytc_value const & { return op(op_pos1); } + + private: + bool checked_; + store_flag flag_; +}; + class expand_inst : public standard_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::expand; } @@ -758,13 +833,7 @@ class sum_inst : public blas_a2_inst { class yield_inst : public standard_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::yield; } - inline yield_inst(array_view vals, location const &lc = {}) - : standard_inst{IK::yield, static_cast(vals.size())} { - loc(lc); - for (std::size_t i = 0; i < vals.size(); ++i) { - op(i, vals[i]); - } - } + yield_inst(array_view vals, location const &lc = {}); }; } // namespace tinytc diff --git a/src/parser/lexer.re b/src/parser/lexer.re index 58b57fe8..f580234c 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -96,6 +96,7 @@ lex: ".t" { adv_loc(); return parser::make_TRANS(loc_); } ".atomic" { adv_loc(); return parser::make_ATOMIC(loc_); } ".atomic_add" { adv_loc(); return parser::make_ATOMIC_ADD(loc_); } + ".checked" { adv_loc(); return parser::make_CHECKED(loc_); } "init" { adv_loc(); return parser::make_INIT(loc_); } "local" { adv_loc(); return parser::make_LOCAL(loc_); } "global" { adv_loc(); return parser::make_GLOBAL(loc_); } @@ -152,6 +153,10 @@ lex: "cast" { adv_loc(); return parser::make_CAST(loc_); } "cmp" { adv_loc(); return parser::make_CMP(loc_); } "constant" { adv_loc(); return parser::make_CONSTANT(loc_); } + "cooperative_matrix_load" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_LOAD(loc_); } + "cooperative_matrix_mul_add" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_MUL_ADD(loc_); } + "cooperative_matrix_scale" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_SCALE(loc_); } + "cooperative_matrix_store" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_STORE(loc_); } "expand" { adv_loc(); return parser::make_EXPAND(loc_); } "fuse" { adv_loc(); return parser::make_FUSE(loc_); } "load" { adv_loc(); return parser::make_LOAD(loc_); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 0ada8d08..bf2d9e95 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -96,6 +96,7 @@ TRANS ".t" ATOMIC ".atomic" ATOMIC_ADD ".atomic_add" + CHECKED ".checked" INIT "init" LOCAL "local" GLOBAL "global" @@ -117,6 +118,10 @@ CAST "cast" CMP "cmp" CONSTANT "constant" + COOPERATIVE_MATRIX_LOAD "cooperative_matrix_load" + COOPERATIVE_MATRIX_MUL_ADD "cooperative_matrix_mul_add" + COOPERATIVE_MATRIX_SCALE "cooperative_matrix_scale" + COOPERATIVE_MATRIX_STORE "cooperative_matrix_store" EXPAND "expand" FUSE "fuse" LOAD "load" @@ -188,8 +193,8 @@ %nterm hadamard_inst %nterm if_inst %nterm > optional_returned_values -%nterm > optional_scalar_type_list -%nterm > scalar_type_list +%nterm > optional_return_type_list +%nterm > return_type_list %nterm sum_inst %nterm yield_inst %nterm for_loop_var_type @@ -202,6 +207,11 @@ %nterm cast_inst %nterm compare_inst %nterm constant_inst +%nterm cooperative_matrix_load_inst +%nterm cooperative_matrix_mul_add_inst +%nterm cooperative_matrix_scale_inst +%nterm cooperative_matrix_store_inst +%nterm checked %nterm expand_inst %nterm integer_constant_or_identifier %nterm > expand_shape @@ -428,6 +438,7 @@ instructions: instruction: axpby_inst { $$ = std::move($1); } | barrier_inst { $$ = std::move($1); } + | cooperative_matrix_store_inst { $$ = std::move($1); } | gemm_inst { $$ = std::move($1); } | gemv_inst { $$ = std::move($1); } | ger_inst { $$ = std::move($1); } @@ -622,9 +633,9 @@ optional_step: optional_loop_carried_values: %empty { $$ = {}; } - | INIT LPAREN init_value_list RPAREN RETURNS LPAREN scalar_type_list RPAREN { + | INIT LPAREN init_value_list RPAREN RETURNS LPAREN return_type_list RPAREN { $$ = std::make_tuple(std::move($init_value_list.first), std::move($init_value_list.second), - std::move($scalar_type_list)); + std::move($return_type_list)); } ; @@ -741,7 +752,7 @@ sum_inst: ; yield_inst: - YIELD optional_value_list[vals] COLON optional_scalar_type_list[tys] { + YIELD optional_value_list[vals] COLON optional_return_type_list[tys] { if ($vals.size() != $tys.size()) { location loc = @vals; loc.end = @tys.end; @@ -761,6 +772,9 @@ valued_inst: | cast_inst { $$ = std::move($1); } | compare_inst { $$ = std::move($1); } | constant_inst { $$ = std::move($1); } + | cooperative_matrix_load_inst { $$ = std::move($1); } + | cooperative_matrix_mul_add_inst { $$ = std::move($1); } + | cooperative_matrix_scale_inst { $$ = std::move($1); } | expand_inst { $$ = std::move($1); } | for_inst { $$ = std::move($1); } | fuse_inst { $$ = std::move($1); } @@ -885,6 +899,82 @@ constant_inst: } ; +cooperative_matrix_load_inst: + COOPERATIVE_MATRIX_LOAD transpose checked var[op] LSQBR var[p0] COMMA var[p1] RSQBR COLON data_type[op_ty] RETURNS data_type[result_ty] { + check_type($op, $op_ty, @op, @op_ty); + try { + $$ = inst { + std::make_unique( + $transpose, $checked, std::move($op), std::move($p0), std::move($p1), std::move($result_ty), + @cooperative_matrix_load_inst) + .release() + }; + } catch (compilation_error const &e) { + error(e.loc(), e.what()); + YYERROR; + } + } +; + +checked: + %empty { $$ = false; } + | CHECKED { $$ = true; } +; + +cooperative_matrix_mul_add_inst: + COOPERATIVE_MATRIX_MUL_ADD var[a] COMMA var[b] COMMA var[c] COLON data_type[a_ty] COMMA data_type[b_ty] COMMA data_type[c_ty] RETURNS data_type[to_ty] { + check_type($a, $a_ty, @a, @a_ty); + check_type($b, $b_ty, @b, @b_ty); + check_type($c, $c_ty, @c, @c_ty); + try { + $$ = inst { + std::make_unique(std::move($a), std::move($b), + std::move($c), std::move($to_ty), + @cooperative_matrix_mul_add_inst) + .release() + }; + } catch (compilation_error const &e) { + error(e.loc(), e.what()); + YYERROR; + } + } +; + +cooperative_matrix_scale_inst: + COOPERATIVE_MATRIX_SCALE var[a] COMMA var[b] COLON data_type[a_ty] COMMA data_type[b_ty] { + check_type($a, $a_ty, @a, @a_ty); + check_type($b, $b_ty, @b, @b_ty); + try { + $$ = inst { + std::make_unique(std::move($a), std::move($b), + @cooperative_matrix_scale_inst) + .release() + }; + } catch (compilation_error const &e) { + error(e.loc(), e.what()); + YYERROR; + } + } +; + +cooperative_matrix_store_inst: + COOPERATIVE_MATRIX_STORE checked store_flag var[val] COMMA var[op] LSQBR var[p0] COMMA var[p1] RSQBR COLON data_type[val_ty] COMMA data_type[op_ty] { + check_type($val, $val_ty, @val, @val_ty); + check_type($op, $op_ty, @op, @op_ty); + try { + $$ = inst { + std::make_unique( + $checked, $store_flag, std::move($val), std::move($op), std::move($p0), std::move($p1), + @cooperative_matrix_store_inst) + .release() + }; + } catch (compilation_error const &e) { + error(e.loc(), e.what()); + YYERROR; + } + } +; + expand_inst: EXPAND var LSQBR INTEGER_CONSTANT[expanded_mode] RETURNS expand_shape RSQBR COLON memref_type { if ($var->ty() != $memref_type) { @@ -1039,18 +1129,18 @@ else_region: optional_returned_values: %empty { $$ = {}; } - | RETURNS LPAREN optional_scalar_type_list[tys] RPAREN { $$ = std::move($tys); } + | RETURNS LPAREN optional_return_type_list[tys] RPAREN { $$ = std::move($tys); } ; -optional_scalar_type_list: +optional_return_type_list: %empty {} - | scalar_type_list { $$ = std::move($1); } + | return_type_list { $$ = std::move($1); } ; -scalar_type_list: - scalar_type { $$.push_back($scalar_type); } - | scalar_type_list COMMA scalar_type { - $$ = std::move($1); $$.push_back($scalar_type); +return_type_list: + data_type { $$.push_back($data_type); } + | return_type_list COMMA data_type { + $$ = std::move($1); $$.push_back($data_type); } ; diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 5bca1f2b..bb3f023c 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -146,6 +146,15 @@ clir::var convert_to_opencl_pass::declare(value_node const &v) { return declared_vars_.back()[u]; } +auto convert_to_opencl_pass::get_coopmatrix_type(value_node const &v) const + -> const coopmatrix_data_type * { + auto t = dyn_cast(v.ty()); + if (t == nullptr) { + throw compilation_error(v.loc(), status::ir_expected_coopmatrix); + } + return t; +} + auto convert_to_opencl_pass::get_memref_type(value_node const &v) const -> const memref_data_type * { auto t = dyn_cast(v.ty()); @@ -531,6 +540,314 @@ std::vector convert_to_opencl_pass::operator()(constant_inst const & throw compilation_error(c.loc(), status::ir_expected_coopmatrix_or_scalar); } +std::vector convert_to_opencl_pass::operator()(cooperative_matrix_load_inst const &c) { + auto lhs = declare(c.result(0)); + auto lhs_ty = visit(*this, *c.result(0).ty()); + auto ot = get_memref_type(c.operand()); + auto rt = get_coopmatrix_type(c.result(0)); + auto &dv = get_dope_vector(c.operand()); + + const int rmode = rt->distributed_mode(); + const int omode = c.t() == transpose::T ? 1 - rmode : rmode; + const bool enable_sub_group_reads = core_cfg_.block_read_write_supported && !c.checked() && + c.t() == transpose::N && ot->stride(omode) == 1; + + auto clinst = std::vector{}; + auto const len = rt->length(core_cfg_.subgroup_size); + clinst.reserve(len + 5); + clinst.emplace_back(declaration(std::move(lhs_ty), lhs)); + + clir::expr pv[] = {val(c.pos0()), val(c.pos1())}; + auto pointer = clir::var{}; + clinst.emplace_back( + declaration_assignment(visit(*this, *c.operand().ty()), pointer, + val(c.operand()) + pv[0] * dv.stride(0) + pv[1] * dv.stride(1))); + clir::var rem[2] = {}; + if (c.checked()) { + clinst.emplace_back( + declaration_assignment(to_clir_ty(scalar_type::index), rem[0], dv.shape(0) - pv[0])); + clinst.emplace_back( + declaration_assignment(to_clir_ty(scalar_type::index), rem[1], dv.shape(1) - pv[1])); + } + + const std::int64_t num_blocks = rt->num_blocks(core_cfg_.subgroup_size); + for (std::int64_t block = 0; block < num_blocks; ++block) { + auto common_check = clir::var{}; + if (c.checked()) { + auto m = clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size; + clinst.emplace_back(declaration_assignment(to_clir_ty(scalar_type::i1), common_check, + m >= -pv[omode] && m < rem[omode])); + } + for (std::int64_t k = 0; k < rt->shape(1 - rmode); ++k) { + auto const store = [&](clir::expr rhs) -> clir::stmt { + return expression_statement( + assignment(lhs[k + block * rt->shape(1 - rmode)], std::move(rhs))); + }; + auto const remainder = rt->shape(rmode) - core_cfg_.subgroup_size * block; + const bool needs_mask = remainder < core_cfg_.subgroup_size; + if (enable_sub_group_reads && !needs_mask) { + auto rhs = sub_group_block_read_helper( + pointer + block * core_cfg_.subgroup_size + k * ot->stride(1), ot->element_ty(), + to_clir_address_space(ot->addrspace())); + clinst.emplace_back(store(std::move(rhs))); + } else { + auto rhs = pointer[ot->stride(omode) * (clir::get_sub_group_local_id() + + block * core_cfg_.subgroup_size) + + k * ot->stride(1 - omode)]; + auto checked_cond = [&] { + return common_check && k >= -pv[1 - omode] && k < rem[1 - omode]; + }; + auto mask_cond = [&] { return clir::get_sub_group_local_id() < remainder; }; + clir::expr cond = {}; + if (c.checked() && needs_mask) { + cond = checked_cond() && mask_cond(); + } else if (c.checked()) { + cond = checked_cond(); + } else if (needs_mask) { + cond = mask_cond(); + } + if (cond) { + rhs = ternary_conditional(cond, std::move(rhs), 0); + } + clinst.emplace_back(store(std::move(rhs))); + } + } + } + return clinst; +} +std::vector +convert_to_opencl_pass::operator()(cooperative_matrix_mul_add_inst const &c) { + auto lhs = declare(c.result(0)); + auto lhs_ty = visit(*this, *c.result(0).ty()); + auto rt = get_coopmatrix_type(c.result(0)); + auto at = get_coopmatrix_type(c.a()); + auto bt = get_coopmatrix_type(c.b()); + auto ct = get_coopmatrix_type(c.c()); + auto av = val(c.a()); + auto bv = val(c.b()); + auto cv = val(c.c()); + + const auto a_ty = at->component_ty(); + const auto b_ty = bt->component_ty(); + const auto c_ty = ct->component_ty(); + const auto r_ty = rt->component_ty(); + const bool use_double_buffering = is_complex_type(a_ty) && is_complex_type(b_ty); + + const std::int64_t M = rt->rows(), N = rt->cols(), K = at->cols(); + auto clinst = std::vector{}; + clinst.reserve(M * N + 2); + clinst.emplace_back(declaration(lhs_ty, lhs)); + + auto c_acc_im = clir::var{}; + if (use_double_buffering) { + clinst.emplace_back(declaration(lhs_ty, c_acc_im)); + } + + const std::int64_t num_blocks = rt->num_blocks(core_cfg_.subgroup_size); + const std::int64_t nbb = 4; + for (std::int64_t m_block = 0; m_block < num_blocks; ++m_block) { + for (std::int64_t nb = 0; nb < N; nb += nbb) { + for (std::int64_t k = 0; k < K; ++k) { + for (std::int64_t n = 0; n < nbb; ++n) { + if (nb + n < N) { + auto const n_block = (nb + n) / core_cfg_.subgroup_size; + auto const n_offset = (nb + n) % core_cfg_.subgroup_size; + + auto a = av[k + m_block * K]; + auto b = bv[k + n_block * K]; + auto c_next = lhs[nb + n + m_block * N]; + auto c = [&] { + if (k == 0) { + auto c = cv[nb + n + m_block * N]; + if (c_ty != r_ty) { + if (is_complex_type(r_ty) && !is_complex_type(c_ty)) { + c = clir::init_vector(to_clir_ty(r_ty), {c, 0}); + } else if (r_ty != c_ty) { + return clir::cast(to_clir_ty(r_ty), c); + } + } + return c; + } + return c_next; + }(); + const auto c_next_im = [&] { return c_acc_im[nb + n + m_block * N]; }; + const auto c_im = [&] { + if (k == 0) { + return init_vector(to_clir_ty(r_ty), {0, 0}); + } + return c_next_im(); + }; + + auto const add = [&](auto a_ty, auto b_ty, auto c_ty, auto a, auto b, + auto c, auto c_next) { + if (a_ty == b_ty && b_ty == c_ty) { + clinst.emplace_back(expression_statement( + assignment(std::move(c_next), + fma(std::move(a), std::move(b), std::move(c))))); + } else { + clinst.emplace_back(expression_statement( + assignment(std::move(c_next), + std::move(c) + std::move(a) * std::move(b)))); + } + }; + + if (is_complex_type(a_ty)) { + if (is_complex_type(b_ty)) { + auto b_bc_re = sub_group_broadcast(b.s(0), n_offset); + auto b_bc_im = sub_group_broadcast(b.s(1), n_offset); + add(a_ty, element_type(b_ty), r_ty, a, std::move(b_bc_re), c, + c_next); + add(a_ty, element_type(b_ty), r_ty, a, std::move(b_bc_im), c_im(), + c_next_im()); + } else { + auto b_bc = sub_group_broadcast(b, n_offset); + add(a_ty, b_ty, r_ty, std::move(a), std::move(b_bc), c, c_next); + } + } else if (is_complex_type(b_ty)) { + auto b_bc_re = sub_group_broadcast(b.s(0), n_offset); + auto b_bc_im = sub_group_broadcast(b.s(1), n_offset); + add(a_ty, element_type(b_ty), r_ty, a, std::move(b_bc_re), c.s(0), + c_next.s(0)); + add(a_ty, element_type(b_ty), r_ty, a, std::move(b_bc_im), c.s(1), + c_next.s(1)); + } else { + auto b_bc = sub_group_broadcast(std::move(b), n_offset); + add(a_ty, b_ty, r_ty, std::move(a), std::move(b_bc), std::move(c), + std::move(c_next)); + } + } + } + } + } + } + if (use_double_buffering) { + for (std::int64_t i = 0; i < rt->length(core_cfg_.subgroup_size); ++i) { + clinst.emplace_back(expression_statement( + add_into(lhs[i], clir::init_vector(to_clir_ty(r_ty), + {-c_acc_im[i].s(1), c_acc_im[i].s(0)})))); + } + } + return clinst; +} +std::vector convert_to_opencl_pass::operator()(cooperative_matrix_scale_inst const &c) { + auto lhs = declare(c.result(0)); + auto lhs_ty = visit(*this, *c.result()->ty()); + auto av = val(c.a()); + auto bv = val(c.b()); + auto at = get_scalar_type(c.a()); + auto bt = get_coopmatrix_type(c.b()); + + auto clinst = std::vector{}; + auto const len = bt->length(core_cfg_.subgroup_size); + clinst.reserve(len + 1); + clinst.emplace_back(declaration(std::move(lhs_ty), lhs)); + for (std::int64_t i = 0; i < len; ++i) { + auto op = multiply(at, bt->component_ty(), av, bv[i]); + clinst.emplace_back(expression_statement(assignment(lhs[i], std::move(op)))); + } + return clinst; +} +std::vector convert_to_opencl_pass::operator()(cooperative_matrix_store_inst const &c) { + auto ot = get_memref_type(c.operand()); + auto vt = get_coopmatrix_type(c.val()); + auto &dv = get_dope_vector(c.operand()); + auto valv = val(c.val()); + + const int vmode = vt->distributed_mode(); + const int omode = vmode; + + auto clinst = std::vector{}; + auto const len = vt->length(core_cfg_.subgroup_size); + clinst.reserve(len + 4); + + clir::expr pv[] = {val(c.pos0()), val(c.pos1())}; + auto pointer = clir::var{}; + clinst.emplace_back( + declaration_assignment(visit(*this, *c.operand().ty()), pointer, + val(c.operand()) + pv[0] * dv.stride(0) + pv[1] * dv.stride(1))); + clir::var rem[2] = {}; + if (c.checked()) { + clinst.emplace_back( + declaration_assignment(to_clir_ty(scalar_type::index), rem[0], dv.shape(0) - pv[0])); + clinst.emplace_back( + declaration_assignment(to_clir_ty(scalar_type::index), rem[1], dv.shape(1) - pv[1])); + } + + auto atomic_pointer = + cast(pointer_to(to_clir_atomic_ty(ot->element_ty(), to_clir_address_space(ot->addrspace()), + clir::type_qualifier::volatile_t)), + pointer); + const std::int64_t num_blocks = vt->num_blocks(core_cfg_.subgroup_size); + auto const num_k = vt->shape(1 - vmode); + auto store_block = std::vector{}; + store_block.reserve(num_k); + for (std::int64_t block = 0; block < num_blocks; ++block) { + store_block.clear(); + for (std::int64_t k = 0; k < num_k; ++k) { + auto const store = [&](clir::expr offset, clir::expr rhs) -> clir::expr { + switch (c.flag()) { + case store_flag::regular: + return assignment(pointer[std::move(offset)], std::move(rhs)); + case store_flag::atomic: + return call_builtin(clir::builtin_function::atomic_store_explicit, + {atomic_pointer + std::move(offset), std::move(rhs), + clir::memory_order::relaxed, + clir::memory_scope::work_group}); + case store_flag::atomic_add: + return call_builtin(clir::builtin_function::atomic_fetch_add_explicit, + {atomic_pointer + std::move(offset), std::move(rhs), + clir::memory_order::relaxed, + clir::memory_scope::work_group}); + }; + return {}; + }; + auto const remainder = vt->shape(vmode) - core_cfg_.subgroup_size * block; + const bool needs_mask = remainder < core_cfg_.subgroup_size; + + auto offset = ot->stride(omode) * + (clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size) + + k * ot->stride(1 - omode); + auto rhs = valv[k + block * vt->shape(1 - vmode)]; + auto checked_cond = [&] { return k >= -pv[1 - omode] && k < rem[1 - omode]; }; + auto mask_cond = [&] { return clir::get_sub_group_local_id() < remainder; }; + clir::expr cond = {}; + if (c.checked() && needs_mask) { + cond = checked_cond() && mask_cond(); + } else if (c.checked()) { + cond = checked_cond(); + } else if (needs_mask) { + cond = mask_cond(); + } + auto st = clir::expression_statement(store(std::move(offset), std::move(rhs))); + if (cond) { + store_block.emplace_back( + clir::if_selection_builder(cond) + .then([&](clir::block_builder &bb) { bb.add(std::move(st)); }) + .get_product()); + } else { + store_block.emplace_back(std::move(st)); + } + } + + if (c.checked()) { + auto m = clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size; + clinst.emplace_back(clir::if_selection_builder(m >= -pv[omode] && m < rem[omode]) + .then([&](clir::block_builder &bb) { + for (auto &i : store_block) { + bb.add(i); + } + }) + .get_product()); + } else { + for (auto &i : store_block) { + clinst.emplace_back(i); + } + } + } + + return clinst; +} + std::vector convert_to_opencl_pass::operator()(expand_inst const &e) { auto result_var = declare(*e.result()); auto m = get_memref_type(e.operand()); @@ -859,10 +1176,25 @@ std::vector convert_to_opencl_pass::operator()(for_inst const &in) { yielded_vars_.push_back(std::vector{}); for (std::int64_t i = 0; i < in.num_results(); ++i) { - auto v = declare(in.iter_arg(i)); - clinst.emplace_back(clir::declaration_assignment(visit(*this, *in.iter_arg(i).ty()), v, - val(in.iter_init(i)))); - yielded_vars_.back().emplace_back(std::move(v)); + auto lhs_ty = visit(*this, *in.result(i).ty()); + auto lhs = declare(in.result(i)); + + // Link the iteration variable to the result variable + uintptr_t u = std::bit_cast(&in.result(i)); + uintptr_t v = std::bit_cast(&in.iter_arg(i)); + declared_vars_.back()[v] = declared_vars_.back()[u]; + + auto iinit = val(in.iter_init(i)); + if (auto ct = dyn_cast(in.result(i).ty()); ct) { + clinst.emplace_back(declaration(std::move(lhs_ty), lhs)); + auto const len = ct->length(core_cfg_.subgroup_size); + for (std::int64_t j = 0; j < len; ++j) { + clinst.emplace_back(expression_statement(assignment(lhs[j], iinit[j]))); + } + } else { + clinst.emplace_back(clir::declaration_assignment(lhs_ty, lhs, iinit)); + } + yielded_vars_.back().emplace_back(std::move(lhs)); } auto lv = declare(in.loop_var()); @@ -874,11 +1206,6 @@ std::vector convert_to_opencl_pass::operator()(for_inst const &in) { clinst.emplace_back(clir::stmt(std::make_shared( std::move(start), std::move(condition), std::move(step), std::move(body)))); - for (std::int64_t i = 0; i < in.num_results(); ++i) { - clinst.emplace_back(clir::declaration_assignment( - visit(*this, *in.result(i).ty()), declare(in.result(i)), yielded_vars_.back()[i])); - } - yielded_vars_.pop_back(); return clinst; } @@ -1195,8 +1522,16 @@ std::vector convert_to_opencl_pass::operator()(yield_inst const &in) } std::vector clinst; for (std::int64_t i = 0; i < in.num_operands(); ++i) { - auto assign_yielded_var = clir::assignment(yielded_vars_.back()[i], val(in.op(i))); - clinst.push_back(clir::expression_statement(std::move(assign_yielded_var))); + auto &yielded_var = yielded_vars_.back()[i]; + auto ov = val(in.op(i)); + if (auto ct = dyn_cast(in.op(i).ty()); ct) { + auto const len = ct->length(core_cfg_.subgroup_size); + for (std::int64_t j = 0; j < len; ++j) { + clinst.push_back(expression_statement(assignment(yielded_var[j], ov[j]))); + } + } else { + clinst.push_back(expression_statement(assignment(yielded_var, ov))); + } } return clinst; } diff --git a/src/pass/convert_to_opencl.hpp b/src/pass/convert_to_opencl.hpp index 539cb4d7..f8c82c64 100644 --- a/src/pass/convert_to_opencl.hpp +++ b/src/pass/convert_to_opencl.hpp @@ -75,6 +75,10 @@ class convert_to_opencl_pass { std::vector operator()(cast_inst const &c); std::vector operator()(compare_inst const &c); std::vector operator()(constant_inst const &c); + std::vector operator()(cooperative_matrix_load_inst const &c); + std::vector operator()(cooperative_matrix_mul_add_inst const &c); + std::vector operator()(cooperative_matrix_scale_inst const &c); + std::vector operator()(cooperative_matrix_store_inst const &c); std::vector operator()(expand_inst const &e); std::vector operator()(fuse_inst const &f); std::vector operator()(load_inst const &e); @@ -109,6 +113,7 @@ class convert_to_opencl_pass { auto get_dope_vector(value_node const &v) -> dope_vector &; void set_dope_vector(value_node const &v, dope_vector dv); clir::var declare(value_node const &v); + auto get_coopmatrix_type(value_node const &v) const -> const coopmatrix_data_type *; auto get_memref_type(value_node const &v) const -> const memref_data_type *; static auto get_scalar_type(value_node const &v) -> scalar_type; diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index e7f973ae..3636c702 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -202,6 +202,77 @@ void dump_ir_pass::operator()(constant_inst const &c) { visit(*this, *c.result()->ty()); } +void dump_ir_pass::operator()(cooperative_matrix_load_inst const &c) { + dump_val(c.result(0)); + *os_ << " = cooperative_matrix_load"; + *os_ << "." << to_string(c.t()); + if (c.checked()) { + *os_ << ".checked"; + } + *os_ << " "; + dump_val(c.operand()); + *os_ << "["; + dump_val(c.pos0()); + *os_ << ","; + dump_val(c.pos1()); + *os_ << "] : "; + visit(*this, *c.operand().ty()); + *os_ << " -> "; + visit(*this, *c.result(0).ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_mul_add_inst const &c) { + dump_val(c.result(0)); + *os_ << " = cooperative_matrix_mul_add "; + dump_val(c.a()); + *os_ << ", "; + dump_val(c.b()); + *os_ << ", "; + dump_val(c.c()); + *os_ << " : "; + visit(*this, *c.a().ty()); + *os_ << ", "; + visit(*this, *c.b().ty()); + *os_ << ", "; + visit(*this, *c.b().ty()); + *os_ << " -> "; + visit(*this, *c.result(0).ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_scale_inst const &c) { + dump_val(c.result(0)); + *os_ << " = cooperative_matrix_scale "; + dump_val(c.a()); + *os_ << ", "; + dump_val(c.b()); + *os_ << " : "; + visit(*this, *c.a().ty()); + *os_ << ", "; + visit(*this, *c.b().ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_store_inst const &c) { + *os_ << "cooperative_matrix_store"; + if (c.checked()) { + *os_ << ".checked"; + } + if (c.flag() != store_flag::regular) { + *os_ << '.' << to_string(c.flag()); + } + *os_ << " "; + dump_val(c.val()); + *os_ << ", "; + dump_val(c.operand()); + *os_ << "["; + dump_val(c.pos0()); + *os_ << ","; + dump_val(c.pos1()); + *os_ << "] : "; + visit(*this, *c.val().ty()); + *os_ << ", "; + visit(*this, *c.operand().ty()); +} + void dump_ir_pass::operator()(expand_inst const &e) { dump_val(e.result(0)); *os_ << " = expand "; diff --git a/src/pass/dump_ir.hpp b/src/pass/dump_ir.hpp index be273053..fa6275f5 100644 --- a/src/pass/dump_ir.hpp +++ b/src/pass/dump_ir.hpp @@ -37,6 +37,10 @@ class dump_ir_pass { void operator()(cast_inst const &c); void operator()(compare_inst const &c); void operator()(constant_inst const &c); + void operator()(cooperative_matrix_load_inst const &c); + void operator()(cooperative_matrix_mul_add_inst const &c); + void operator()(cooperative_matrix_scale_inst const &c); + void operator()(cooperative_matrix_store_inst const &c); void operator()(expand_inst const &e); void operator()(fuse_inst const &f); void operator()(load_inst const &e); diff --git a/test/codegen/coopmatrix_basic.ir b/test/codegen/coopmatrix_basic.ir new file mode 100644 index 00000000..f3d690cc --- /dev/null +++ b/test/codegen/coopmatrix_basic.ir @@ -0,0 +1,56 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -O0 < %s | filecheck %s +func @coopmatrix_constant() { + %0 = constant 1.0 -> coopmatrix +; CHECK-LABEL: void coopmatrix_constant({{.*}} +; CHECK: double x[5]; +; CHECK-NEXT: x[0] = 0x1p+0; +; CHECK-NEXT: x[1] = 0x1p+0; +; CHECK-NEXT: x[2] = 0x1p+0; +; CHECK-NEXT: x[3] = 0x1p+0; +; CHECK-NEXT: x[4] = 0x1p+0; +} + +func @coopmatrix_add() { + %0 = constant 1.0 -> coopmatrix + %1 = constant 1.0 -> coopmatrix + %2 = arith.add %0, %1 : coopmatrix +; CHECK-LABEL: void coopmatrix_add({{.*}} +; CHECK: double x2[4]; +; CHECK-NEXT: x2[0] = x[0] + x1[0]; +; CHECK-NEXT: x2[1] = x[1] + x1[1]; +; CHECK-NEXT: x2[2] = x[2] + x1[2]; +; CHECK-NEXT: x2[3] = x[3] + x1[3]; +} + +func @coopmatrix_neg() subgroup_size(16) { + %0 = constant 1.0 -> coopmatrix + %1 = arith.neg %0 : coopmatrix +; CHECK-LABEL: void coopmatrix_neg({{.*}} +; CHECK: double x1[8]; +; CHECK-NEXT: x1[0] = -x[0]; +; CHECK-NEXT: x1[1] = -x[1]; +; CHECK-NEXT: x1[2] = -x[2]; +; CHECK-NEXT: x1[3] = -x[3]; +; CHECK-NEXT: x1[4] = -x[4]; +; CHECK-NEXT: x1[5] = -x[5]; +; CHECK-NEXT: x1[6] = -x[6]; +; CHECK-NEXT: x1[7] = -x[7]; +} + +func @coopmatrix_cast() subgroup_size(16) { + %0 = constant 1 -> coopmatrix + %1 = cast %0 : coopmatrix -> coopmatrix +; CHECK-LABEL: void coopmatrix_cast({{.*}} +; CHECK: float2 x1[8]; +; CHECK-NEXT: x1[0] = (float2) (x[0], 0); +; CHECK-NEXT: x1[1] = (float2) (x[1], 0); +; CHECK-NEXT: x1[2] = (float2) (x[2], 0); +; CHECK-NEXT: x1[3] = (float2) (x[3], 0); +; CHECK-NEXT: x1[4] = (float2) (x[4], 0); +; CHECK-NEXT: x1[5] = (float2) (x[5], 0); +; CHECK-NEXT: x1[6] = (float2) (x[6], 0); +; CHECK-NEXT: x1[7] = (float2) (x[7], 0); +} diff --git a/test/codegen/coopmatrix_load.ir b/test/codegen/coopmatrix_load.ir new file mode 100644 index 00000000..aa604a9b --- /dev/null +++ b/test/codegen/coopmatrix_load.ir @@ -0,0 +1,153 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -O0 < %s | filecheck %s +func @coopmatrix_a_load_n(%A: memref, %x: index, %y: index) subgroup_size(16) { + %0 = cooperative_matrix_load.n %A[%x,%y] : memref -> coopmatrix +; CHECK-LABEL: void coopmatrix_a_load_n({{.*}} +; CHECK: float x1[8]; +; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; +; CHECK-NEXT: x1[0] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 0))); +; CHECK-NEXT: x1[1] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 64))); +; CHECK-NEXT: x1[2] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 128))); +; CHECK-NEXT: x1[3] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 192))); +; CHECK-NEXT: x1[4] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 256))); +; CHECK-NEXT: x1[5] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 320))); +; CHECK-NEXT: x1[6] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 384))); +; CHECK-NEXT: x1[7] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 448))); +} + +func @coopmatrix_a_load_n_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { + %0 = cooperative_matrix_load.n.checked %A[%x,%y] : memref -> coopmatrix +; CHECK-LABEL: void coopmatrix_a_load_n_checked({{.*}} +; CHECK: float x1[16]; +; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; +; CHECK-NEXT: long x3 = 64 - x; +; CHECK-NEXT: long x4 = 48 - y; +; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x3; +; CHECK-NEXT: x1[0] = x5 && 0 >= -y && 0 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 0] : 0; +; CHECK-NEXT: x1[1] = x5 && 1 >= -y && 1 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 64] : 0; +; CHECK-NEXT: x1[2] = x5 && 2 >= -y && 2 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 128] : 0; +; CHECK-NEXT: x1[3] = x5 && 3 >= -y && 3 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 192] : 0; +; CHECK-NEXT: x1[4] = x5 && 4 >= -y && 4 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 256] : 0; +; CHECK-NEXT: x1[5] = x5 && 5 >= -y && 5 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 320] : 0; +; CHECK-NEXT: x1[6] = x5 && 6 >= -y && 6 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 384] : 0; +; CHECK-NEXT: x1[7] = x5 && 7 >= -y && 7 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 448] : 0; +; CHECK-NEXT: bool x6 = get_sub_group_local_id() + 16 >= -x && get_sub_group_local_id() + 16 < x3; +; CHECK-NEXT: x1[8] = x6 && 0 >= -y && 0 < x4 ? x2[1 * (get_sub_group_local_id() + 16) + 0] : 0; +; CHECK-NEXT: x1[9] = x6 && 1 >= -y && 1 < x4 ? x2[1 * (get_sub_group_local_id() + 16) + 64] : 0; +; CHECK-NEXT: x1[10] = x6 && 2 >= -y && 2 < x4 ? x2[1 * (get_sub_group_local_id() + 16) + 128] : 0; +; CHECK-NEXT: x1[11] = x6 && 3 >= -y && 3 < x4 ? x2[1 * (get_sub_group_local_id() + 16) + 192] : 0; +; CHECK-NEXT: x1[12] = x6 && 4 >= -y && 4 < x4 ? x2[1 * (get_sub_group_local_id() + 16) + 256] : 0; +; CHECK-NEXT: x1[13] = x6 && 5 >= -y && 5 < x4 ? x2[1 * (get_sub_group_local_id() + 16) + 320] : 0; +; CHECK-NEXT: x1[14] = x6 && 6 >= -y && 6 < x4 ? x2[1 * (get_sub_group_local_id() + 16) + 384] : 0; +; CHECK-NEXT: x1[15] = x6 && 7 >= -y && 7 < x4 ? x2[1 * (get_sub_group_local_id() + 16) + 448] : 0; +} + +func @coopmatrix_a_load_t(%A: memref, %x: index, %y: index) subgroup_size(16) { + %0 = cooperative_matrix_load.t %A[%x,%y] : memref -> coopmatrix +; CHECK-LABEL: void coopmatrix_a_load_t({{.*}} +; CHECK: float x1[8]; +; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; +; CHECK-NEXT: x1[0] = x2[64 * (get_sub_group_local_id() + 0) + 0]; +; CHECK-NEXT: x1[1] = x2[64 * (get_sub_group_local_id() + 0) + 1]; +; CHECK-NEXT: x1[2] = x2[64 * (get_sub_group_local_id() + 0) + 2]; +; CHECK-NEXT: x1[3] = x2[64 * (get_sub_group_local_id() + 0) + 3]; +; CHECK-NEXT: x1[4] = x2[64 * (get_sub_group_local_id() + 0) + 4]; +; CHECK-NEXT: x1[5] = x2[64 * (get_sub_group_local_id() + 0) + 5]; +; CHECK-NEXT: x1[6] = x2[64 * (get_sub_group_local_id() + 0) + 6]; +; CHECK-NEXT: x1[7] = x2[64 * (get_sub_group_local_id() + 0) + 7]; +} + +func @coopmatrix_a_load_t_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { + %0 = cooperative_matrix_load.t.checked %A[%x,%y] : memref -> coopmatrix +; CHECK-LABEL: void coopmatrix_a_load_t_checked({{.*}} +; CHECK: float x1[8]; +; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; +; CHECK-NEXT: long x3 = 64 - x; +; CHECK-NEXT: long x4 = 48 - y; +; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -y && get_sub_group_local_id() + 0 < x4; +; CHECK-NEXT: x1[0] = x5 && 0 >= -x && 0 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 0] : 0; +; CHECK-NEXT: x1[1] = x5 && 1 >= -x && 1 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 1] : 0; +; CHECK-NEXT: x1[2] = x5 && 2 >= -x && 2 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 2] : 0; +; CHECK-NEXT: x1[3] = x5 && 3 >= -x && 3 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 3] : 0; +; CHECK-NEXT: x1[4] = x5 && 4 >= -x && 4 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 4] : 0; +; CHECK-NEXT: x1[5] = x5 && 5 >= -x && 5 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 5] : 0; +; CHECK-NEXT: x1[6] = x5 && 6 >= -x && 6 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 6] : 0; +; CHECK-NEXT: x1[7] = x5 && 7 >= -x && 7 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 7] : 0; +} + +func @coopmatrix_b_load_n(%B: memref, %x: index, %y: index) subgroup_size(16) { + %0 = cooperative_matrix_load.n %B[%x,%y] : memref -> coopmatrix +; CHECK-LABEL: void coopmatrix_b_load_n({{.*}} +; CHECK: float x1[8]; +; CHECK-NEXT: global float* x2 = B + x * 1 + y * 64; +; CHECK-NEXT: x1[0] = x2[64 * (get_sub_group_local_id() + 0) + 0]; +; CHECK-NEXT: x1[1] = x2[64 * (get_sub_group_local_id() + 0) + 1]; +; CHECK-NEXT: x1[2] = x2[64 * (get_sub_group_local_id() + 0) + 2]; +; CHECK-NEXT: x1[3] = x2[64 * (get_sub_group_local_id() + 0) + 3]; +; CHECK-NEXT: x1[4] = x2[64 * (get_sub_group_local_id() + 0) + 4]; +; CHECK-NEXT: x1[5] = x2[64 * (get_sub_group_local_id() + 0) + 5]; +; CHECK-NEXT: x1[6] = x2[64 * (get_sub_group_local_id() + 0) + 6]; +; CHECK-NEXT: x1[7] = x2[64 * (get_sub_group_local_id() + 0) + 7]; +} + +func @coopmatrix_b_load_n_checked(%B: memref, %x: index, %y: index) subgroup_size(16) { + %0 = cooperative_matrix_load.n.checked %B[%x,%y] : memref -> coopmatrix +; CHECK-LABEL: void coopmatrix_b_load_n_checked({{.*}} +; CHECK: float x1[16]; +; CHECK-NEXT: global float* x2 = B + x * 1 + y * 64; +; CHECK-NEXT: long x3 = 64 - x; +; CHECK-NEXT: long x4 = 48 - y; +; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -y && get_sub_group_local_id() + 0 < x4; +; CHECK-NEXT: x1[0] = x5 && 0 >= -x && 0 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 0] : 0; +; CHECK-NEXT: x1[1] = x5 && 1 >= -x && 1 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 1] : 0; +; CHECK-NEXT: x1[2] = x5 && 2 >= -x && 2 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 2] : 0; +; CHECK-NEXT: x1[3] = x5 && 3 >= -x && 3 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 3] : 0; +; CHECK-NEXT: x1[4] = x5 && 4 >= -x && 4 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 4] : 0; +; CHECK-NEXT: x1[5] = x5 && 5 >= -x && 5 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 5] : 0; +; CHECK-NEXT: x1[6] = x5 && 6 >= -x && 6 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 6] : 0; +; CHECK-NEXT: x1[7] = x5 && 7 >= -x && 7 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 7] : 0; +; CHECK-NEXT: bool x6 = get_sub_group_local_id() + 16 >= -y && get_sub_group_local_id() + 16 < x4; +; CHECK-NEXT: x1[8] = x6 && 0 >= -x && 0 < x3 ? x2[64 * (get_sub_group_local_id() + 16) + 0] : 0; +; CHECK-NEXT: x1[9] = x6 && 1 >= -x && 1 < x3 ? x2[64 * (get_sub_group_local_id() + 16) + 1] : 0; +; CHECK-NEXT: x1[10] = x6 && 2 >= -x && 2 < x3 ? x2[64 * (get_sub_group_local_id() + 16) + 2] : 0; +; CHECK-NEXT: x1[11] = x6 && 3 >= -x && 3 < x3 ? x2[64 * (get_sub_group_local_id() + 16) + 3] : 0; +; CHECK-NEXT: x1[12] = x6 && 4 >= -x && 4 < x3 ? x2[64 * (get_sub_group_local_id() + 16) + 4] : 0; +; CHECK-NEXT: x1[13] = x6 && 5 >= -x && 5 < x3 ? x2[64 * (get_sub_group_local_id() + 16) + 5] : 0; +; CHECK-NEXT: x1[14] = x6 && 6 >= -x && 6 < x3 ? x2[64 * (get_sub_group_local_id() + 16) + 6] : 0; +; CHECK-NEXT: x1[15] = x6 && 7 >= -x && 7 < x3 ? x2[64 * (get_sub_group_local_id() + 16) + 7] : 0; +} + +func @coopmatrix_b_load_t(%B: memref, %x: index, %y: index) subgroup_size(16) { + %0 = cooperative_matrix_load.t %B[%x,%y] : memref -> coopmatrix +; CHECK-LABEL: void coopmatrix_b_load_t({{.*}} +; CHECK: float x1[8]; +; CHECK-NEXT: global float* x2 = B + x * 1 + y * 64; +; CHECK-NEXT: x1[0] = x2[1 * (get_sub_group_local_id() + 0) + 0]; +; CHECK-NEXT: x1[1] = x2[1 * (get_sub_group_local_id() + 0) + 64]; +; CHECK-NEXT: x1[2] = x2[1 * (get_sub_group_local_id() + 0) + 128]; +; CHECK-NEXT: x1[3] = x2[1 * (get_sub_group_local_id() + 0) + 192]; +; CHECK-NEXT: x1[4] = x2[1 * (get_sub_group_local_id() + 0) + 256]; +; CHECK-NEXT: x1[5] = x2[1 * (get_sub_group_local_id() + 0) + 320]; +; CHECK-NEXT: x1[6] = x2[1 * (get_sub_group_local_id() + 0) + 384]; +; CHECK-NEXT: x1[7] = x2[1 * (get_sub_group_local_id() + 0) + 448]; +} + +func @coopmatrix_b_load_t_checked(%B: memref, %x: index, %y: index) subgroup_size(16) { + %0 = cooperative_matrix_load.t.checked %B[%x,%y] : memref -> coopmatrix +; CHECK-LABEL: void coopmatrix_b_load_t_checked({{.*}} +; CHECK: float x1[8]; +; CHECK-NEXT: global float* x2 = B + x * 1 + y * 64; +; CHECK-NEXT: long x3 = 64 - x; +; CHECK-NEXT: long x4 = 48 - y; +; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x3; +; CHECK-NEXT: x1[0] = x5 && 0 >= -y && 0 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 0] : 0; +; CHECK-NEXT: x1[1] = x5 && 1 >= -y && 1 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 64] : 0; +; CHECK-NEXT: x1[2] = x5 && 2 >= -y && 2 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 128] : 0; +; CHECK-NEXT: x1[3] = x5 && 3 >= -y && 3 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 192] : 0; +; CHECK-NEXT: x1[4] = x5 && 4 >= -y && 4 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 256] : 0; +; CHECK-NEXT: x1[5] = x5 && 5 >= -y && 5 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 320] : 0; +; CHECK-NEXT: x1[6] = x5 && 6 >= -y && 6 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 384] : 0; +; CHECK-NEXT: x1[7] = x5 && 7 >= -y && 7 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 448] : 0; +} diff --git a/test/codegen/coopmatrix_mul_add.ir b/test/codegen/coopmatrix_mul_add.ir new file mode 100644 index 00000000..62e0e181 --- /dev/null +++ b/test/codegen/coopmatrix_mul_add.ir @@ -0,0 +1,100 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -O0 < %s | filecheck %s +func @coopmatrix_mul_add_ff() subgroup_size(16) { + %a = constant 1.0 -> coopmatrix + %b = constant 1.0 -> coopmatrix + %c = constant 1.0 -> coopmatrix + %c_next = cooperative_matrix_mul_add %a, %b, %c + : coopmatrix, coopmatrix, + coopmatrix -> coopmatrix +; CHECK-LABEL: void coopmatrix_mul_add_ff({{.*}} +; CHECK: float c_next[4]; +; CHECK-NEXT: c_next[0] = fma(a[0], sub_group_broadcast(b[0], 0), c[0]); +; CHECK-NEXT: c_next[1] = fma(a[0], sub_group_broadcast(b[0], 1), c[1]); +; CHECK-NEXT: c_next[2] = fma(a[0], sub_group_broadcast(b[0], 2), c[2]); +; CHECK-NEXT: c_next[3] = fma(a[0], sub_group_broadcast(b[0], 3), c[3]); +; CHECK-NEXT: c_next[0] = fma(a[1], sub_group_broadcast(b[1], 0), c_next[0]); +; CHECK-NEXT: c_next[1] = fma(a[1], sub_group_broadcast(b[1], 1), c_next[1]); +; CHECK-NEXT: c_next[2] = fma(a[1], sub_group_broadcast(b[1], 2), c_next[2]); +; CHECK-NEXT: c_next[3] = fma(a[1], sub_group_broadcast(b[1], 3), c_next[3]); +} + +func @coopmatrix_mul_add_cf() subgroup_size(16) { + %a = constant [1.0, 0.0] -> coopmatrix + %b = constant 1.0 -> coopmatrix + %c = constant 1.0 -> coopmatrix + %c_next = cooperative_matrix_mul_add %a, %b, %c + : coopmatrix, coopmatrix, + coopmatrix -> coopmatrix +; CHECK-LABEL: void coopmatrix_mul_add_cf({{.*}} +; CHECK: float2 c_next[4]; +; CHECK-NEXT: c_next[0] = (float2) (c[0], 0) + a[0] * sub_group_broadcast(b[0], 0); +; CHECK-NEXT: c_next[1] = (float2) (c[1], 0) + a[0] * sub_group_broadcast(b[0], 1); +; CHECK-NEXT: c_next[2] = (float2) (c[2], 0) + a[0] * sub_group_broadcast(b[0], 2); +; CHECK-NEXT: c_next[3] = (float2) (c[3], 0) + a[0] * sub_group_broadcast(b[0], 3); +; CHECK-NEXT: c_next[0] = c_next[0] + a[1] * sub_group_broadcast(b[1], 0); +; CHECK-NEXT: c_next[1] = c_next[1] + a[1] * sub_group_broadcast(b[1], 1); +; CHECK-NEXT: c_next[2] = c_next[2] + a[1] * sub_group_broadcast(b[1], 2); +; CHECK-NEXT: c_next[3] = c_next[3] + a[1] * sub_group_broadcast(b[1], 3); +} + +func @coopmatrix_mul_add_fc() subgroup_size(16) { + %a = constant 1.0 -> coopmatrix + %b = constant [1.0, 0.0] -> coopmatrix + %c = constant 1.0 -> coopmatrix + %c_next = cooperative_matrix_mul_add %a, %b, %c + : coopmatrix, coopmatrix, + coopmatrix -> coopmatrix +; CHECK-LABEL: void coopmatrix_mul_add_fc({{.*}} +; CHECK: float2 c_next[4]; +; CHECK-NEXT: c_next[0].x = ((float2) (c[0], 0)).x + a[0] * sub_group_broadcast(b[0].x, 0); +; CHECK-NEXT: c_next[0].y = ((float2) (c[0], 0)).y + a[0] * sub_group_broadcast(b[0].y, 0); +; CHECK-NEXT: c_next[1].x = ((float2) (c[1], 0)).x + a[0] * sub_group_broadcast(b[0].x, 1); +; CHECK-NEXT: c_next[1].y = ((float2) (c[1], 0)).y + a[0] * sub_group_broadcast(b[0].y, 1); +; CHECK-NEXT: c_next[2].x = ((float2) (c[2], 0)).x + a[0] * sub_group_broadcast(b[0].x, 2); +; CHECK-NEXT: c_next[2].y = ((float2) (c[2], 0)).y + a[0] * sub_group_broadcast(b[0].y, 2); +; CHECK-NEXT: c_next[3].x = ((float2) (c[3], 0)).x + a[0] * sub_group_broadcast(b[0].x, 3); +; CHECK-NEXT: c_next[3].y = ((float2) (c[3], 0)).y + a[0] * sub_group_broadcast(b[0].y, 3); +; CHECK-NEXT: c_next[0].x = c_next[0].x + a[1] * sub_group_broadcast(b[1].x, 0); +; CHECK-NEXT: c_next[0].y = c_next[0].y + a[1] * sub_group_broadcast(b[1].y, 0); +; CHECK-NEXT: c_next[1].x = c_next[1].x + a[1] * sub_group_broadcast(b[1].x, 1); +; CHECK-NEXT: c_next[1].y = c_next[1].y + a[1] * sub_group_broadcast(b[1].y, 1); +; CHECK-NEXT: c_next[2].x = c_next[2].x + a[1] * sub_group_broadcast(b[1].x, 2); +; CHECK-NEXT: c_next[2].y = c_next[2].y + a[1] * sub_group_broadcast(b[1].y, 2); +; CHECK-NEXT: c_next[3].x = c_next[3].x + a[1] * sub_group_broadcast(b[1].x, 3); +; CHECK-NEXT: c_next[3].y = c_next[3].y + a[1] * sub_group_broadcast(b[1].y, 3); +} + +func @coopmatrix_mul_add_cc() subgroup_size(16) { + %a = constant [1.0, 0.0] -> coopmatrix + %b = constant [1.0, 0.0] -> coopmatrix + %c = constant [1.0, 0.0] -> coopmatrix + %c_next = cooperative_matrix_mul_add %a, %b, %c + : coopmatrix, coopmatrix, + coopmatrix -> coopmatrix +; CHECK-LABEL: void coopmatrix_mul_add_cc({{.*}} +; CHECK: float2 c_next[4]; +; CHECK-NEXT: float2 x[4]; +; CHECK-NEXT: c_next[0] = c[0] + a[0] * sub_group_broadcast(b[0].x, 0); +; CHECK-NEXT: x[0] = (float2) (0, 0) + a[0] * sub_group_broadcast(b[0].y, 0); +; CHECK-NEXT: c_next[1] = c[1] + a[0] * sub_group_broadcast(b[0].x, 1); +; CHECK-NEXT: x[1] = (float2) (0, 0) + a[0] * sub_group_broadcast(b[0].y, 1); +; CHECK-NEXT: c_next[2] = c[2] + a[0] * sub_group_broadcast(b[0].x, 2); +; CHECK-NEXT: x[2] = (float2) (0, 0) + a[0] * sub_group_broadcast(b[0].y, 2); +; CHECK-NEXT: c_next[3] = c[3] + a[0] * sub_group_broadcast(b[0].x, 3); +; CHECK-NEXT: x[3] = (float2) (0, 0) + a[0] * sub_group_broadcast(b[0].y, 3); +; CHECK-NEXT: c_next[0] = c_next[0] + a[1] * sub_group_broadcast(b[1].x, 0); +; CHECK-NEXT: x[0] = x[0] + a[1] * sub_group_broadcast(b[1].y, 0); +; CHECK-NEXT: c_next[1] = c_next[1] + a[1] * sub_group_broadcast(b[1].x, 1); +; CHECK-NEXT: x[1] = x[1] + a[1] * sub_group_broadcast(b[1].y, 1); +; CHECK-NEXT: c_next[2] = c_next[2] + a[1] * sub_group_broadcast(b[1].x, 2); +; CHECK-NEXT: x[2] = x[2] + a[1] * sub_group_broadcast(b[1].y, 2); +; CHECK-NEXT: c_next[3] = c_next[3] + a[1] * sub_group_broadcast(b[1].x, 3); +; CHECK-NEXT: x[3] = x[3] + a[1] * sub_group_broadcast(b[1].y, 3); +; CHECK-NEXT: c_next[0] += (float2) (-x[0].y, x[0].x); +; CHECK-NEXT: c_next[1] += (float2) (-x[1].y, x[1].x); +; CHECK-NEXT: c_next[2] += (float2) (-x[2].y, x[2].x); +; CHECK-NEXT: c_next[3] += (float2) (-x[3].y, x[3].x); +} diff --git a/test/codegen/coopmatrix_store.ir b/test/codegen/coopmatrix_store.ir new file mode 100644 index 00000000..e0988ff5 --- /dev/null +++ b/test/codegen/coopmatrix_store.ir @@ -0,0 +1,55 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -O0 < %s | filecheck %s +func @coopmatrix_a_store_n(%A: memref, %x: index, %y: index) subgroup_size(16) { + %c0 = constant 1.0 -> coopmatrix + cooperative_matrix_store %c0, %A[%x,%y] : coopmatrix, memref +; CHECK-LABEL: void coopmatrix_a_store_n({{.*}} +; CHECK: global float* x1 = A + x * 1 + y * 64; +; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 0] = c0[0]; +; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 64] = c0[1]; +} + +func @coopmatrix_a_store_n_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { + %c0 = constant 1.0 -> coopmatrix + cooperative_matrix_store.checked %c0, %A[%x,%y] : coopmatrix, memref +; CHECK-LABEL: void coopmatrix_a_store_n_checked({{.*}} +; CHECK: global float* x1 = A + x * 1 + y * 64; +; CHECK-NEXT: long x2 = 64 - x; +; CHECK-NEXT: long x3 = 48 - y; +; CHECK-NEXT: if (get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x2) { +; CHECK-NEXT: if (0 >= -y && 0 < x3) { +; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 0] = c0[0]; +; CHECK-NEXT: } +; CHECK-NEXT: if (1 >= -y && 1 < x3) { +; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 64] = c0[1]; +; CHECK-NEXT: } +; CHECK-NEXT: } +} + +func @coopmatrix_a_store_atomic_add(%A: memref, %x: index, %y: index) subgroup_size(16) { + %c0 = constant 1.0 -> coopmatrix + cooperative_matrix_store.atomic_add %c0, %A[%x,%y] : coopmatrix, memref +; CHECK-LABEL: void coopmatrix_a_store_atomic_add({{.*}} +; CHECK: global float* x1 = A + x * 1 + y * 64; +; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) x1 + (1 * (get_sub_group_local_id() + 0) + 0), c0[0], memory_order_relaxed, memory_scope_work_group); +; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) x1 + (1 * (get_sub_group_local_id() + 0) + 64), c0[1], memory_order_relaxed, memory_scope_work_group); +} + +func @coopmatrix_a_store_checked_atomic_add(%A: memref, %x: index, %y: index) subgroup_size(16) { + %c0 = constant 1.0 -> coopmatrix + cooperative_matrix_store.checked.atomic_add %c0, %A[%x,%y] : coopmatrix, memref +; CHECK-LABEL: void coopmatrix_a_store_checked_atomic_add({{.*}} +; CHECK: global float* x1 = A + x * 1 + y * 64; +; CHECK-NEXT: long x2 = 64 - x; +; CHECK-NEXT: long x3 = 48 - y; +; CHECK-NEXT: if (get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x2) { +; CHECK-NEXT: if (0 >= -y && 0 < x3) { +; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) x1 + (1 * (get_sub_group_local_id() + 0) + 0), c0[0], memory_order_relaxed, memory_scope_work_group); +; CHECK-NEXT: } +; CHECK-NEXT: if (1 >= -y && 1 < x3) { +; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) x1 + (1 * (get_sub_group_local_id() + 0) + 64), c0[1], memory_order_relaxed, memory_scope_work_group); +; CHECK-NEXT: } +; CHECK-NEXT: } +} diff --git a/test/codegen/for.ir b/test/codegen/for.ir index 6f57d11b..d58f9060 100644 --- a/test/codegen/for.ir +++ b/test/codegen/for.ir @@ -29,14 +29,12 @@ func @for2(%fib: memref) { ; CHECK-LABEL: void for2({{.*}} ; CHECK: long f0 = 0ll; ; CHECK-NEXT: long f1 = 1ll; -; CHECK-NEXT: long fn_2 = f0; -; CHECK-NEXT: long fn_1 = f1; +; CHECK-NEXT: long fn_1 = f0; +; CHECK-NEXT: long fn = f1; ; CHECK-NEXT: for (int n = from; n < to; ++n) { -; CHECK-NEXT: long fn = fn_2 + fn_1; -; CHECK-NEXT: fn_2 = fn_1; +; CHECK-NEXT: long fn1 = fn_1 + fn; ; CHECK-NEXT: fn_1 = fn; +; CHECK-NEXT: fn = fn1; ; CHECK-NEXT: } -; CHECK-NEXT: long fn_11 = fn_2; -; CHECK-NEXT: long fn = fn_1; ; CHECK-NEXT: *fib = fn; } diff --git a/test/opt/insert-barrier.ir b/test/opt/insert-barrier.ir index 824ada77..06ebd842 100644 --- a/test/opt/insert-barrier.ir +++ b/test/opt/insert-barrier.ir @@ -158,8 +158,8 @@ func @region1() { func @no_barrier_spmd(%a: f32, %b: f32, %A: memref, %B: memref) { %c0 = constant 0 -> i32 - %c3 = constant 3 -> i32 - %c4 = constant 4 -> i32 + %c3 = constant 3 -> index + %c4 = constant 4 -> index parallel { %0 = subgroup_id %1 = cmp.eq %0, %c0 : i32 From 3638962423106b968b4526f0595dc294d1e84d0d Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 22 Oct 2024 10:09:59 +0200 Subject: [PATCH 064/297] Fix includes and move a few things around Signed-off-by: Carsten Uphoff --- examples/builder/main.cpp | 1 + src/CMakeLists.txt | 2 + src/binary.hpp | 1 + src/gemm_generator.cpp | 41 +-- src/gemm_generator.hpp | 16 -- src/gemm_tools.cpp | 45 +++ src/gemm_tools.hpp | 31 ++ src/pass/constant_folding.cpp | 259 +++++++++++++++++ ...olding_helper.hpp => constant_folding.hpp} | 30 +- src/pass/constant_propagation.cpp | 265 +----------------- src/pass/convert_to_opencl.cpp | 6 +- src/source.cpp | 1 + src/source.hpp | 2 +- src/tiling.cpp | 2 +- test/generator.cpp | 1 + test/smm.hpp | 2 + tools/argparser/argparser_common.cpp | 1 + tools/argparser/argparser_common.hpp | 2 +- tools/argparser/test.cpp | 2 + tools/offline_compiler/main.cpp | 1 - tools/opt/main.cpp | 1 - 21 files changed, 382 insertions(+), 330 deletions(-) create mode 100644 src/gemm_tools.cpp create mode 100644 src/gemm_tools.hpp create mode 100644 src/pass/constant_folding.cpp rename src/pass/{constant_folding_helper.hpp => constant_folding.hpp} (94%) diff --git a/examples/builder/main.cpp b/examples/builder/main.cpp index 3aeba1ea..0bab3c9b 100644 --- a/examples/builder/main.cpp +++ b/examples/builder/main.cpp @@ -7,6 +7,7 @@ #include #include #include +#include using namespace tinytc; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a8e8d654..c61333ad 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -31,6 +31,7 @@ set(SOURCES error.cpp func.cpp gemm_generator.cpp + gemm_tools.cpp inst.cpp location.cpp node/data_type_node.cpp @@ -41,6 +42,7 @@ set(SOURCES parser/parse_context.cpp parser.cpp pass/check_ir.cpp + pass/constant_folding.cpp pass/constant_propagation.cpp pass/convert_to_opencl.cpp pass/dead_code_elimination.cpp diff --git a/src/binary.hpp b/src/binary.hpp index 2b1950f3..ac5d678c 100644 --- a/src/binary.hpp +++ b/src/binary.hpp @@ -4,6 +4,7 @@ #ifndef BINARY_20240308_HPP #define BINARY_20240308_HPP +#include "compiler_context.hpp" #include "reference_counted.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" diff --git a/src/gemm_generator.cpp b/src/gemm_generator.cpp index 7e96ee66..a537ed2e 100644 --- a/src/gemm_generator.cpp +++ b/src/gemm_generator.cpp @@ -4,6 +4,7 @@ #include "gemm_generator.hpp" #include "codegen_tools.hpp" #include "device_info.hpp" +#include "gemm_tools.hpp" #include "scalar_type.hpp" #include "tiling.hpp" #include "tinytc/tinytc.hpp" @@ -26,6 +27,7 @@ #include #include #include +#include using namespace clir; @@ -84,45 +86,6 @@ std::string gemm_configuration::identifier(std::string_view prefix) const { return oss.str(); } -constexpr static int max_K_unrolling = 8; - -auto max_register_block_gemm(std::int32_t C_scalar_type_size_in_bytes, std::int32_t sgs, - std::int32_t register_space, - std::pair max_fill_fraction) - -> std::pair { - auto const arithmetic_intensity = [&sgs](std::int32_t row_blocks, std::int32_t cols) { - return (row_blocks * sgs * cols) / static_cast(row_blocks * sgs + cols); - }; - - auto const max_scalars = register_space * max_fill_fraction.first / - (max_fill_fraction.second * C_scalar_type_size_in_bytes); - - // The required number of scalars is given by - // row_blocks * sgs * (cols + max_K_unrolling) + cols * max_K_unrolling - auto const max_row_blocks = [&sgs, &max_scalars](std::int32_t cols) { - return (max_scalars - cols * max_K_unrolling) / (sgs * (cols + max_K_unrolling)); - }; - auto const max_cols = [&sgs, &max_scalars](std::int32_t row_blocks) { - return (max_scalars - row_blocks * sgs * max_K_unrolling) / - (row_blocks * sgs + max_K_unrolling); - }; - - double max_ai = 0.0; - std::int32_t row_blocks = 1, cols = 1; - for (std::int32_t r = 1; r <= max_row_blocks(1); ++r) { - for (std::int32_t c = 1; c <= max_cols(r); ++c) { - auto const ai = arithmetic_intensity(r, c); - if (ai > max_ai) { - max_ai = ai; - row_blocks = r; - cols = c; - } - } - } - - return std::make_pair(row_blocks, cols); -} - class generator { public: generator(gemm_configuration const &gemm_cfg, local_tiling const &tiling, diff --git a/src/gemm_generator.hpp b/src/gemm_generator.hpp index b24e9b45..7a89d7d5 100644 --- a/src/gemm_generator.hpp +++ b/src/gemm_generator.hpp @@ -16,7 +16,6 @@ #include #include #include -#include namespace tinytc { @@ -96,21 +95,6 @@ ::clir::func generate_gemm(gemm_configuration const &gemm_cfg, local_tiling cons ::clir::address_space Bs = ::clir::address_space::global_t, ::clir::address_space Cs = ::clir::address_space::global_t); -/** - * @brief Calculate maximum register blocking size of GEMM - * - * @param C_scalar_type_size_in_bytes Size of scalar type of result matrix in bytes - * @param sgs Subgroup size - * @param register_space Size of register file per core in bytes - * @param max_fill_fraction Fraction of register file that shall be blocked at most - * - * @return {number of row-blocks (block size = subgroup size), number of columns} - */ -auto max_register_block_gemm(std::int32_t C_scalar_type_size_in_bytes, std::int32_t sgs, - std::int32_t register_space, - std::pair max_fill_fraction = { - 1, 2}) -> std::pair; - } // namespace tinytc #endif // GEMM_GENERATOR_20240314_HPP diff --git a/src/gemm_tools.cpp b/src/gemm_tools.cpp new file mode 100644 index 00000000..33e2ceb0 --- /dev/null +++ b/src/gemm_tools.cpp @@ -0,0 +1,45 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "gemm_tools.hpp" + +namespace tinytc { + +auto max_register_block_gemm(std::int32_t C_scalar_type_size_in_bytes, std::int32_t sgs, + std::int32_t register_space, + std::pair max_fill_fraction) + -> std::pair { + auto const arithmetic_intensity = [&sgs](std::int32_t row_blocks, std::int32_t cols) { + return (row_blocks * sgs * cols) / static_cast(row_blocks * sgs + cols); + }; + + auto const max_scalars = register_space * max_fill_fraction.first / + (max_fill_fraction.second * C_scalar_type_size_in_bytes); + + // The required number of scalars is given by + // row_blocks * sgs * (cols + max_K_unrolling) + cols * max_K_unrolling + auto const max_row_blocks = [&sgs, &max_scalars](std::int32_t cols) { + return (max_scalars - cols * max_K_unrolling) / (sgs * (cols + max_K_unrolling)); + }; + auto const max_cols = [&sgs, &max_scalars](std::int32_t row_blocks) { + return (max_scalars - row_blocks * sgs * max_K_unrolling) / + (row_blocks * sgs + max_K_unrolling); + }; + + double max_ai = 0.0; + std::int32_t row_blocks = 1, cols = 1; + for (std::int32_t r = 1; r <= max_row_blocks(1); ++r) { + for (std::int32_t c = 1; c <= max_cols(r); ++c) { + auto const ai = arithmetic_intensity(r, c); + if (ai > max_ai) { + max_ai = ai; + row_blocks = r; + cols = c; + } + } + } + + return std::make_pair(row_blocks, cols); +} + +} // namespace tinytc diff --git a/src/gemm_tools.hpp b/src/gemm_tools.hpp new file mode 100644 index 00000000..1e03bda5 --- /dev/null +++ b/src/gemm_tools.hpp @@ -0,0 +1,31 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GEMM_TOOLS_20241022_HPP +#define GEMM_TOOLS_20241022_HPP + +#include +#include + +namespace tinytc { + +constexpr static int max_K_unrolling = 8; + +/** + * @brief Calculate maximum register blocking size of GEMM + * + * @param C_scalar_type_size_in_bytes Size of scalar type of result matrix in bytes + * @param sgs Subgroup size + * @param register_space Size of register file per core in bytes + * @param max_fill_fraction Fraction of register file that shall be blocked at most + * + * @return {number of row-blocks (block size = subgroup size), number of columns} + */ +auto max_register_block_gemm(std::int32_t C_scalar_type_size_in_bytes, std::int32_t sgs, + std::int32_t register_space, + std::pair max_fill_fraction = { + 1, 2}) -> std::pair; + +} // namespace tinytc + +#endif // GEMM_TOOLS_20241022_HPP diff --git a/src/pass/constant_folding.cpp b/src/pass/constant_folding.cpp new file mode 100644 index 00000000..9676321f --- /dev/null +++ b/src/pass/constant_folding.cpp @@ -0,0 +1,259 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/constant_folding.hpp" +#include "error.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "node/value_node.hpp" +#include "scalar_type.hpp" +#include "support/casting.hpp" +#include "support/visit.hpp" +#include "tinytc/tinytc.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +template class unary_op_dispatcher { + private: + scalar_type switch_ty; + F computer; + + public: + unary_op_dispatcher(scalar_type sw_ty, F &&f) + : switch_ty{sw_ty}, computer{std::forward(f)} {} + + auto operator()(std::int64_t const &A) -> fold_result { + switch (switch_ty) { + case scalar_type::i1: + return computer.template operator()(A); + case scalar_type::i8: + return computer.template operator()(A); + case scalar_type::i16: + return computer.template operator()(A); + case scalar_type::i32: + return computer.template operator()(A); + case scalar_type::i64: + return computer.template operator()(A); + case scalar_type::index: + return computer.template operator()(A); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + }; + } + auto operator()(double const &A) -> fold_result { + switch (switch_ty) { + case scalar_type::f32: + return computer.template operator()(A); + case scalar_type::f64: + return computer.template operator()(A); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + } + } + auto operator()(std::complex const &A) -> fold_result { + switch (switch_ty) { + case scalar_type::c32: + return computer.template operator()>(A); + case scalar_type::c64: + return computer.template operator()>(A); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + } + } +}; + +template class binary_op_dispatcher { + private: + scalar_type switch_ty; + F computer; + + public: + binary_op_dispatcher(scalar_type sw_ty, F &&f) + : switch_ty{sw_ty}, computer{std::forward(f)} {} + + auto operator()(std::int64_t const &A, std::int64_t const &B) -> fold_result { + switch (switch_ty) { + case scalar_type::i1: + return computer.template operator()(A, B); + case scalar_type::i8: + return computer.template operator()(A, B); + case scalar_type::i16: + return computer.template operator()(A, B); + case scalar_type::i32: + return computer.template operator()(A, B); + case scalar_type::i64: + return computer.template operator()(A, B); + case scalar_type::index: + return computer.template operator()(A, B); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + }; + } + auto operator()(double const &A, double const &B) -> fold_result { + switch (switch_ty) { + case scalar_type::f32: + return computer.template operator()(A, B); + case scalar_type::f64: + return computer.template operator()(A, B); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + } + } + auto operator()(std::complex const &A, std::complex const &B) -> fold_result { + switch (switch_ty) { + case scalar_type::c32: + return computer.template operator()>(A, B); + case scalar_type::c64: + return computer.template operator()>(A, B); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + } + } + template auto operator()(T const &, U const &) -> fold_result { + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + } +}; + +constant_folding::constant_folding(bool unsafe_fp_math) : unsafe_fp_math_(unsafe_fp_math) {} + +auto constant_folding::get_memref_type(value_node const &v) const -> const memref_data_type * { + auto t = dyn_cast(v.ty()); + if (t == nullptr) { + throw compilation_error(v.loc(), status::ir_expected_memref); + } + return t; +} + +auto constant_folding::operator()(inst_node &) -> fold_result { return {}; } + +auto constant_folding::operator()(arith_inst &in) -> fold_result { + auto &op_a = in.a(); + auto &op_b = in.b(); + + constant_inst *a_const = dyn_cast(op_a.defining_inst()); + constant_inst *b_const = dyn_cast(op_b.defining_inst()); + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + // Arithmetic on coopmatrix is component-wise and if a coopmatrix is constant, then all + // elements have the same value. Thus, constant folding on coopmatrix types is identical to + // constant folding on scalar types. + auto ct = dyn_cast(op_a.ty()); + if (ct == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_coopmatrix_or_scalar); + } + at = dyn_cast(ct->ty()); + } + + if (a_const != nullptr && b_const != nullptr) { + auto computer = compute_binary_op{in.operation(), op_a.ty(), in.loc()}; + auto dispatcher = binary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), a_const->value(), b_const->value()); + } else if (a_const != nullptr) { + auto computer = + compute_binop_identities{unsafe_fp_math_, in.operation(), op_b, true, in.loc()}; + auto dispatcher = unary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), a_const->value()); + } else if (b_const != nullptr) { + auto computer = + compute_binop_identities{unsafe_fp_math_, in.operation(), op_a, false, in.loc()}; + auto dispatcher = unary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), b_const->value()); + } + return tinytc_value_t{}; +} + +auto constant_folding::operator()(arith_unary_inst &in) -> fold_result { + auto &op_a = in.a(); + + constant_inst *a_const = dyn_cast(op_a.defining_inst()); + if (a_const == nullptr) { + return tinytc_value_t{}; + } + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + // Arithmetic on coopmatrix is component-wise and if a coopmatrix is constant, then all + // elements have the same value. Thus, constant folding on coopmatrix types is identical to + // constant folding on scalar types. + auto ct = dyn_cast(op_a.ty()); + if (ct == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_coopmatrix_or_scalar); + } + at = dyn_cast(ct->ty()); + } + + auto computer = compute_unary_op{in.operation(), op_a.ty(), in.loc()}; + auto dispatcher = unary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), a_const->value()); +} + +auto constant_folding::operator()(cast_inst &in) -> fold_result { + auto &op_a = in.a(); + + constant_inst *a_const = dyn_cast(op_a.defining_inst()); + if (a_const == nullptr) { + return tinytc_value_t{}; + } + + auto rt = dyn_cast(in.result(0).ty()); + if (rt == nullptr) { + // Cast on coopmatrix is component-wise and if a coopmatrix is constant, then all + // elements have the same value. Thus, constant folding on coopmatrix types is identical to + // constant folding on scalar types. + auto ct = dyn_cast(in.result(0).ty()); + if (ct == nullptr) { + throw compilation_error(in.result(0).loc(), status::ir_expected_coopmatrix_or_scalar); + } + rt = dyn_cast(ct->ty()); + } + + return std::visit( + overloaded{[&](auto A) -> fold_result { return compute_cast(rt, A, in.loc()); }}, + a_const->value()); +} + +auto constant_folding::operator()(compare_inst &in) -> fold_result { + auto &op_a = in.a(); + auto &op_b = in.b(); + + constant_inst *a_const = dyn_cast(op_a.defining_inst()); + constant_inst *b_const = dyn_cast(op_b.defining_inst()); + if (a_const == nullptr || b_const == nullptr) { + return tinytc_value_t{}; + } + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_scalar); + } + + auto computer = compute_compare{in.cond(), in.result(0).ty(), in.loc()}; + auto dispatcher = binary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), a_const->value(), b_const->value()); +} + +auto constant_folding::operator()(size_inst &in) -> fold_result { + auto ct = get_memref_type(in.operand()); + + auto mode_size = ct->shape(in.mode()); + if (!is_dynamic_value(mode_size)) { + return make_constant( + mode_size, scalar_data_type::get(in.operand().context(), scalar_type::index), in.loc()); + } + + return tinytc_value_t{}; +} + +} // namespace tinytc diff --git a/src/pass/constant_folding_helper.hpp b/src/pass/constant_folding.hpp similarity index 94% rename from src/pass/constant_folding_helper.hpp rename to src/pass/constant_folding.hpp index 4b66ea8a..baee354c 100644 --- a/src/pass/constant_folding_helper.hpp +++ b/src/pass/constant_folding.hpp @@ -4,25 +4,49 @@ #ifndef CONSTANT_FOLDING_HELPER_20241011_HPP #define CONSTANT_FOLDING_HELPER_20241011_HPP +#include "error.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "node/value_node.hpp" #include "scalar_type.hpp" #include "support/casting.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include #include -#include +#include #include #include namespace tinytc { +using fold_result = std::variant; + +class constant_folding { + public: + constant_folding(bool unsafe_fp_math); + + auto operator()(inst_node &) -> fold_result; + auto operator()(arith_inst &) -> fold_result; + auto operator()(arith_unary_inst &) -> fold_result; + auto operator()(cast_inst &) -> fold_result; + auto operator()(compare_inst &) -> fold_result; + auto operator()(size_inst &in) -> fold_result; + + private: + auto get_memref_type(value_node const &v) const -> const memref_data_type *; + + bool unsafe_fp_math_; +}; + template struct is_complex : public std::false_type {}; template requires(std::is_floating_point_v) struct is_complex> : public std::true_type {}; template inline constexpr bool is_complex_v = is_complex::value; -using fold_result = std::variant; - struct compute_unary_op { arithmetic_unary operation; data_type ty; diff --git a/src/pass/constant_propagation.cpp b/src/pass/constant_propagation.cpp index bb495bcf..8ae6e5e9 100644 --- a/src/pass/constant_propagation.cpp +++ b/src/pass/constant_propagation.cpp @@ -2,283 +2,20 @@ // SPDX-License-Identifier: BSD-3-Clause #include "pass/constant_propagation.hpp" -#include "error.hpp" -#include "node/data_type_node.hpp" #include "node/function_node.hpp" #include "node/inst_node.hpp" #include "node/region_node.hpp" #include "node/value_node.hpp" -#include "pass/constant_folding_helper.hpp" -#include "scalar_type.hpp" -#include "support/casting.hpp" +#include "pass/constant_folding.hpp" #include "support/ilist.hpp" #include "support/ilist_base.hpp" #include "support/visit.hpp" #include "tinytc/tinytc.hpp" -#include -#include -#include -#include #include namespace tinytc { -template class unary_op_dispatcher { - private: - scalar_type switch_ty; - F computer; - - public: - unary_op_dispatcher(scalar_type sw_ty, F &&f) - : switch_ty{sw_ty}, computer{std::forward(f)} {} - - auto operator()(std::int64_t const &A) -> fold_result { - switch (switch_ty) { - case scalar_type::i1: - return computer.template operator()(A); - case scalar_type::i8: - return computer.template operator()(A); - case scalar_type::i16: - return computer.template operator()(A); - case scalar_type::i32: - return computer.template operator()(A); - case scalar_type::i64: - return computer.template operator()(A); - case scalar_type::index: - return computer.template operator()(A); - default: - throw compilation_error(computer.loc, status::ir_scalar_mismatch); - break; - }; - } - auto operator()(double const &A) -> fold_result { - switch (switch_ty) { - case scalar_type::f32: - return computer.template operator()(A); - case scalar_type::f64: - return computer.template operator()(A); - default: - throw compilation_error(computer.loc, status::ir_scalar_mismatch); - break; - } - } - auto operator()(std::complex const &A) -> fold_result { - switch (switch_ty) { - case scalar_type::c32: - return computer.template operator()>(A); - case scalar_type::c64: - return computer.template operator()>(A); - default: - throw compilation_error(computer.loc, status::ir_scalar_mismatch); - break; - } - } -}; - -template class binary_op_dispatcher { - private: - scalar_type switch_ty; - F computer; - - public: - binary_op_dispatcher(scalar_type sw_ty, F &&f) - : switch_ty{sw_ty}, computer{std::forward(f)} {} - - auto operator()(std::int64_t const &A, std::int64_t const &B) -> fold_result { - switch (switch_ty) { - case scalar_type::i1: - return computer.template operator()(A, B); - case scalar_type::i8: - return computer.template operator()(A, B); - case scalar_type::i16: - return computer.template operator()(A, B); - case scalar_type::i32: - return computer.template operator()(A, B); - case scalar_type::i64: - return computer.template operator()(A, B); - case scalar_type::index: - return computer.template operator()(A, B); - default: - throw compilation_error(computer.loc, status::ir_scalar_mismatch); - break; - }; - } - auto operator()(double const &A, double const &B) -> fold_result { - switch (switch_ty) { - case scalar_type::f32: - return computer.template operator()(A, B); - case scalar_type::f64: - return computer.template operator()(A, B); - default: - throw compilation_error(computer.loc, status::ir_scalar_mismatch); - break; - } - } - auto operator()(std::complex const &A, std::complex const &B) -> fold_result { - switch (switch_ty) { - case scalar_type::c32: - return computer.template operator()>(A, B); - case scalar_type::c64: - return computer.template operator()>(A, B); - default: - throw compilation_error(computer.loc, status::ir_scalar_mismatch); - break; - } - } - template auto operator()(T const &, U const &) -> fold_result { - throw compilation_error(computer.loc, status::ir_scalar_mismatch); - } -}; - -class constant_folding { - public: - constant_folding(bool unsafe_fp_math); - - auto operator()(inst_node &) -> fold_result; - auto operator()(arith_inst &) -> fold_result; - auto operator()(arith_unary_inst &) -> fold_result; - auto operator()(cast_inst &) -> fold_result; - auto operator()(compare_inst &) -> fold_result; - auto operator()(size_inst &in) -> fold_result; - - private: - auto get_memref_type(value_node const &v) const -> const memref_data_type *; - - bool unsafe_fp_math_; -}; - -constant_folding::constant_folding(bool unsafe_fp_math) : unsafe_fp_math_(unsafe_fp_math) {} - -auto constant_folding::get_memref_type(value_node const &v) const -> const memref_data_type * { - auto t = dyn_cast(v.ty()); - if (t == nullptr) { - throw compilation_error(v.loc(), status::ir_expected_memref); - } - return t; -} - -auto constant_folding::operator()(inst_node &) -> fold_result { return {}; } - -auto constant_folding::operator()(arith_inst &in) -> fold_result { - auto &op_a = in.a(); - auto &op_b = in.b(); - - constant_inst *a_const = dyn_cast(op_a.defining_inst()); - constant_inst *b_const = dyn_cast(op_b.defining_inst()); - - auto at = dyn_cast(op_a.ty()); - if (at == nullptr) { - // Arithmetic on coopmatrix is component-wise and if a coopmatrix is constant, then all - // elements have the same value. Thus, constant folding on coopmatrix types is identical to - // constant folding on scalar types. - auto ct = dyn_cast(op_a.ty()); - if (ct == nullptr) { - throw compilation_error(op_a.loc(), status::ir_expected_coopmatrix_or_scalar); - } - at = dyn_cast(ct->ty()); - } - - if (a_const != nullptr && b_const != nullptr) { - auto computer = compute_binary_op{in.operation(), op_a.ty(), in.loc()}; - auto dispatcher = binary_op_dispatcher{at->ty(), std::move(computer)}; - return std::visit(std::move(dispatcher), a_const->value(), b_const->value()); - } else if (a_const != nullptr) { - auto computer = - compute_binop_identities{unsafe_fp_math_, in.operation(), op_b, true, in.loc()}; - auto dispatcher = unary_op_dispatcher{at->ty(), std::move(computer)}; - return std::visit(std::move(dispatcher), a_const->value()); - } else if (b_const != nullptr) { - auto computer = - compute_binop_identities{unsafe_fp_math_, in.operation(), op_a, false, in.loc()}; - auto dispatcher = unary_op_dispatcher{at->ty(), std::move(computer)}; - return std::visit(std::move(dispatcher), b_const->value()); - } - return tinytc_value_t{}; -} - -auto constant_folding::operator()(arith_unary_inst &in) -> fold_result { - auto &op_a = in.a(); - - constant_inst *a_const = dyn_cast(op_a.defining_inst()); - if (a_const == nullptr) { - return tinytc_value_t{}; - } - - auto at = dyn_cast(op_a.ty()); - if (at == nullptr) { - // Arithmetic on coopmatrix is component-wise and if a coopmatrix is constant, then all - // elements have the same value. Thus, constant folding on coopmatrix types is identical to - // constant folding on scalar types. - auto ct = dyn_cast(op_a.ty()); - if (ct == nullptr) { - throw compilation_error(op_a.loc(), status::ir_expected_coopmatrix_or_scalar); - } - at = dyn_cast(ct->ty()); - } - - auto computer = compute_unary_op{in.operation(), op_a.ty(), in.loc()}; - auto dispatcher = unary_op_dispatcher{at->ty(), std::move(computer)}; - return std::visit(std::move(dispatcher), a_const->value()); -} - -auto constant_folding::operator()(cast_inst &in) -> fold_result { - auto &op_a = in.a(); - - constant_inst *a_const = dyn_cast(op_a.defining_inst()); - if (a_const == nullptr) { - return tinytc_value_t{}; - } - - auto rt = dyn_cast(in.result(0).ty()); - if (rt == nullptr) { - // Cast on coopmatrix is component-wise and if a coopmatrix is constant, then all - // elements have the same value. Thus, constant folding on coopmatrix types is identical to - // constant folding on scalar types. - auto ct = dyn_cast(in.result(0).ty()); - if (ct == nullptr) { - throw compilation_error(in.result(0).loc(), status::ir_expected_coopmatrix_or_scalar); - } - rt = dyn_cast(ct->ty()); - } - - return std::visit( - overloaded{[&](auto A) -> fold_result { return compute_cast(rt, A, in.loc()); }}, - a_const->value()); -} - -auto constant_folding::operator()(compare_inst &in) -> fold_result { - auto &op_a = in.a(); - auto &op_b = in.b(); - - constant_inst *a_const = dyn_cast(op_a.defining_inst()); - constant_inst *b_const = dyn_cast(op_b.defining_inst()); - if (a_const == nullptr || b_const == nullptr) { - return tinytc_value_t{}; - } - - auto at = dyn_cast(op_a.ty()); - if (at == nullptr) { - throw compilation_error(op_a.loc(), status::ir_expected_scalar); - } - - auto computer = compute_compare{in.cond(), in.result(0).ty(), in.loc()}; - auto dispatcher = binary_op_dispatcher{at->ty(), std::move(computer)}; - return std::visit(std::move(dispatcher), a_const->value(), b_const->value()); -} - -auto constant_folding::operator()(size_inst &in) -> fold_result { - auto ct = get_memref_type(in.operand()); - - auto mode_size = ct->shape(in.mode()); - if (!is_dynamic_value(mode_size)) { - return make_constant( - mode_size, scalar_data_type::get(in.operand().context(), scalar_type::index), in.loc()); - } - - return tinytc_value_t{}; -} - void constant_propagation_pass::run_on_function(function_node &fn) { run_on_region(fn.body()); } void constant_propagation_pass::run_on_region(region_node ®) { diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index bb3f023c..0c6730b5 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -681,9 +681,9 @@ convert_to_opencl_pass::operator()(cooperative_matrix_mul_add_inst const &c) { auto const add = [&](auto a_ty, auto b_ty, auto c_ty, auto a, auto b, auto c, auto c_next) { if (a_ty == b_ty && b_ty == c_ty) { - clinst.emplace_back(expression_statement( - assignment(std::move(c_next), - fma(std::move(a), std::move(b), std::move(c))))); + clinst.emplace_back(expression_statement(assignment( + std::move(c_next), + clir::fma(std::move(a), std::move(b), std::move(c))))); } else { clinst.emplace_back(expression_statement( assignment(std::move(c_next), diff --git a/src/source.cpp b/src/source.cpp index f4f9fca6..fcea0153 100644 --- a/src/source.cpp +++ b/src/source.cpp @@ -6,6 +6,7 @@ #include "tinytc/tinytc.h" #include +#include using namespace tinytc; diff --git a/src/source.hpp b/src/source.hpp index ade28c25..13e4797a 100644 --- a/src/source.hpp +++ b/src/source.hpp @@ -4,13 +4,13 @@ #ifndef SOURCE_20240412_HPP #define SOURCE_20240412_HPP +#include "compiler_context.hpp" #include "reference_counted.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include #include -#include #include struct tinytc_source : tinytc::reference_counted { diff --git a/src/tiling.cpp b/src/tiling.cpp index 6d33e14f..0589a429 100644 --- a/src/tiling.cpp +++ b/src/tiling.cpp @@ -3,7 +3,7 @@ #include "tiling.hpp" #include "device_info.hpp" -#include "gemm_generator.hpp" +#include "gemm_tools.hpp" #include "tinytc/tinytc.hpp" #include diff --git a/test/generator.cpp b/test/generator.cpp index 8b88d2dc..e5ba8887 100644 --- a/test/generator.cpp +++ b/test/generator.cpp @@ -3,6 +3,7 @@ #include "device_info.hpp" #include "gemm_generator.hpp" +#include "gemm_tools.hpp" #include "reference_counted.hpp" #include "scalar_type.hpp" #include "support/util.hpp" diff --git a/test/smm.hpp b/test/smm.hpp index 717b4a5f..83ca2f55 100644 --- a/test/smm.hpp +++ b/test/smm.hpp @@ -10,12 +10,14 @@ #include +#include #include #include #include #include #include #include +#include #include #define DOCTEST_TENSOR4_TEST(MM, NN, KK, HH) \ diff --git a/tools/argparser/argparser_common.cpp b/tools/argparser/argparser_common.cpp index 00bbd543..f18e64f4 100644 --- a/tools/argparser/argparser_common.cpp +++ b/tools/argparser/argparser_common.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause #include "argparser_common.hpp" +#include "argparser.hpp" #include "support/fnv1a.hpp" #include "tinytc/tinytc.hpp" diff --git a/tools/argparser/argparser_common.hpp b/tools/argparser/argparser_common.hpp index 24462cb9..5ebe42a4 100644 --- a/tools/argparser/argparser_common.hpp +++ b/tools/argparser/argparser_common.hpp @@ -4,9 +4,9 @@ #ifndef ARGPARSER_COMMON_20241010_HPP #define ARGPARSER_COMMON_20241010_HPP -#include "argparser.hpp" #include "tinytc/types.hpp" +#include #include #include #include diff --git a/tools/argparser/test.cpp b/tools/argparser/test.cpp index c73dd071..6a722b1b 100644 --- a/tools/argparser/test.cpp +++ b/tools/argparser/test.cpp @@ -3,7 +3,9 @@ #include "argparser.hpp" +#include #include +#include using namespace tinytc; diff --git a/tools/offline_compiler/main.cpp b/tools/offline_compiler/main.cpp index cacc816f..2a8c5cd0 100644 --- a/tools/offline_compiler/main.cpp +++ b/tools/offline_compiler/main.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include #include diff --git a/tools/opt/main.cpp b/tools/opt/main.cpp index dc6157c5..aa9d35bc 100644 --- a/tools/opt/main.cpp +++ b/tools/opt/main.cpp @@ -10,7 +10,6 @@ #include #include #include -#include #include using namespace tinytc; From 6d8b1cb194c47dedcb816347c26b722760fe834d Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 23 Oct 2024 09:41:26 +0200 Subject: [PATCH 065/297] Add checked flag and lower gemm Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.yaml | 2 + docs/api/builder_cxxapi.yaml | 1 + docs/manual/tensor-ir.rst | 42 +++++-- include/tinytc/tinytc.h | 18 +-- include/tinytc/tinytc.hpp | 37 +++--- include/tinytc/types.h | 8 ++ include/tinytc/types.hpp | 8 ++ src/codegen_tools.cpp | 129 +++++++++++++++------ src/codegen_tools.hpp | 5 + src/gemm_generator.cpp | 3 +- src/gemm_tools.cpp | 70 +++++++++--- src/gemm_tools.hpp | 15 ++- src/inst.cpp | 59 +++++++--- src/node/inst_node.cpp | 17 ++- src/node/inst_node.hpp | 19 ++-- src/parser/lexer.re | 6 +- src/parser/parser_impl.yy | 8 +- src/pass/constant_folding.cpp | 2 +- src/pass/convert_to_opencl.cpp | 62 +++++----- src/pass/dump_ir.cpp | 8 +- src/pass/lower_linalg.cpp | 187 +++++++++++++++++++++++++++++-- src/tiling.cpp | 6 +- test/codegen/coopmatrix_load.ir | 132 +++++++++++++--------- test/codegen/coopmatrix_store.ir | 32 +++++- test/generator.cpp | 26 ++++- 25 files changed, 685 insertions(+), 217 deletions(-) diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index 6d99df31..5aaa570c 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -6,6 +6,7 @@ Builder C-API: - tinytc_address_space_t - tinytc_arithmetic_t - tinytc_arithmetic_unary_t + - tinytc_checked_flag_t - tinytc_cmp_condition_t - tinytc_matrix_use_t - tinytc_scalar_type_t @@ -17,6 +18,7 @@ Builder C-API: - tinytc_address_space_to_string - tinytc_arithmetic_to_string - tinytc_arithmetic_unary_to_string + - tinytc_checked_flag_to_string - tinytc_cmp_condition_to_string - tinytc_matrix_use_to_string - tinytc_scalar_type_size diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index f2c07869..0797a739 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -16,6 +16,7 @@ Builder C++-API: - tinytc::to_string(address_space) - tinytc::to_string(arithmetic) - tinytc::to_string(arithmetic_unary) + - tinytc::to_string(checked_flag) - tinytc::to_string(cmp_condition) - tinytc::to_string(matrix_use) - tinytc::to_string(scalar_type) diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 8a1a89dd..67af2edb 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -28,6 +28,9 @@ An SPMD instruction follows the OpenCL execution model, where local variables ma for each work-item. Mixed instructions accept both varying and uniform local variables. +In an SPMD region, we call an argument *dynamically uniform* if all work-items in a subgroup have +the same value. + Regions come in two different kinds: collective and SPMD. A collective instructions must only appear in a collective region, and an SPMD instruction must only appear in a SPMD region. Mixed instructions might appear in both kinds of regions. @@ -271,6 +274,8 @@ position of the matrix in a matrix multiplication. Not all matrix shapes need to be supported in the implementation. The supported matrix shapes may depend on data type, matrix use, and target hardware. +An argument to any instruction that has coopmatrix type **must** be dynamically uniform. + Instructions ============ @@ -441,7 +446,7 @@ GEMV implements the well-known GEMM BLAS-2 operation. .. math:: - c := \alpha \text{op}_1(A) b + \beta C + c := \alpha \text{op}_1(A) b + \beta c If the atomic flag is set, c is updated atomically. @@ -774,9 +779,10 @@ Cooperative matrix load .. code:: abnf - value-instruction =/ "cooperative_matrix_load" transpose [".checked"] + value-instruction =/ "cooperative_matrix_load" transpose checked-flag local-identifier "[" local-identifier "," local-identifier "]" ":" memref-type "->" coopmatrix-type + checked-flag = ".rows_checked" / ".cols_checked" / ".both_checked" Overview ~~~~~~~~ @@ -790,15 +796,25 @@ position :math:`x, y`, then the components :math:`A_{ij}` of the coopmatrix are \forall i \in [0,X), j \in [0,Y): A_{ij} := M[(x + i) S_1 + (y + j) S_2] -When the checked flag is set, memory loads that would be out of bounds are not executed and the corresponding -value in the cooperative matrix are set to 0. - When the transpose modifier ".t" is given, we have .. math:: \forall i \in [0,X), j \in [0,Y): A_{ij} := M[(x + j) S_1 + (y + i) S_2] +When the checked flag is set, the following out-of-bound checks are added: + +=============== ======================================================================================================= +Flag Description +=============== ======================================================================================================= +.rows_checked.n :math:`A_{ij} := M[...] \text{ if } 0 \leq x+i < X \text{ else } 0` +.rows_checked.t :math:`A_{ij} := M[...] \text{ if } 0 \leq y+i < Y \text{ else } 0` +.cols_checked.n :math:`A_{ij} := M[...] \text{ if } 0 \leq y+j < Y \text{ else } 0` +.cols_checked.t :math:`A_{ij} := M[...] \text{ if } 0 \leq x+j < X \text{ else } 0` +.both_checked.n .rows_checked.n + .cols_checked.n +.both_checked.t .rows_checked.t + .cols_checked.t +=============== ======================================================================================================= + Arguments ~~~~~~~~~ @@ -806,6 +822,8 @@ The first operand must have memref type of dimension 2 with the same component t as the coopmatrix type. The indices must be of ``index`` type. +All arguments **must** be dynamically uniform. + Cooperative matrix mul add .......................... @@ -853,7 +871,7 @@ Cooperative matrix store .. code:: abnf - instruction =/ "cooperative_matrix_store" [".checked"] [store-flag] + instruction =/ "cooperative_matrix_store" checked-flag [store-flag] local-identifier "," local-identifier "[" local-identifier "," local-identifier "]" ":" coopmatrix-type "," memref-type @@ -869,7 +887,15 @@ position :math:`x, y`, then the components :math:`A_{ij}` of the coopmatrix are \forall i \in [0,X), j \in [0,Y): M[(x + i) S_1 + (y + j) S_2] := A_{ij} -If the checked flag is set, only memory locations that are in-bounds are written. +When the checked flag is set, the following out-of-bound checks are added: + +============= ======================================================================================================= +Flag Description +============= ======================================================================================================= +.rows_checked Only execute store if :math:`0 \leq x+i < X` +.cols_checked Only execute store if :math:`0 \leq y+j < Y` +.both_checked .rows_checked + .cols_checked +============= ======================================================================================================= The store is atomic when the atomic flag is set with relaxed memory ordering. When the atomic_add flag is set, the coopmatrix is added to the memref atomically. @@ -883,6 +909,8 @@ Arguments The first operand must have cooperative matrix type with the same component type as the memref type. The indices must be of ``index`` type. +All arguments **must** be dynamically uniform. + Expand ...... diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 139b5a8e..37e86284 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -170,6 +170,8 @@ TINYTC_EXPORT char const *tinytc_address_space_to_string(tinytc_address_space_t TINYTC_EXPORT char const *tinytc_arithmetic_to_string(tinytc_arithmetic_t op); //! Convert arithmetic operation type to string (unary) TINYTC_EXPORT char const *tinytc_arithmetic_unary_to_string(tinytc_arithmetic_unary_t op); +//! Convert checked flag to string +TINYTC_EXPORT char const *tinytc_checked_flag_to_string(tinytc_checked_flag_t flag); //! Convert cmp condition to string TINYTC_EXPORT char const *tinytc_cmp_condition_to_string(tinytc_cmp_condition_t cond); //! Convert matrix use to string @@ -325,7 +327,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_zero(tinytc_inst_t *in * * @param instr [out] pointer to the inst object created * @param transpose [in] transpose operation applied on load - * @param checked [in] true for out-of-bounds checks + * @param flag [in] out-of-bounds checks type * @param op [in] %op * @param p0 [in] %p0 * @param p1 [in] %p1 @@ -335,8 +337,9 @@ TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_zero(tinytc_inst_t *in * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_load_inst_create( - tinytc_inst_t *instr, tinytc_transpose_t transpose, tinytc_bool_t checked, tinytc_value_t op, - tinytc_value_t p0, tinytc_value_t p1, tinytc_data_type_t to_ty, const tinytc_location_t *loc); + tinytc_inst_t *instr, tinytc_transpose_t transpose, tinytc_checked_flag_t flag, + tinytc_value_t op, tinytc_value_t p0, tinytc_value_t p1, tinytc_data_type_t to_ty, + const tinytc_location_t *loc); /** * @brief Create cooperative matrix mul add instruction @@ -378,8 +381,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_scale_inst_create( * @endcode * * @param instr [out] pointer to the inst object created - * @param checked [in] true for out-of-bounds checks - * @param flag [in] store flag + * @param cflag [in] out-of-bounds checks type + * @param sflag [in] store flag * @param val [in] %val * @param op [in] %op * @param p0 [in] %p0 @@ -389,8 +392,9 @@ TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_scale_inst_create( * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_store_inst_create( - tinytc_inst_t *instr, tinytc_bool_t checked, tinytc_store_flag_t flag, tinytc_value_t val, - tinytc_value_t op, tinytc_value_t p0, tinytc_value_t p1, const tinytc_location_t *loc); + tinytc_inst_t *instr, tinytc_checked_flag_t cflag, tinytc_store_flag_t sflag, + tinytc_value_t val, tinytc_value_t op, tinytc_value_t p0, tinytc_value_t p1, + const tinytc_location_t *loc); /** * @brief Create alloca instruction diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 959527a4..816a4536 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -693,6 +693,17 @@ inline char const *to_string(arithmetic_unary op) { return ::tinytc_arithmetic_unary_to_string(static_cast<::tinytc_arithmetic_unary_t>(op)); } +/** + * @brief Convert checked flag string + * + * @param flag Flag + * + * @return C-string + */ +inline char const *to_string(checked_flag flag) { + return ::tinytc_checked_flag_to_string(static_cast<::tinytc_checked_flag_t>(flag)); +} + /** * @brief Convert cmp condition to string * @@ -983,7 +994,7 @@ inline inst make_constant_zero(data_type ty, location const &loc = {}) { * @brief Create cooperative matrix load instruction * * @param trans transpose operation applied on load - * @param checked true for out-of-bounds checks + * @param flag out-of-bounds checks type * @param op %op * @param p0 %p0 * @param p1 %p1 @@ -992,13 +1003,13 @@ inline inst make_constant_zero(data_type ty, location const &loc = {}) { * * @return Instruction */ -inline inst make_cooperative_matrix_load(transpose trans, bool checked, value op, value p0, +inline inst make_cooperative_matrix_load(transpose trans, checked_flag flag, value op, value p0, value p1, data_type to_ty, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC( - tinytc_cooperative_matrix_load_inst_create(&instr, static_cast(trans), - checked, op, p0, p1, to_ty, &loc), - loc); + CHECK_STATUS_LOC(tinytc_cooperative_matrix_load_inst_create( + &instr, static_cast(trans), + static_cast(flag), op, p0, p1, to_ty, &loc), + loc); return inst(instr); } @@ -1039,8 +1050,8 @@ inline inst make_cooperative_matrix_scale(value a, value b, location const &loc /** * @brief Create cooperative matrix store instruction * - * @param checked true for out-of-bounds checks - * @param flag store flag + * @param cflag out-of-bounds checks type + * @param sflag store flag * @param val %val * @param op %op * @param p0 %p0 @@ -1049,13 +1060,13 @@ inline inst make_cooperative_matrix_scale(value a, value b, location const &loc * * @return Instruction */ -inline inst make_cooperative_matrix_store(bool checked, store_flag flag, value val, value op, +inline inst make_cooperative_matrix_store(checked_flag cflag, store_flag sflag, value val, value op, value p0, value p1, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC( - tinytc_cooperative_matrix_store_inst_create( - &instr, checked, static_cast(flag), val, op, p0, p1, &loc), - loc); + CHECK_STATUS_LOC(tinytc_cooperative_matrix_store_inst_create( + &instr, static_cast(cflag), + static_cast(sflag), val, op, p0, p1, &loc), + loc); return inst(instr); } diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 68070f69..a084cabe 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -292,6 +292,14 @@ typedef enum { tinytc_address_space_local = 0x2 ///< Local memory, returned by alloca } tinytc_address_space_t; +//! Checked flag +typedef enum { + tinytc_checked_flag_none = 0, ///< Perform no checks + tinytc_checked_flag_rows = 1, ///< Check for out-of-bound rows + tinytc_checked_flag_cols = 2, ///< Check for out-of-bound cols + tinytc_checked_flag_both = 3 ///< Check for out-of-bound rows and cols +} tinytc_checked_flag_t; + //! Store flag typedef enum { tinytc_store_flag_regular = 0, ///< Non-atomic store diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 6f7a4ad9..b4c99b5e 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -272,6 +272,14 @@ enum class address_space { local = tinytc_address_space_local ///< Local memory, returned by alloca }; +//! Checked flag +enum class checked_flag { + none = tinytc_checked_flag_none, ///< Perform no checks + rows = tinytc_checked_flag_rows, ///< Check for out-of-bound rows + cols = tinytc_checked_flag_cols, ///< Check for out-of-bound cols + both = tinytc_checked_flag_both ///< Check for out-of-bound rows and cols +}; + //! Store flag enum class store_flag { regular = tinytc_store_flag_regular, ///< Non-atomic store diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 68ead93f..8f066d4a 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -6,6 +6,7 @@ #include "node/data_type_node.hpp" #include "node/inst_node.hpp" #include "node/value_node.hpp" +#include "pass/constant_folding.hpp" #include "scalar_type.hpp" #include "support/casting.hpp" #include "support/visit.hpp" @@ -452,23 +453,26 @@ void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, in auto c0 = bb.add(make_constant(0, index_ty)); auto c_tiles_1 = bb.add(make_constant(num_tiles - 1, index_ty)); - auto blocks = bb.add(make_arith(arithmetic::div, loop_trip_count, c_sgs)); - auto rem = bb.add(make_arith(arithmetic::rem, loop_trip_count, c_sgs)); + auto blocks = + instant_constant_fold_add(bb, make_arith(arithmetic::div, loop_trip_count, c_sgs)); + auto rem = instant_constant_fold_add(bb, make_arith(arithmetic::rem, loop_trip_count, c_sgs)); - auto sg_id_index = bb.add(make_cast(sg_id, index_ty)); - auto is_blocks_gt_0 = bb.add(make_cmp(cmp_condition::gt, blocks, c0)); + auto sg_id_index = instant_constant_fold_add(bb, make_cast(sg_id, index_ty)); + auto is_blocks_gt_0 = instant_constant_fold_add(bb, make_cmp(cmp_condition::gt, blocks, c0)); bb.if_condition(is_blocks_gt_0, [&](region_builder &bb) { - auto block_start = bb.add(make_arith(arithmetic::mul, c_sgs, sg_id_index)); - auto block_end = bb.add(make_arith(arithmetic::mul, c_sgs, blocks)); + auto block_start = + instant_constant_fold_add(bb, make_arith(arithmetic::mul, c_sgs, sg_id_index)); + auto block_end = instant_constant_fold_add(bb, make_arith(arithmetic::mul, c_sgs, blocks)); bb.for_loop(std::move(block_start), std::move(block_end), c_sgs_tiles, index_ty, [&](region_builder &bb, value block) { body(bb, block, false, c_sgs); }); }); - auto condition0 = bb.add(make_cmp(cmp_condition::gt, rem, c0)); + auto condition0 = instant_constant_fold_add(bb, make_cmp(cmp_condition::gt, rem, c0)); bb.if_condition(condition0, [&](region_builder &bb) { - auto condition1 = bb.add(make_cmp(cmp_condition::eq, sg_id_index, c_tiles_1)); + auto condition1 = + instant_constant_fold_add(bb, make_cmp(cmp_condition::eq, sg_id_index, c_tiles_1)); bb.if_condition(condition1, [&](region_builder &bb) { - auto block = bb.add(make_arith(arithmetic::mul, blocks, c_sgs)); + auto block = instant_constant_fold_add(bb, make_arith(arithmetic::mul, blocks, c_sgs)); body(bb, block, true, rem); }); }); @@ -483,9 +487,10 @@ void tile_loop_by_sgs_standard(region_builder &bb, value loop_trip_count, int sg tile_loop_by_sgs_new( bb, loop_trip_count, sgs, num_tiles, sg_id, [&m_index, &body](region_builder &bb, value block, bool is_remainder, value trip_count) { - auto mm = bb.add(make_arith(arithmetic::add, block, m_index)); + auto mm = instant_constant_fold_add(bb, make_arith(arithmetic::add, block, m_index)); if (is_remainder) { - auto cond = bb.add(make_cmp(cmp_condition::lt, m_index, trip_count)); + auto cond = + instant_constant_fold_add(bb, make_cmp(cmp_condition::lt, m_index, trip_count)); bb.if_condition(cond, [&](region_builder &bb) { body(bb, mm); }); } else { body(bb, mm); @@ -505,36 +510,38 @@ void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int bloc // blocks = ceil(loop_trip_count / block_size) = 1 + (loop_trip_count - 1) / block_size // blocks = ceil(blocks / num_tiles) * num_tiles = (1 + (blocks - 1) / num_tiles) * num_tiles auto c_block_size = bb.add(make_constant(block_size, index_ty)); - auto blocks0 = bb.add(make_arith(arithmetic::sub, loop_trip_count, c1)); - auto blocks1 = bb.add(make_arith(arithmetic::div, blocks0, c_block_size)); - auto blocks2 = bb.add(make_arith(arithmetic::div, blocks1, c_tiles)); - auto blocks3 = bb.add(make_arith(arithmetic::add, c1, blocks2)); - auto blocks = bb.add(make_arith(arithmetic::mul, blocks3, c_tiles)); - - auto bs = bb.add(make_arith(arithmetic::div, loop_trip_count, blocks)); - auto bs_1 = bb.add(make_arith(arithmetic::add, bs, c1)); - auto rem = bb.add(make_arith(arithmetic::rem, loop_trip_count, blocks)); - - auto sg_id_index = bb.add(make_cast(sg_id, index_ty)); + auto blocks0 = instant_constant_fold_add(bb, make_arith(arithmetic::sub, loop_trip_count, c1)); + auto blocks1 = + instant_constant_fold_add(bb, make_arith(arithmetic::div, blocks0, c_block_size)); + auto blocks2 = instant_constant_fold_add(bb, make_arith(arithmetic::div, blocks1, c_tiles)); + auto blocks3 = instant_constant_fold_add(bb, make_arith(arithmetic::add, c1, blocks2)); + auto blocks = instant_constant_fold_add(bb, make_arith(arithmetic::mul, blocks3, c_tiles)); + + auto bs = instant_constant_fold_add(bb, make_arith(arithmetic::div, loop_trip_count, blocks)); + auto bs_1 = instant_constant_fold_add(bb, make_arith(arithmetic::add, bs, c1)); + auto rem = instant_constant_fold_add(bb, make_arith(arithmetic::rem, loop_trip_count, blocks)); + + auto sg_id_index = instant_constant_fold_add(bb, make_cast(sg_id, index_ty)); // The following if makes it easy to eliminate the remainder handler in optimization if rem == 0 // is known at compile time. Without the if, we would need to prove that block_start_1 is // non-negative to eliminate the for-loop. - auto is_rem_gt_0 = bb.add(make_cmp(cmp_condition::gt, rem, c0)); + auto is_rem_gt_0 = instant_constant_fold_add(bb, make_cmp(cmp_condition::gt, rem, c0)); bb.if_condition(is_rem_gt_0, [&](region_builder &bb) { - auto block_start_1 = bb.add(make_arith(arithmetic::mul, bs_1, sg_id_index)); - auto block_end_1 = bb.add(make_arith(arithmetic::mul, bs_1, rem)); - auto step_1 = bb.add(make_arith(arithmetic::mul, bs_1, c_tiles)); + auto block_start_1 = + instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, sg_id_index)); + auto block_end_1 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, rem)); + auto step_1 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, c_tiles)); bb.for_loop(std::move(block_start_1), std::move(block_end_1), std::move(step_1), index_ty, [&](region_builder &bb, value block) { body(bb, block, bs_1); }); }); - auto tmp0 = bb.add(make_arith(arithmetic::rem, rem, c_tiles)); - auto tmp1 = bb.add(make_arith(arithmetic::add, sg_id_index, tmp0)); - auto sg_id_1 = bb.add(make_arith(arithmetic::rem, tmp1, c_tiles)); - auto tmp2 = bb.add(make_arith(arithmetic::mul, bs, sg_id_1)); - auto tmp3 = bb.add(make_arith(arithmetic::mul, bs_1, rem)); - auto block_start = bb.add(make_arith(arithmetic::add, tmp3, tmp2)); - auto step = bb.add(make_arith(arithmetic::mul, bs, c_tiles)); + auto tmp0 = instant_constant_fold_add(bb, make_arith(arithmetic::rem, rem, c_tiles)); + auto tmp1 = instant_constant_fold_add(bb, make_arith(arithmetic::add, sg_id_index, tmp0)); + auto sg_id_1 = instant_constant_fold_add(bb, make_arith(arithmetic::rem, tmp1, c_tiles)); + auto tmp2 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs, sg_id_1)); + auto tmp3 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, rem)); + auto block_start = instant_constant_fold_add(bb, make_arith(arithmetic::add, tmp3, tmp2)); + auto step = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs, c_tiles)); bb.for_loop(std::move(block_start), loop_trip_count, std::move(step), index_ty, [&](region_builder &bb, value block) { body(bb, block, bs); }); } @@ -557,7 +564,29 @@ auto mixed_precision_arithmetic(region_builder &bb, arithmetic operation, value b = bb.add(make_cast(b, compatible_ty, loc)); } } - return bb.add(make_arith(operation, a, b)); + return bb.add(make_arith(operation, a, b, loc)); +} +auto mixed_precision_coopmatrix_scale(region_builder &bb, value a, value b, + location const &loc) -> value { + scalar_data_type *at = dyn_cast(a->ty()); + if (at == nullptr) { + throw compilation_error(loc, status::ir_expected_scalar); + } + coopmatrix_data_type *bt = dyn_cast(b->ty()); + if (bt == nullptr) { + throw compilation_error(loc, status::ir_expected_coopmatrix); + } + const auto a_ty = at->ty(); + const auto b_ty = bt->component_ty(); + if (a_ty != b_ty) { + const auto compatible_scalar_ty = compatible_type(a_ty, b_ty); + + if (a_ty != compatible_type(a_ty, b_ty)) { + auto compatible_ty = scalar_data_type::get(at->context(), compatible_scalar_ty); + a = bb.add(make_cast(a, compatible_ty, loc)); + } + } + return bb.add(make_cooperative_matrix_scale(a, b, loc)); } auto get_atomic_store_flag(value beta) -> std::optional { @@ -589,4 +618,36 @@ void blas_update(region_builder &bb, bool atomic, value alpha, value ab, value b } } +auto instant_constant_fold_add(region_builder &bb, inst i) -> value { + auto ctx = i->context(); + if (!ctx) { + throw compilation_error(i->loc(), status::internal_compiler_error); + } + + auto fold = visit(constant_folding{ctx->opt_flag(optflag::unsafe_fp_math)}, *i); + auto val = std::visit(overloaded{[](tinytc_value_t &v) -> tinytc_value_t { return v; }, + [&bb](inst &j) -> tinytc_value_t { + if (j) { + return bb.add(std::move(j)); + } + return nullptr; + }}, + fold); + if (val) { + return val; + } + return bb.add(std::move(i)); +} + +auto get_int_constant(tinytc_value_t val) -> std::optional { + if (auto i = val->defining_inst(); i) { + if (auto *ci = dyn_cast(i); ci) { + if (std::holds_alternative(ci->value())) { + return std::get(ci->value()); + } + } + } + return std::nullopt; +} + } // namespace tinytc diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index a49db00f..8ede8a62 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -143,11 +143,16 @@ void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int bloc auto mixed_precision_arithmetic(region_builder &bb, arithmetic operation, value a, value b, location const &loc) -> value; +auto mixed_precision_coopmatrix_scale(region_builder &bb, value a, value b, + location const &loc) -> value; auto get_atomic_store_flag(value beta) -> std::optional; void blas_update(region_builder &bb, bool atomic, value alpha, value ab, value beta, value C, array_view index_list, location const &loc); +auto instant_constant_fold_add(region_builder &bb, inst i) -> value; +auto get_int_constant(tinytc_value_t val) -> std::optional; + } // namespace tinytc #endif // CODEGEN_TOOLS_20240229_HPP diff --git a/src/gemm_generator.cpp b/src/gemm_generator.cpp index a537ed2e..e6c8608c 100644 --- a/src/gemm_generator.cpp +++ b/src/gemm_generator.cpp @@ -299,8 +299,9 @@ void generator::add_function_body(block_builder &bb, var A, var B, var C, expr a // available for one of the buffers register_space /= 2; } - auto [max_row_blocks, max_cols] = + auto [max_rows, max_cols] = max_register_block_gemm(size(gemm_cfg.ty.C), core_cfg.subgroup_size, register_space); + const auto max_row_blocks = max_rows / core_cfg.subgroup_size; row_blocks_in_register = max_row_blocks; cols_in_register = max_cols; if (!is_dynamic_value(gemm_cfg.M)) { diff --git a/src/gemm_tools.cpp b/src/gemm_tools.cpp index 33e2ceb0..c9b73793 100644 --- a/src/gemm_tools.cpp +++ b/src/gemm_tools.cpp @@ -5,41 +5,83 @@ namespace tinytc { -auto max_register_block_gemm(std::int32_t C_scalar_type_size_in_bytes, std::int32_t sgs, - std::int32_t register_space, +auto max_register_block_gemm(std::int32_t C_scalar_type_size_in_bytes, std::int32_t subgroup_size, + std::int32_t register_space, std::int32_t C_blocks, std::pair max_fill_fraction) -> std::pair { - auto const arithmetic_intensity = [&sgs](std::int32_t row_blocks, std::int32_t cols) { - return (row_blocks * sgs * cols) / static_cast(row_blocks * sgs + cols); + auto const arithmetic_intensity = [](std::int32_t rows, std::int32_t cols) { + return (rows * cols) / static_cast(rows + cols); }; auto const max_scalars = register_space * max_fill_fraction.first / (max_fill_fraction.second * C_scalar_type_size_in_bytes); // The required number of scalars is given by - // row_blocks * sgs * (cols + max_K_unrolling) + cols * max_K_unrolling - auto const max_row_blocks = [&sgs, &max_scalars](std::int32_t cols) { - return (max_scalars - cols * max_K_unrolling) / (sgs * (cols + max_K_unrolling)); + // num_scalars = rows * (cols * C_blocks + max_K_unrolling) + cols * max_K_unrolling. + // Thus + // rows <= (max_scalars - cols * max_K_unrolling) / (cols * C_blocks + max_K_unrolling). + // Moreover, we require rows % subgroup_size = 0, so we set rows = k * subgroup_size and get + // k <= (max_scalars - cols * max_K_unrolling) / (subgroup_size * (cols * C_blocks + + // max_K_unrolling)). + auto const max_rows = [&subgroup_size, &max_scalars, &C_blocks](std::int32_t cols) { + const auto k = (max_scalars - cols * max_K_unrolling) / + (subgroup_size * (cols * C_blocks + max_K_unrolling)); + return k * subgroup_size; }; - auto const max_cols = [&sgs, &max_scalars](std::int32_t row_blocks) { - return (max_scalars - row_blocks * sgs * max_K_unrolling) / - (row_blocks * sgs + max_K_unrolling); + // Here, we have + // cols <= (num_scalars - rows * max_K_unrolling) / (rows * C_blocks + max_K_unrolling). + auto const max_cols = [&max_scalars, &C_blocks](std::int32_t rows) { + return (max_scalars - rows * max_K_unrolling) / (rows * C_blocks + max_K_unrolling); }; double max_ai = 0.0; - std::int32_t row_blocks = 1, cols = 1; - for (std::int32_t r = 1; r <= max_row_blocks(1); ++r) { + std::int32_t rows = subgroup_size, cols = 1; + for (std::int32_t r = subgroup_size; r <= max_rows(1); r += subgroup_size) { for (std::int32_t c = 1; c <= max_cols(r); ++c) { auto const ai = arithmetic_intensity(r, c); if (ai > max_ai) { max_ai = ai; - row_blocks = r; + rows = r; cols = c; } } } - return std::make_pair(row_blocks, cols); + return std::make_pair(rows, cols); +} + +// We have block_size(k) = k * subgroup_size, where k is a positive integer, +// and num_blocks(k) = ceil(size / block_size(k)) +// We want to solve +// max_k block_size(k) s.t. +// block_size(k) <= max_block_size ; must not exceed max block size +// and num_blocks(k) % num_tiles == 0 ; no load imbalance +// and block_size(k) - size < sgs ; no excessive block size +// +// If the above optimization does not have a solution, the minimum block size (= subgroup size is +// returned) +// +auto compute_m_block_size(std::int32_t subgroup_size, std::int32_t max_block_size, + std::int32_t num_tiles, std::int64_t size) -> std::int32_t { + auto const block_size = [&subgroup_size](std::int32_t k) -> std::int32_t { + return k * subgroup_size; + }; + auto const num_blocks = [&block_size, &size](std::int32_t k) -> std::int64_t { + return 1 + (size - 1) / block_size(k); + }; + std::int32_t k = max_block_size / subgroup_size; + while (k > 1 && (num_blocks(k) % num_tiles != 0 || block_size(k) - size >= subgroup_size)) { + --k; + } + return k * subgroup_size; +} + +auto compute_k_block_size(std::int64_t K) -> std::int32_t { + auto k_block_size = max_K_unrolling; + while (K < k_block_size && k_block_size > 1) { + k_block_size /= 2; + } + return k_block_size; } } // namespace tinytc diff --git a/src/gemm_tools.hpp b/src/gemm_tools.hpp index 1e03bda5..4e8c204b 100644 --- a/src/gemm_tools.hpp +++ b/src/gemm_tools.hpp @@ -9,23 +9,28 @@ namespace tinytc { -constexpr static int max_K_unrolling = 8; +constexpr static std::int32_t max_K_unrolling = 8; /** * @brief Calculate maximum register blocking size of GEMM * * @param C_scalar_type_size_in_bytes Size of scalar type of result matrix in bytes - * @param sgs Subgroup size + * @param subgroup_size Subgroup size * @param register_space Size of register file per core in bytes + * @param C_blocks Number of register blocks needed for C, usually 1, 2 for complex * @param max_fill_fraction Fraction of register file that shall be blocked at most * - * @return {number of row-blocks (block size = subgroup size), number of columns} + * @return {number of rows, number of columns} */ -auto max_register_block_gemm(std::int32_t C_scalar_type_size_in_bytes, std::int32_t sgs, - std::int32_t register_space, +auto max_register_block_gemm(std::int32_t C_scalar_type_size_in_bytes, std::int32_t subgroup_size, + std::int32_t register_space, std::int32_t C_blocks = 1, std::pair max_fill_fraction = { 1, 2}) -> std::pair; +auto compute_m_block_size(std::int32_t subgroup_size, std::int32_t max_block_size, + std::int32_t num_tiles, std::int64_t size) -> std::int32_t; +auto compute_k_block_size(std::int64_t K) -> std::int32_t; + } // namespace tinytc #endif // GEMM_TOOLS_20241022_HPP diff --git a/src/inst.cpp b/src/inst.cpp index 4d17e713..1712ac77 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -80,6 +80,20 @@ char const *tinytc_arithmetic_unary_to_string(tinytc_arithmetic_unary_t op) { return "unknown"; } +char const *tinytc_checked_flag_to_string(tinytc_checked_flag_t flag) { + switch (flag) { + case tinytc_checked_flag_none: + return ""; + case tinytc_checked_flag_rows: + return "rows_checked"; + case tinytc_checked_flag_cols: + return "cols_checked"; + case tinytc_checked_flag_both: + return "both_checked"; + } + return "unknown"; +} + char const *tinytc_cmp_condition_to_string(tinytc_cmp_condition_t cond) { switch (cond) { case tinytc_cmp_condition_eq: @@ -204,12 +218,18 @@ tinytc_status_t tinytc_constant_inst_create_one(tinytc_inst_t *instr, tinytc_dat if (instr == nullptr) { return tinytc_status_invalid_arguments; } - const auto *st = dyn_cast(ty); - if (st == nullptr) { + + scalar_type sty; + if (const auto *st = dyn_cast(ty); st != nullptr) { + sty = st->ty(); + } else if (const auto *ct = dyn_cast(ty); ct != nullptr) { + sty = ct->component_ty(); + } else { return tinytc_status_invalid_arguments; } + return exception_to_status_code([&] { - switch (st->ty()) { + switch (sty) { case scalar_type::i1: case scalar_type::i8: case scalar_type::i16: @@ -237,12 +257,18 @@ tinytc_status_t tinytc_constant_inst_create_zero(tinytc_inst_t *instr, tinytc_da if (instr == nullptr) { return tinytc_status_invalid_arguments; } - const auto *st = dyn_cast(ty); - if (st == nullptr) { + + scalar_type sty; + if (const auto *st = dyn_cast(ty); st != nullptr) { + sty = st->ty(); + } else if (const auto *ct = dyn_cast(ty); ct != nullptr) { + sty = ct->component_ty(); + } else { return tinytc_status_invalid_arguments; } + return exception_to_status_code([&] { - switch (st->ty()) { + switch (sty) { case scalar_type::i1: case scalar_type::i8: case scalar_type::i16: @@ -266,14 +292,15 @@ tinytc_status_t tinytc_constant_inst_create_zero(tinytc_inst_t *instr, tinytc_da } tinytc_status_t tinytc_cooperative_matrix_load_inst_create( - tinytc_inst_t *instr, tinytc_transpose_t trans, tinytc_bool_t checked, tinytc_value_t op, + tinytc_inst_t *instr, tinytc_transpose_t trans, tinytc_checked_flag_t flag, tinytc_value_t op, tinytc_value_t p0, tinytc_value_t p1, tinytc_data_type_t to_ty, const tinytc_location_t *loc) { if (instr == nullptr || op == nullptr || p0 == nullptr || p1 == nullptr || to_ty == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique( - enum_cast(trans), checked, op, p0, p1, to_ty, get_optional(loc)) + *instr = std::make_unique(enum_cast(trans), + enum_cast(flag), op, + p0, p1, to_ty, get_optional(loc)) .release(); }); } @@ -304,15 +331,19 @@ tinytc_status_t tinytc_cooperative_matrix_scale_inst_create(tinytc_inst_t *instr }); } -tinytc_status_t tinytc_cooperative_matrix_store_inst_create( - tinytc_inst_t *instr, tinytc_bool_t checked, tinytc_store_flag_t flag, tinytc_value_t val, - tinytc_value_t op, tinytc_value_t p0, tinytc_value_t p1, const tinytc_location_t *loc) { +tinytc_status_t tinytc_cooperative_matrix_store_inst_create(tinytc_inst_t *instr, + tinytc_checked_flag_t cflag, + tinytc_store_flag_t sflag, + tinytc_value_t val, tinytc_value_t op, + tinytc_value_t p0, tinytc_value_t p1, + const tinytc_location_t *loc) { if (instr == nullptr || val == nullptr || op == nullptr || p0 == nullptr || p1 == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique( - checked, enum_cast(flag), val, op, p0, p1, get_optional(loc)) + *instr = std::make_unique(enum_cast(cflag), + enum_cast(sflag), val, + op, p0, p1, get_optional(loc)) .release(); }); } diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 3707ccf6..6fc66f09 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -16,6 +16,15 @@ #include #include +auto tinytc_inst::context() const -> tinytc_compiler_context_t { + if (num_results() > 0) { + return result(0).context(); + } else if (num_operands() > 0) { + return op(0).context(); + } + return nullptr; +} + namespace tinytc { scalar_data_type *get_scalar_type(location const &loc, tinytc_value const &v) { @@ -376,12 +385,12 @@ auto constant_inst::is_identity() const -> bool { return std::visit([](auto const &v) { return v == decltype(v){1}; }, value_); } -cooperative_matrix_load_inst::cooperative_matrix_load_inst(transpose t, bool checked, +cooperative_matrix_load_inst::cooperative_matrix_load_inst(transpose t, checked_flag flag, tinytc_value_t op0, tinytc_value_t p0, tinytc_value_t p1, tinytc_data_type_t to_ty, location const &lc) - : standard_inst{IK::cooperative_matrix_load}, t_(t), checked_(checked) { + : standard_inst{IK::cooperative_matrix_load}, t_(t), flag_(flag) { op(op_operand, op0); op(op_pos0, p0); op(op_pos1, p1); @@ -471,11 +480,11 @@ cooperative_matrix_scale_inst::cooperative_matrix_scale_inst(tinytc_value_t a0, result(0) = value_node{b().ty(), this, lc}; } -cooperative_matrix_store_inst::cooperative_matrix_store_inst(bool checked, store_flag flag, +cooperative_matrix_store_inst::cooperative_matrix_store_inst(checked_flag cflag, store_flag sflag, tinytc_value_t val0, tinytc_value_t op0, tinytc_value_t p0, tinytc_value_t p1, location const &lc) - : standard_inst{IK::cooperative_matrix_store}, checked_(checked), flag_(flag) { + : standard_inst{IK::cooperative_matrix_store}, cflag_(cflag), sflag_(sflag) { op(op_val, val0); op(op_operand, op0); op(op_pos0, p0); diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 15f6bb6d..94951bb9 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -129,6 +129,7 @@ struct tinytc_inst : tinytc::ilist_node_with_parent tinytc_inst &operator=(tinytc_inst const &other) = delete; tinytc_inst &operator=(tinytc_inst &&other) = delete; + auto context() const -> tinytc_compiler_context_t; inline auto type_id() const -> tinytc::IK { return tid_; } inline auto loc() const noexcept -> tinytc::location const & { return loc_; } @@ -517,18 +518,18 @@ class cooperative_matrix_load_inst : public standard_inst<3, 1, 0> { return i.type_id() == IK::cooperative_matrix_load; } enum op_number { op_operand = 0, op_pos0 = 1, op_pos1 = 2 }; - cooperative_matrix_load_inst(transpose t, bool checked, tinytc_value_t op0, tinytc_value_t p0, - tinytc_value_t p1, tinytc_data_type_t to_ty, + cooperative_matrix_load_inst(transpose t, checked_flag flag, tinytc_value_t op0, + tinytc_value_t p0, tinytc_value_t p1, tinytc_data_type_t to_ty, location const &lc = {}); inline auto t() const -> transpose { return t_; } - inline auto checked() const -> bool { return checked_; } + inline auto checked() const -> checked_flag { return flag_; } inline auto operand() const -> tinytc_value const & { return op(op_operand); } inline auto pos0() const -> tinytc_value const & { return op(op_pos0); } inline auto pos1() const -> tinytc_value const & { return op(op_pos1); } private: transpose t_; - bool checked_; + checked_flag flag_; }; class cooperative_matrix_mul_add_inst : public standard_inst<3, 1, 0> { @@ -561,19 +562,19 @@ class cooperative_matrix_store_inst : public standard_inst<4, 0, 0> { return i.type_id() == IK::cooperative_matrix_store; } enum op_number { op_val = 0, op_operand = 1, op_pos0 = 2, op_pos1 = 3 }; - cooperative_matrix_store_inst(bool checked, store_flag flag, tinytc_value_t val0, + cooperative_matrix_store_inst(checked_flag cflag, store_flag sflag, tinytc_value_t val0, tinytc_value_t op0, tinytc_value_t p0, tinytc_value_t p1, location const &lc = {}); - inline auto checked() const -> bool { return checked_; } - inline auto flag() const -> store_flag { return flag_; } + inline auto checked() const -> checked_flag { return cflag_; } + inline auto flag() const -> store_flag { return sflag_; } inline auto val() const -> tinytc_value const & { return op(op_val); } inline auto operand() const -> tinytc_value const & { return op(op_operand); } inline auto pos0() const -> tinytc_value const & { return op(op_pos0); } inline auto pos1() const -> tinytc_value const & { return op(op_pos1); } private: - bool checked_; - store_flag flag_; + checked_flag cflag_; + store_flag sflag_; }; class expand_inst : public standard_inst { diff --git a/src/parser/lexer.re b/src/parser/lexer.re index f580234c..e67c1e23 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -96,7 +96,6 @@ lex: ".t" { adv_loc(); return parser::make_TRANS(loc_); } ".atomic" { adv_loc(); return parser::make_ATOMIC(loc_); } ".atomic_add" { adv_loc(); return parser::make_ATOMIC_ADD(loc_); } - ".checked" { adv_loc(); return parser::make_CHECKED(loc_); } "init" { adv_loc(); return parser::make_INIT(loc_); } "local" { adv_loc(); return parser::make_LOCAL(loc_); } "global" { adv_loc(); return parser::make_GLOBAL(loc_); } @@ -141,6 +140,11 @@ lex: "matrix_b" { adv_loc(); return parser::make_MATRIX_USE(matrix_use::b, loc_); } "matrix_acc" { adv_loc(); return parser::make_MATRIX_USE(matrix_use::acc, loc_); } + // checked flag + ".rows_checked" { adv_loc(); return parser::make_CHECKED(checked_flag::rows, loc_); } + ".cols_checked" { adv_loc(); return parser::make_CHECKED(checked_flag::cols, loc_); } + ".both_checked" { adv_loc(); return parser::make_CHECKED(checked_flag::both, loc_); } + // instructions "axpby" { adv_loc(); return parser::make_AXPBY(loc_); } "arith" { adv_loc(); return parser::make_ARITH(loc_); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index bf2d9e95..b83421d7 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -96,7 +96,6 @@ TRANS ".t" ATOMIC ".atomic" ATOMIC_ADD ".atomic_add" - CHECKED ".checked" INIT "init" LOCAL "local" GLOBAL "global" @@ -151,6 +150,7 @@ %token ARITHMETIC_UNARY %token CMP_CONDITION %token MATRIX_USE +%token CHECKED %nterm prog %nterm > func_list @@ -211,7 +211,7 @@ %nterm cooperative_matrix_mul_add_inst %nterm cooperative_matrix_scale_inst %nterm cooperative_matrix_store_inst -%nterm checked +%nterm checked %nterm expand_inst %nterm integer_constant_or_identifier %nterm > expand_shape @@ -917,8 +917,8 @@ cooperative_matrix_load_inst: ; checked: - %empty { $$ = false; } - | CHECKED { $$ = true; } + %empty { $$ = checked_flag::none; } + | CHECKED { $$ = $CHECKED; } ; cooperative_matrix_mul_add_inst: diff --git a/src/pass/constant_folding.cpp b/src/pass/constant_folding.cpp index 9676321f..3e97b72b 100644 --- a/src/pass/constant_folding.cpp +++ b/src/pass/constant_folding.cpp @@ -135,7 +135,7 @@ auto constant_folding::get_memref_type(value_node const &v) const -> const memre return t; } -auto constant_folding::operator()(inst_node &) -> fold_result { return {}; } +auto constant_folding::operator()(inst_node &) -> fold_result { return tinytc_value_t{}; } auto constant_folding::operator()(arith_inst &in) -> fold_result { auto &op_a = in.a(); diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 0c6730b5..d5679654 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -547,10 +547,12 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_lo auto rt = get_coopmatrix_type(c.result(0)); auto &dv = get_dope_vector(c.operand()); + const bool check_rows = c.checked() == checked_flag::rows || c.checked() == checked_flag::both; + const bool check_cols = c.checked() == checked_flag::cols || c.checked() == checked_flag::both; const int rmode = rt->distributed_mode(); const int omode = c.t() == transpose::T ? 1 - rmode : rmode; - const bool enable_sub_group_reads = core_cfg_.block_read_write_supported && !c.checked() && - c.t() == transpose::N && ot->stride(omode) == 1; + const bool enable_sub_group_reads = + core_cfg_.block_read_write_supported && c.t() == transpose::N && ot->stride(omode) == 1; auto clinst = std::vector{}; auto const len = rt->length(core_cfg_.subgroup_size); @@ -563,7 +565,7 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_lo declaration_assignment(visit(*this, *c.operand().ty()), pointer, val(c.operand()) + pv[0] * dv.stride(0) + pv[1] * dv.stride(1))); clir::var rem[2] = {}; - if (c.checked()) { + if (check_rows || check_cols) { clinst.emplace_back( declaration_assignment(to_clir_ty(scalar_type::index), rem[0], dv.shape(0) - pv[0])); clinst.emplace_back( @@ -572,40 +574,45 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_lo const std::int64_t num_blocks = rt->num_blocks(core_cfg_.subgroup_size); for (std::int64_t block = 0; block < num_blocks; ++block) { - auto common_check = clir::var{}; - if (c.checked()) { + auto row_in_bounds = clir::var{}; + if (check_rows) { auto m = clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size; - clinst.emplace_back(declaration_assignment(to_clir_ty(scalar_type::i1), common_check, + clinst.emplace_back(declaration_assignment(to_clir_ty(scalar_type::i1), row_in_bounds, m >= -pv[omode] && m < rem[omode])); } for (std::int64_t k = 0; k < rt->shape(1 - rmode); ++k) { + auto col_cond = [&] { return k >= -pv[1 - omode] && k < rem[1 - omode]; }; + auto const store = [&](clir::expr rhs) -> clir::stmt { return expression_statement( assignment(lhs[k + block * rt->shape(1 - rmode)], std::move(rhs))); }; auto const remainder = rt->shape(rmode) - core_cfg_.subgroup_size * block; const bool needs_mask = remainder < core_cfg_.subgroup_size; - if (enable_sub_group_reads && !needs_mask) { + if (enable_sub_group_reads && !needs_mask && !check_rows) { auto rhs = sub_group_block_read_helper( pointer + block * core_cfg_.subgroup_size + k * ot->stride(1), ot->element_ty(), to_clir_address_space(ot->addrspace())); + if (check_cols) { + rhs = ternary_conditional(col_cond(), std::move(rhs), 0); + } clinst.emplace_back(store(std::move(rhs))); } else { auto rhs = pointer[ot->stride(omode) * (clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size) + k * ot->stride(1 - omode)]; - auto checked_cond = [&] { - return common_check && k >= -pv[1 - omode] && k < rem[1 - omode]; - }; - auto mask_cond = [&] { return clir::get_sub_group_local_id() < remainder; }; clir::expr cond = {}; - if (c.checked() && needs_mask) { - cond = checked_cond() && mask_cond(); - } else if (c.checked()) { - cond = checked_cond(); - } else if (needs_mask) { - cond = mask_cond(); + if (check_rows) { + cond = row_in_bounds; + } + if (check_cols) { + cond = cond ? cond && col_cond() : col_cond(); } + if (needs_mask) { + auto mask_cond = clir::get_sub_group_local_id() < remainder; + cond = cond ? cond && mask_cond : mask_cond; + } + if (cond) { rhs = ternary_conditional(cond, std::move(rhs), 0); } @@ -753,6 +760,9 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_st auto &dv = get_dope_vector(c.operand()); auto valv = val(c.val()); + const bool check_rows = c.checked() == checked_flag::rows || c.checked() == checked_flag::both; + const bool check_cols = c.checked() == checked_flag::cols || c.checked() == checked_flag::both; + const int vmode = vt->distributed_mode(); const int omode = vmode; @@ -766,7 +776,7 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_st declaration_assignment(visit(*this, *c.operand().ty()), pointer, val(c.operand()) + pv[0] * dv.stride(0) + pv[1] * dv.stride(1))); clir::var rem[2] = {}; - if (c.checked()) { + if (check_rows || check_cols) { clinst.emplace_back( declaration_assignment(to_clir_ty(scalar_type::index), rem[0], dv.shape(0) - pv[0])); clinst.emplace_back( @@ -808,15 +818,13 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_st (clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size) + k * ot->stride(1 - omode); auto rhs = valv[k + block * vt->shape(1 - vmode)]; - auto checked_cond = [&] { return k >= -pv[1 - omode] && k < rem[1 - omode]; }; - auto mask_cond = [&] { return clir::get_sub_group_local_id() < remainder; }; clir::expr cond = {}; - if (c.checked() && needs_mask) { - cond = checked_cond() && mask_cond(); - } else if (c.checked()) { - cond = checked_cond(); - } else if (needs_mask) { - cond = mask_cond(); + if (check_cols) { + cond = k >= -pv[1 - omode] && k < rem[1 - omode]; + } + if (needs_mask) { + auto mask_cond = clir::get_sub_group_local_id() < remainder; + cond = cond ? cond && mask_cond : mask_cond; } auto st = clir::expression_statement(store(std::move(offset), std::move(rhs))); if (cond) { @@ -829,7 +837,7 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_st } } - if (c.checked()) { + if (check_rows) { auto m = clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size; clinst.emplace_back(clir::if_selection_builder(m >= -pv[omode] && m < rem[omode]) .then([&](clir::block_builder &bb) { diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 3636c702..ab6723a7 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -206,8 +206,8 @@ void dump_ir_pass::operator()(cooperative_matrix_load_inst const &c) { dump_val(c.result(0)); *os_ << " = cooperative_matrix_load"; *os_ << "." << to_string(c.t()); - if (c.checked()) { - *os_ << ".checked"; + if (c.checked() != checked_flag::none) { + *os_ << "." << to_string(c.checked()); } *os_ << " "; dump_val(c.operand()); @@ -253,8 +253,8 @@ void dump_ir_pass::operator()(cooperative_matrix_scale_inst const &c) { void dump_ir_pass::operator()(cooperative_matrix_store_inst const &c) { *os_ << "cooperative_matrix_store"; - if (c.checked()) { - *os_ << ".checked"; + if (c.checked() != checked_flag::none) { + *os_ << "." << to_string(c.checked()); } if (c.flag() != store_flag::regular) { *os_ << '.' << to_string(c.flag()); diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index 3c4e1326..74c8a909 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -5,11 +5,13 @@ #include "codegen_tools.hpp" #include "device_info.hpp" #include "error.hpp" +#include "gemm_tools.hpp" #include "node/data_type_node.hpp" #include "node/function_node.hpp" #include "node/inst_node.hpp" #include "node/region_node.hpp" #include "node/value_node.hpp" +#include "scalar_type.hpp" #include "support/casting.hpp" #include "support/ilist.hpp" #include "support/ilist_base.hpp" @@ -19,13 +21,107 @@ #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" -#include #include +#include #include #include +#include namespace tinytc { +void gemm_microkernel(region_builder &bb, transpose tA, transpose tB, bool atomic, value alpha, + value A, value B, value beta, value C, value K, value m_block, + std::int64_t m_block_size, bool m_check, value n_block, + std::int64_t n_block_size, bool n_check, data_type a_ty, data_type b_ty, + data_type c_ty, location const &loc) { + auto ctx = m_block->context(); + auto index_ty = scalar_data_type::get(ctx, scalar_type::index); + + const auto check_a = m_check ? checked_flag::rows : checked_flag::none; + const auto check_b = n_check ? checked_flag::cols : checked_flag::none; + const auto check_c = [&] { + if (m_check && n_check) { + return checked_flag::both; + } else if (m_check) { + return checked_flag::rows; + } else if (n_check) { + return checked_flag::cols; + } + return checked_flag::none; + }(); + auto coopmatrix_c_ty = get_coopmatrix(c_ty, m_block_size, n_block_size, matrix_use::acc, loc); + auto const compute_c = [&](region_builder &bb, std::int32_t k_block_size, value K0, value K1, + value c_init) -> value { + auto c_step = bb.add(make_constant(k_block_size, index_ty, loc)); + auto return_values = bb.for_loop( + K0, K1, c_step, {c_init}, {coopmatrix_c_ty}, index_ty, + [&](region_builder &bb, array_view p) { + const auto k = p[0]; + + value pos_a[2] = {m_block, k}; + if (tA == transpose::T) { + std::swap(pos_a[0], pos_a[1]); + } + auto coopmatrix_a_ty = + get_coopmatrix(a_ty, m_block_size, k_block_size, matrix_use::a, loc); + auto a = bb.add(make_cooperative_matrix_load(tA, check_a, A, pos_a[0], pos_a[1], + coopmatrix_a_ty)); + + value pos_b[2] = {k, n_block}; + if (tB == transpose::T) { + std::swap(pos_b[0], pos_b[1]); + } + auto coopmatrix_b_ty = + get_coopmatrix(b_ty, k_block_size, n_block_size, matrix_use::b, loc); + auto b = bb.add(make_cooperative_matrix_load(tB, check_b, B, pos_b[0], pos_b[1], + coopmatrix_b_ty)); + auto c_next = + bb.add(make_cooperative_matrix_mul_add(a, b, p[1], coopmatrix_c_ty, loc)); + bb.add(make_yield(c_next, loc)); + }); + return return_values[0]; + }; + + auto c_init = bb.add(make_constant_zero(coopmatrix_c_ty, loc)); + + auto k_block_size = max_K_unrolling; + + const auto const_K = get_int_constant(K); + if (const_K) { + k_block_size = compute_k_block_size(*const_K); + } + + auto c_zero = bb.add(make_constant_zero(index_ty, loc)); + auto c_k_block_size = bb.add(make_constant(k_block_size, index_ty, loc)); + auto tmp = bb.add(make_arith(arithmetic::div, K, c_k_block_size, loc)); + auto K0 = bb.add(make_arith(arithmetic::mul, tmp, c_k_block_size, loc)); + c_init = compute_c(bb, k_block_size, c_zero, K0, c_init); + auto needs_remainder = bb.add(make_cmp(cmp_condition::lt, K0, K, loc)); + bb.if_condition( + needs_remainder, + [&](region_builder &bb) { + auto c_next = compute_c(bb, 1, K0, K, c_init); + bb.add(make_yield(c_next, loc)); + }, + {coopmatrix_c_ty}, loc); + + auto alpha_ab = mixed_precision_coopmatrix_scale(bb, alpha, c_init, loc); + if (atomic) { + auto flag = get_atomic_store_flag(beta); + if (!flag) { + throw compilation_error(loc, status::ir_invalid_beta); + } + bb.add(make_cooperative_matrix_store(check_c, *flag, alpha_ab, C, m_block, n_block, loc)); + } else { + auto c_load = bb.add(make_cooperative_matrix_load(transpose::N, check_c, C, m_block, + n_block, coopmatrix_c_ty)); + auto beta_c = mixed_precision_coopmatrix_scale(bb, beta, c_load, loc); + auto alpha_ab_plus_beta_c = bb.add(make_arith(arithmetic::add, alpha_ab, beta_c, loc)); + bb.add(make_cooperative_matrix_store(check_c, store_flag::regular, alpha_ab_plus_beta_c, C, + m_block, n_block, loc)); + } +} + class linalg_generator { public: linalg_generator(local_tiling tiling, core_config core_cfg) @@ -33,6 +129,7 @@ class linalg_generator { auto operator()(inst_node &) -> inst { return inst{}; } auto operator()(axpby_inst &in) -> inst; auto operator()(ger_inst &in) -> inst; + auto operator()(gemm_inst &in) -> inst; auto operator()(hadamard_inst &in) -> inst; auto operator()(sum_inst &in) -> inst; @@ -85,15 +182,15 @@ auto linalg_generator::operator()(axpby_inst &in) -> inst { blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), {}, in.loc()); }); } else if (bt->dim() == 1) { - auto c_shape0 = bb.add(make_size(&in.B(), 0, in.loc())); + auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.B(), 0, in.loc())); inner_loop(bb, &in.A(), &in.B(), c_shape0, tiling_.m_tiles() * tiling_.n_tiles(), sgid); } else if (bt->dim() == 2) { auto c_m_tiles = bb.add(make_constant(tiling_.m_tiles(), i32_ty, in.loc())); auto sg_n = bb.add(make_arith(arithmetic::div, sgid, c_m_tiles, in.loc())); auto sg_m = bb.add(make_arith(arithmetic::rem, sgid, c_m_tiles, in.loc())); - auto c_shape0 = bb.add(make_size(&in.B(), 0, in.loc())); - auto c_shape1 = bb.add(make_size(&in.B(), 1, in.loc())); + auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.B(), 0, in.loc())); + auto c_shape1 = instant_constant_fold_add(bb, make_size(&in.B(), 1, in.loc())); tile_loop_uniformly_new( bb, c_shape1, core_cfg_.subgroup_size, tiling_.n_tiles(), sg_n, [&](region_builder &bb, value block, value trip_count) { @@ -132,8 +229,8 @@ auto linalg_generator::operator()(ger_inst &in) -> inst { auto sg_n = bb.add(make_arith(arithmetic::div, sgid, c_m_tiles, in.loc())); auto sg_m = bb.add(make_arith(arithmetic::rem, sgid, c_m_tiles, in.loc())); - auto c_shape0 = bb.add(make_size(&in.C(), 0, in.loc())); - auto c_shape1 = bb.add(make_size(&in.C(), 1, in.loc())); + auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.C(), 0, in.loc())); + auto c_shape1 = instant_constant_fold_add(bb, make_size(&in.C(), 1, in.loc())); tile_loop_uniformly_new( bb, c_shape1, core_cfg_.subgroup_size, tiling_.n_tiles(), sg_n, [&](region_builder &bb, value block, value trip_count) { @@ -155,6 +252,77 @@ auto linalg_generator::operator()(ger_inst &in) -> inst { return parallel; } +auto linalg_generator::operator()(gemm_inst &in) -> inst { + auto parallel = make_parallel(in.loc()); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + auto ct = get_memref_type(in.C()); + + auto ctx = compiler_context{in.alpha().context(), true}; + auto i32_ty = get_scalar(ctx, scalar_type::i32); + + auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); + auto c_m_tiles = bb.add(make_constant(tiling_.m_tiles(), i32_ty, in.loc())); + auto sg_n = bb.add(make_arith(arithmetic::div, sgid, c_m_tiles, in.loc())); + auto sg_m = bb.add(make_arith(arithmetic::rem, sgid, c_m_tiles, in.loc())); + + auto [max_rows, max_cols] = max_register_block_gemm( + size(ct->element_ty()), core_cfg_.subgroup_size, core_cfg_.register_space, + is_complex_type(ct->element_ty()) ? 2 : 1); + + auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.C(), 0, in.loc())); + auto c_shape1 = instant_constant_fold_add(bb, make_size(&in.C(), 1, in.loc())); + auto K = instant_constant_fold_add( + bb, make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, in.loc())); + + auto const_shape0 = get_int_constant(c_shape0); + auto const_shape1 = get_int_constant(c_shape1); + + const auto block_size0 = const_shape0 ? compute_m_block_size(core_cfg_.subgroup_size, max_rows, + tiling_.m_tiles(), *const_shape0) + : max_rows; + const auto block_size1 = max_cols; + + if (const_shape1) { + tile_loop_uniformly_new( + bb, c_shape1, block_size1, tiling_.n_tiles(), sg_n, + [&](region_builder &bb, value n_block, value trip_count) { + auto const_trip_count = get_int_constant(trip_count); + if (!const_trip_count) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + tile_loop_by_sgs_new(bb, c_shape0, block_size0, tiling_.m_tiles(), sg_m, + [&](region_builder &bb, value m_block, bool m_check, value) { + gemm_microkernel( + bb, in.tA(), in.tB(), in.atomic(), &in.alpha(), + &in.A(), &in.B(), &in.beta(), &in.C(), K, m_block, + block_size0, m_check, n_block, *const_trip_count, + false, at->element_data_ty(), bt->element_data_ty(), + ct->element_data_ty(), in.loc()); + }); + }); + } else { + tile_loop_by_sgs_new(bb, c_shape1, block_size1, tiling_.n_tiles(), sg_n, + [&](region_builder &bb, value n_block, bool n_check, value) { + tile_loop_by_sgs_new( + bb, c_shape0, block_size0, tiling_.m_tiles(), sg_m, + [&](region_builder &bb, value m_block, bool m_check, value) { + gemm_microkernel( + bb, in.tA(), in.tB(), in.atomic(), &in.alpha(), + &in.A(), &in.B(), &in.beta(), &in.C(), K, m_block, + block_size0, m_check, n_block, block_size1, n_check, + at->element_data_ty(), bt->element_data_ty(), + ct->element_data_ty(), in.loc()); + }); + }); + } + + return parallel; +} + auto linalg_generator::operator()(hadamard_inst &in) -> inst { auto parallel = make_parallel(in.loc()); tinytc_region_t body = ¶llel->child_region(0); @@ -163,7 +331,7 @@ auto linalg_generator::operator()(hadamard_inst &in) -> inst { auto ctx = compiler_context{in.alpha().context(), true}; auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); - auto c_shape0 = bb.add(make_size(&in.C(), 0, in.loc())); + auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.C(), 0, in.loc())); tile_loop_by_sgs_standard( bb, c_shape0, core_cfg_.subgroup_size, tiling_.m_tiles() * tiling_.n_tiles(), sgid, [&](region_builder &bb, value mm) { @@ -191,8 +359,9 @@ auto linalg_generator::operator()(sum_inst &in) -> inst { if (bt->dim() == 0) { // @todo } else if (bt->dim() == 1) { - auto c_shape0 = bb.add(make_size(&in.B(), 0, in.loc())); - auto c_trip_count = bb.add(make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, in.loc())); + auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.B(), 0, in.loc())); + auto c_trip_count = instant_constant_fold_add( + bb, make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, in.loc())); tile_loop_by_sgs_standard( bb, c_shape0, core_cfg_.subgroup_size, tiling_.m_tiles() * tiling_.n_tiles(), sgid, [&](region_builder &bb, value mm) { diff --git a/src/tiling.cpp b/src/tiling.cpp index 0589a429..546ff32e 100644 --- a/src/tiling.cpp +++ b/src/tiling.cpp @@ -95,7 +95,7 @@ auto suggest_local_tiling(std::vector const &shapes, } auto suggest_local_tiling(blas_shape const &bshape, core_config const &core_cfg) -> local_tiling { - auto [row_blocks, cols] = + auto [rows, cols] = max_register_block_gemm(size(bshape.ty), core_cfg.subgroup_size, core_cfg.register_space); auto const num_tile_limit = [](std::int64_t mode, std::int32_t block_size) { auto limit = std::numeric_limits::max(); @@ -104,7 +104,7 @@ auto suggest_local_tiling(blas_shape const &bshape, core_config const &core_cfg) } return limit; }; - auto const m_limit = num_tile_limit(bshape.shape[0], row_blocks * core_cfg.subgroup_size); + auto const m_limit = num_tile_limit(bshape.shape[0], rows); auto const n_limit = num_tile_limit(bshape.shape[1], cols); auto const max_threads = core_cfg.max_work_group_size / core_cfg.subgroup_size; @@ -119,7 +119,7 @@ auto suggest_local_tiling(blas_shape const &bshape, core_config const &core_cfg) while (2 * n <= std::min(n_limit, max_threads / m)) { n *= 2; } - auto const LM = m * row_blocks * core_cfg.subgroup_size; + auto const LM = m * rows; auto const LN = n * cols; double const ratio = LM * LN / static_cast(LM + LN); if (ratio > best_ratio) { diff --git a/test/codegen/coopmatrix_load.ir b/test/codegen/coopmatrix_load.ir index aa604a9b..4feddee4 100644 --- a/test/codegen/coopmatrix_load.ir +++ b/test/codegen/coopmatrix_load.ir @@ -17,31 +17,59 @@ func @coopmatrix_a_load_n(%A: memref, %x: index, %y: index) subgroup_ ; CHECK-NEXT: x1[7] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 448))); } +func @coopmatrix_a_load_n_rows_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { + %0 = cooperative_matrix_load.n.rows_checked %A[%x,%y] : memref -> coopmatrix +; CHECK-LABEL: void coopmatrix_a_load_n_rows_checked({{.*}} +; CHECK: float x1[4]; +; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; +; CHECK-NEXT: long x3 = 64 - x; +; CHECK-NEXT: long x4 = 48 - y; +; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x3; +; CHECK-NEXT: x1[0] = x5 ? x2[1 * (get_sub_group_local_id() + 0) + 0] : 0; +; CHECK-NEXT: x1[1] = x5 ? x2[1 * (get_sub_group_local_id() + 0) + 64] : 0; +; CHECK-NEXT: bool x6 = get_sub_group_local_id() + 16 >= -x && get_sub_group_local_id() + 16 < x3; +; CHECK-NEXT: x1[2] = x6 ? x2[1 * (get_sub_group_local_id() + 16) + 0] : 0; +; CHECK-NEXT: x1[3] = x6 ? x2[1 * (get_sub_group_local_id() + 16) + 64] : 0; +} + +func @coopmatrix_a_load_n_cols_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { + %0 = cooperative_matrix_load.n.cols_checked %A[%x,%y] : memref -> coopmatrix +; CHECK-LABEL: void coopmatrix_a_load_n_cols_checked({{.*}} +; CHECK: float x1[4]; +; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; +; CHECK-NEXT: long x3 = 64 - x; +; CHECK-NEXT: long x4 = 48 - y; +; CHECK-NEXT: x1[0] = 0 >= -y && 0 < x4 ? as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 0))) : 0; +; CHECK-NEXT: x1[1] = 1 >= -y && 1 < x4 ? as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 64))) : 0; +; CHECK-NEXT: x1[2] = 0 >= -y && 0 < x4 ? as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 16 + 0))) : 0; +; CHECK-NEXT: x1[3] = 1 >= -y && 1 < x4 ? as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 16 + 64))) : 0; +} + func @coopmatrix_a_load_n_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n.checked %A[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.n.both_checked %A[%x,%y] : memref -> coopmatrix ; CHECK-LABEL: void coopmatrix_a_load_n_checked({{.*}} ; CHECK: float x1[16]; ; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; ; CHECK-NEXT: long x3 = 64 - x; ; CHECK-NEXT: long x4 = 48 - y; ; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x3; -; CHECK-NEXT: x1[0] = x5 && 0 >= -y && 0 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 0] : 0; -; CHECK-NEXT: x1[1] = x5 && 1 >= -y && 1 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 64] : 0; -; CHECK-NEXT: x1[2] = x5 && 2 >= -y && 2 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 128] : 0; -; CHECK-NEXT: x1[3] = x5 && 3 >= -y && 3 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 192] : 0; -; CHECK-NEXT: x1[4] = x5 && 4 >= -y && 4 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 256] : 0; -; CHECK-NEXT: x1[5] = x5 && 5 >= -y && 5 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 320] : 0; -; CHECK-NEXT: x1[6] = x5 && 6 >= -y && 6 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 384] : 0; -; CHECK-NEXT: x1[7] = x5 && 7 >= -y && 7 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 448] : 0; +; CHECK-NEXT: x1[0] = x5 && (0 >= -y && 0 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 0] : 0; +; CHECK-NEXT: x1[1] = x5 && (1 >= -y && 1 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 64] : 0; +; CHECK-NEXT: x1[2] = x5 && (2 >= -y && 2 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 128] : 0; +; CHECK-NEXT: x1[3] = x5 && (3 >= -y && 3 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 192] : 0; +; CHECK-NEXT: x1[4] = x5 && (4 >= -y && 4 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 256] : 0; +; CHECK-NEXT: x1[5] = x5 && (5 >= -y && 5 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 320] : 0; +; CHECK-NEXT: x1[6] = x5 && (6 >= -y && 6 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 384] : 0; +; CHECK-NEXT: x1[7] = x5 && (7 >= -y && 7 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 448] : 0; ; CHECK-NEXT: bool x6 = get_sub_group_local_id() + 16 >= -x && get_sub_group_local_id() + 16 < x3; -; CHECK-NEXT: x1[8] = x6 && 0 >= -y && 0 < x4 ? x2[1 * (get_sub_group_local_id() + 16) + 0] : 0; -; CHECK-NEXT: x1[9] = x6 && 1 >= -y && 1 < x4 ? x2[1 * (get_sub_group_local_id() + 16) + 64] : 0; -; CHECK-NEXT: x1[10] = x6 && 2 >= -y && 2 < x4 ? x2[1 * (get_sub_group_local_id() + 16) + 128] : 0; -; CHECK-NEXT: x1[11] = x6 && 3 >= -y && 3 < x4 ? x2[1 * (get_sub_group_local_id() + 16) + 192] : 0; -; CHECK-NEXT: x1[12] = x6 && 4 >= -y && 4 < x4 ? x2[1 * (get_sub_group_local_id() + 16) + 256] : 0; -; CHECK-NEXT: x1[13] = x6 && 5 >= -y && 5 < x4 ? x2[1 * (get_sub_group_local_id() + 16) + 320] : 0; -; CHECK-NEXT: x1[14] = x6 && 6 >= -y && 6 < x4 ? x2[1 * (get_sub_group_local_id() + 16) + 384] : 0; -; CHECK-NEXT: x1[15] = x6 && 7 >= -y && 7 < x4 ? x2[1 * (get_sub_group_local_id() + 16) + 448] : 0; +; CHECK-NEXT: x1[8] = x6 && (0 >= -y && 0 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 0] : 0; +; CHECK-NEXT: x1[9] = x6 && (1 >= -y && 1 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 64] : 0; +; CHECK-NEXT: x1[10] = x6 && (2 >= -y && 2 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 128] : 0; +; CHECK-NEXT: x1[11] = x6 && (3 >= -y && 3 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 192] : 0; +; CHECK-NEXT: x1[12] = x6 && (4 >= -y && 4 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 256] : 0; +; CHECK-NEXT: x1[13] = x6 && (5 >= -y && 5 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 320] : 0; +; CHECK-NEXT: x1[14] = x6 && (6 >= -y && 6 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 384] : 0; +; CHECK-NEXT: x1[15] = x6 && (7 >= -y && 7 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 448] : 0; } func @coopmatrix_a_load_t(%A: memref, %x: index, %y: index) subgroup_size(16) { @@ -60,21 +88,21 @@ func @coopmatrix_a_load_t(%A: memref, %x: index, %y: index) subgroup_ } func @coopmatrix_a_load_t_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.t.checked %A[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.t.both_checked %A[%x,%y] : memref -> coopmatrix ; CHECK-LABEL: void coopmatrix_a_load_t_checked({{.*}} ; CHECK: float x1[8]; ; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; ; CHECK-NEXT: long x3 = 64 - x; ; CHECK-NEXT: long x4 = 48 - y; ; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -y && get_sub_group_local_id() + 0 < x4; -; CHECK-NEXT: x1[0] = x5 && 0 >= -x && 0 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 0] : 0; -; CHECK-NEXT: x1[1] = x5 && 1 >= -x && 1 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 1] : 0; -; CHECK-NEXT: x1[2] = x5 && 2 >= -x && 2 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 2] : 0; -; CHECK-NEXT: x1[3] = x5 && 3 >= -x && 3 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 3] : 0; -; CHECK-NEXT: x1[4] = x5 && 4 >= -x && 4 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 4] : 0; -; CHECK-NEXT: x1[5] = x5 && 5 >= -x && 5 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 5] : 0; -; CHECK-NEXT: x1[6] = x5 && 6 >= -x && 6 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 6] : 0; -; CHECK-NEXT: x1[7] = x5 && 7 >= -x && 7 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 7] : 0; +; CHECK-NEXT: x1[0] = x5 && (0 >= -x && 0 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 0] : 0; +; CHECK-NEXT: x1[1] = x5 && (1 >= -x && 1 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 1] : 0; +; CHECK-NEXT: x1[2] = x5 && (2 >= -x && 2 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 2] : 0; +; CHECK-NEXT: x1[3] = x5 && (3 >= -x && 3 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 3] : 0; +; CHECK-NEXT: x1[4] = x5 && (4 >= -x && 4 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 4] : 0; +; CHECK-NEXT: x1[5] = x5 && (5 >= -x && 5 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 5] : 0; +; CHECK-NEXT: x1[6] = x5 && (6 >= -x && 6 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 6] : 0; +; CHECK-NEXT: x1[7] = x5 && (7 >= -x && 7 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 7] : 0; } func @coopmatrix_b_load_n(%B: memref, %x: index, %y: index) subgroup_size(16) { @@ -93,30 +121,30 @@ func @coopmatrix_b_load_n(%B: memref, %x: index, %y: index) subgroup_ } func @coopmatrix_b_load_n_checked(%B: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n.checked %B[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.n.both_checked %B[%x,%y] : memref -> coopmatrix ; CHECK-LABEL: void coopmatrix_b_load_n_checked({{.*}} ; CHECK: float x1[16]; ; CHECK-NEXT: global float* x2 = B + x * 1 + y * 64; ; CHECK-NEXT: long x3 = 64 - x; ; CHECK-NEXT: long x4 = 48 - y; ; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -y && get_sub_group_local_id() + 0 < x4; -; CHECK-NEXT: x1[0] = x5 && 0 >= -x && 0 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 0] : 0; -; CHECK-NEXT: x1[1] = x5 && 1 >= -x && 1 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 1] : 0; -; CHECK-NEXT: x1[2] = x5 && 2 >= -x && 2 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 2] : 0; -; CHECK-NEXT: x1[3] = x5 && 3 >= -x && 3 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 3] : 0; -; CHECK-NEXT: x1[4] = x5 && 4 >= -x && 4 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 4] : 0; -; CHECK-NEXT: x1[5] = x5 && 5 >= -x && 5 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 5] : 0; -; CHECK-NEXT: x1[6] = x5 && 6 >= -x && 6 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 6] : 0; -; CHECK-NEXT: x1[7] = x5 && 7 >= -x && 7 < x3 ? x2[64 * (get_sub_group_local_id() + 0) + 7] : 0; +; CHECK-NEXT: x1[0] = x5 && (0 >= -x && 0 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 0] : 0; +; CHECK-NEXT: x1[1] = x5 && (1 >= -x && 1 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 1] : 0; +; CHECK-NEXT: x1[2] = x5 && (2 >= -x && 2 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 2] : 0; +; CHECK-NEXT: x1[3] = x5 && (3 >= -x && 3 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 3] : 0; +; CHECK-NEXT: x1[4] = x5 && (4 >= -x && 4 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 4] : 0; +; CHECK-NEXT: x1[5] = x5 && (5 >= -x && 5 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 5] : 0; +; CHECK-NEXT: x1[6] = x5 && (6 >= -x && 6 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 6] : 0; +; CHECK-NEXT: x1[7] = x5 && (7 >= -x && 7 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 7] : 0; ; CHECK-NEXT: bool x6 = get_sub_group_local_id() + 16 >= -y && get_sub_group_local_id() + 16 < x4; -; CHECK-NEXT: x1[8] = x6 && 0 >= -x && 0 < x3 ? x2[64 * (get_sub_group_local_id() + 16) + 0] : 0; -; CHECK-NEXT: x1[9] = x6 && 1 >= -x && 1 < x3 ? x2[64 * (get_sub_group_local_id() + 16) + 1] : 0; -; CHECK-NEXT: x1[10] = x6 && 2 >= -x && 2 < x3 ? x2[64 * (get_sub_group_local_id() + 16) + 2] : 0; -; CHECK-NEXT: x1[11] = x6 && 3 >= -x && 3 < x3 ? x2[64 * (get_sub_group_local_id() + 16) + 3] : 0; -; CHECK-NEXT: x1[12] = x6 && 4 >= -x && 4 < x3 ? x2[64 * (get_sub_group_local_id() + 16) + 4] : 0; -; CHECK-NEXT: x1[13] = x6 && 5 >= -x && 5 < x3 ? x2[64 * (get_sub_group_local_id() + 16) + 5] : 0; -; CHECK-NEXT: x1[14] = x6 && 6 >= -x && 6 < x3 ? x2[64 * (get_sub_group_local_id() + 16) + 6] : 0; -; CHECK-NEXT: x1[15] = x6 && 7 >= -x && 7 < x3 ? x2[64 * (get_sub_group_local_id() + 16) + 7] : 0; +; CHECK-NEXT: x1[8] = x6 && (0 >= -x && 0 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 0] : 0; +; CHECK-NEXT: x1[9] = x6 && (1 >= -x && 1 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 1] : 0; +; CHECK-NEXT: x1[10] = x6 && (2 >= -x && 2 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 2] : 0; +; CHECK-NEXT: x1[11] = x6 && (3 >= -x && 3 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 3] : 0; +; CHECK-NEXT: x1[12] = x6 && (4 >= -x && 4 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 4] : 0; +; CHECK-NEXT: x1[13] = x6 && (5 >= -x && 5 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 5] : 0; +; CHECK-NEXT: x1[14] = x6 && (6 >= -x && 6 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 6] : 0; +; CHECK-NEXT: x1[15] = x6 && (7 >= -x && 7 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 7] : 0; } func @coopmatrix_b_load_t(%B: memref, %x: index, %y: index) subgroup_size(16) { @@ -135,19 +163,19 @@ func @coopmatrix_b_load_t(%B: memref, %x: index, %y: index) subgroup_ } func @coopmatrix_b_load_t_checked(%B: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.t.checked %B[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.t.both_checked %B[%x,%y] : memref -> coopmatrix ; CHECK-LABEL: void coopmatrix_b_load_t_checked({{.*}} ; CHECK: float x1[8]; ; CHECK-NEXT: global float* x2 = B + x * 1 + y * 64; ; CHECK-NEXT: long x3 = 64 - x; ; CHECK-NEXT: long x4 = 48 - y; ; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x3; -; CHECK-NEXT: x1[0] = x5 && 0 >= -y && 0 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 0] : 0; -; CHECK-NEXT: x1[1] = x5 && 1 >= -y && 1 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 64] : 0; -; CHECK-NEXT: x1[2] = x5 && 2 >= -y && 2 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 128] : 0; -; CHECK-NEXT: x1[3] = x5 && 3 >= -y && 3 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 192] : 0; -; CHECK-NEXT: x1[4] = x5 && 4 >= -y && 4 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 256] : 0; -; CHECK-NEXT: x1[5] = x5 && 5 >= -y && 5 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 320] : 0; -; CHECK-NEXT: x1[6] = x5 && 6 >= -y && 6 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 384] : 0; -; CHECK-NEXT: x1[7] = x5 && 7 >= -y && 7 < x4 ? x2[1 * (get_sub_group_local_id() + 0) + 448] : 0; +; CHECK-NEXT: x1[0] = x5 && (0 >= -y && 0 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 0] : 0; +; CHECK-NEXT: x1[1] = x5 && (1 >= -y && 1 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 64] : 0; +; CHECK-NEXT: x1[2] = x5 && (2 >= -y && 2 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 128] : 0; +; CHECK-NEXT: x1[3] = x5 && (3 >= -y && 3 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 192] : 0; +; CHECK-NEXT: x1[4] = x5 && (4 >= -y && 4 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 256] : 0; +; CHECK-NEXT: x1[5] = x5 && (5 >= -y && 5 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 320] : 0; +; CHECK-NEXT: x1[6] = x5 && (6 >= -y && 6 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 384] : 0; +; CHECK-NEXT: x1[7] = x5 && (7 >= -y && 7 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 448] : 0; } diff --git a/test/codegen/coopmatrix_store.ir b/test/codegen/coopmatrix_store.ir index e0988ff5..60b4694b 100644 --- a/test/codegen/coopmatrix_store.ir +++ b/test/codegen/coopmatrix_store.ir @@ -11,9 +11,37 @@ func @coopmatrix_a_store_n(%A: memref, %x: index, %y: index) subgroup ; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 64] = c0[1]; } +func @coopmatrix_a_store_n_rows_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { + %c0 = constant 1.0 -> coopmatrix + cooperative_matrix_store.rows_checked %c0, %A[%x,%y] : coopmatrix, memref +; CHECK-LABEL: void coopmatrix_a_store_n_rows_checked({{.*}} +; CHECK: global float* x1 = A + x * 1 + y * 64; +; CHECK-NEXT: long x2 = 64 - x; +; CHECK-NEXT: long x3 = 48 - y; +; CHECK-NEXT: if (get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x2) { +; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 0] = c0[0]; +; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 64] = c0[1]; +; CHECK-NEXT: } +} + +func @coopmatrix_a_store_n_cols_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { + %c0 = constant 1.0 -> coopmatrix + cooperative_matrix_store.cols_checked %c0, %A[%x,%y] : coopmatrix, memref +; CHECK-LABEL: void coopmatrix_a_store_n_cols_checked({{.*}} +; CHECK: global float* x1 = A + x * 1 + y * 64; +; CHECK-NEXT: long x2 = 64 - x; +; CHECK-NEXT: long x3 = 48 - y; +; CHECK-NEXT: if (0 >= -y && 0 < x3) { +; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 0] = c0[0]; +; CHECK-NEXT: } +; CHECK-NEXT: if (1 >= -y && 1 < x3) { +; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 64] = c0[1]; +; CHECK-NEXT: } +} + func @coopmatrix_a_store_n_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { %c0 = constant 1.0 -> coopmatrix - cooperative_matrix_store.checked %c0, %A[%x,%y] : coopmatrix, memref + cooperative_matrix_store.both_checked %c0, %A[%x,%y] : coopmatrix, memref ; CHECK-LABEL: void coopmatrix_a_store_n_checked({{.*}} ; CHECK: global float* x1 = A + x * 1 + y * 64; ; CHECK-NEXT: long x2 = 64 - x; @@ -39,7 +67,7 @@ func @coopmatrix_a_store_atomic_add(%A: memref, %x: index, %y: index) func @coopmatrix_a_store_checked_atomic_add(%A: memref, %x: index, %y: index) subgroup_size(16) { %c0 = constant 1.0 -> coopmatrix - cooperative_matrix_store.checked.atomic_add %c0, %A[%x,%y] : coopmatrix, memref + cooperative_matrix_store.both_checked.atomic_add %c0, %A[%x,%y] : coopmatrix, memref ; CHECK-LABEL: void coopmatrix_a_store_checked_atomic_add({{.*}} ; CHECK: global float* x1 = A + x * 1 + y * 64; ; CHECK-NEXT: long x2 = 64 - x; diff --git a/test/generator.cpp b/test/generator.cpp index e5ba8887..b0be7a2b 100644 --- a/test/generator.cpp +++ b/test/generator.cpp @@ -66,25 +66,39 @@ TEST_CASE("routine names") { TEST_CASE("max register block") { auto s1 = max_register_block_gemm(4, 16, 8192); - CHECK(s1.first == 2); + CHECK(s1.first == 32); CHECK(s1.second == 19); auto s2 = max_register_block_gemm(4, 16, 16384); - CHECK(s2.first == 2); + CHECK(s2.first == 32); CHECK(s2.second == 44); auto s3 = max_register_block_gemm(4, 32, 8192); - CHECK(s3.first == 1); + CHECK(s3.first == 32); CHECK(s3.second == 19); auto s4 = max_register_block_gemm(4, 32, 16384); - CHECK(s4.first == 1); + CHECK(s4.first == 32); CHECK(s4.second == 44); auto d1 = max_register_block_gemm(8, 16, 8192); - CHECK(d1.first == 1); + CHECK(d1.first == 16); CHECK(d1.second == 16); auto d2 = max_register_block_gemm(8, 16, 16384); - CHECK(d2.first == 2); + CHECK(d2.first == 32); CHECK(d2.second == 19); } +TEST_CASE("block size") { + CHECK(compute_m_block_size(16, 48, 1, 5) == 16); + CHECK(compute_m_block_size(16, 48, 1, 17) == 32); + CHECK(compute_m_block_size(16, 48, 1, 32) == 32); + CHECK(compute_m_block_size(16, 48, 1, 48) == 48); + CHECK(compute_m_block_size(16, 48, 3, 144) == 48); + CHECK(compute_m_block_size(16, 48, 3, 143) == 48); + CHECK(compute_m_block_size(16, 48, 3, 145) == 16); + CHECK(compute_m_block_size(16, 48, 3, 288) == 48); + CHECK(compute_m_block_size(16, 48, 3, 286) == 48); + CHECK(compute_m_block_size(16, 48, 3, 290) == 16); + CHECK(compute_m_block_size(16, 48, 7, 224) == 32); +} + TEST_CASE("compatible scalar type") { for (int i = 0; i < TINYTC_NUMBER_OF_SCALAR_TYPES; ++i) { const auto si = enum_cast(i); From d435c5bea4bf02287c03aa2a00334952da3a5751 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 23 Oct 2024 10:20:52 +0200 Subject: [PATCH 066/297] Add compatible types rules Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 34 ++++++++++++++++++++------ include/tinytc/types.h | 1 + include/tinytc/types.hpp | 1 + src/error.cpp | 2 ++ src/node/inst_node.cpp | 50 ++++++++++++++++++++++++++++----------- src/node/inst_node.hpp | 4 ++-- 6 files changed, 69 insertions(+), 23 deletions(-) diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 67af2edb..90841708 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -157,6 +157,15 @@ The number behind the scalar type prefix denotes the number of bits, e.g. "f64" are double precision floating point numbers. The "index" type is an integer type whose width is platform-specific. +Scalar types are ordered as +:math:`i1 \prec \text{i8} \prec \text{i16} \prec \text{i32} \prec \text{i64} \prec \text{f32} \prec \text{f64} \prec \text{c32} \prec \text{c64}`. +A scalar type :math:`\alpha` is called *compatible to* a scalar type :math:`\beta` if +:math:`\alpha \preceq \beta`. +If an arithmetic operation involves mixed types :math:`\alpha` and :math:`\beta` and +:math:`\alpha \preceq \beta`, then :math:`\alpha` is casted to :math:`\beta` and the arithmetic operation +is done with type :math:`\beta`. + + Memref type ----------- @@ -361,7 +370,9 @@ or 2 (matrix). Restrictions ~~~~~~~~~~~~ -If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. +* :math:`\text{type}(\alpha) \preceq \text{element_type}(A)` +* :math:`\text{type}(\beta) \preceq \text{element_type}(B)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. Foreach ....... @@ -428,7 +439,9 @@ If :math:`\text{op}_1(A)` has the shape MxK and Restrictions ~~~~~~~~~~~~ -If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. +* :math:`\text{type}(\alpha) \preceq \text{compatible_type}(\text{element_type}(A), \text{element_type}(B))` +* :math:`\text{type}(\beta) \preceq \text{element_type}(C)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. GEMV .... @@ -464,7 +477,9 @@ The transpose modifier for A as in GEMM. Restrictions ~~~~~~~~~~~~ -If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. +* :math:`\text{type}(\alpha) \preceq \text{compatible_type}(\text{element_type}(A), \text{element_type}(b))` +* :math:`\text{type}(\beta) \preceq \text{element_type}(C)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. GER ... @@ -498,7 +513,9 @@ a and b must be vectors. If the size of a is M and the size of b is N the shape Restrictions ~~~~~~~~~~~~ -If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. +* :math:`\text{type}(\alpha) \preceq \text{compatible_type}(\text{element_type}(a), \text{element_type}(b))` +* :math:`\text{type}(\beta) \preceq \text{element_type}(C)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. Hadamard product @@ -534,7 +551,9 @@ a, b, and c must be vectors and have equal shape. Restrictions ~~~~~~~~~~~~ -If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. +* :math:`\text{type}(\alpha) \preceq \text{compatible_type}(\text{element_type}(a), \text{element_type}(b))` +* :math:`\text{type}(\beta) \preceq \text{element_type}(c)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. Parallel ........ @@ -590,8 +609,9 @@ The transpose op is defined as in the axpby instruction. Restrictions ~~~~~~~~~~~~ -If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. - +* :math:`\text{type}(\alpha) \preceq \text{element_type}(A)` +* :math:`\text{type}(\beta) \preceq \text{element_type}(B)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. Mixed instructions diff --git a/include/tinytc/types.h b/include/tinytc/types.h index a084cabe..49e6b15c 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -80,6 +80,7 @@ typedef enum { tinytc_status_ir_init_return_mismatch = 0x121, ///< Mismatch of init values and returned values tinytc_status_ir_invalid_matrix_use = 0x122, ///< Invalid matrix use tinytc_status_ir_unsupported_coopmatrix_shape = 0x123, ///< Unsupported coopmatrix shape + tinytc_status_ir_incompatible_scalar_types = 0x124, ///< Incompatible scalar types // Level zero errors tinytc_status_ze_result_not_ready = 0x10000, ///< ZE_RESULT_NOT_READY tinytc_status_ze_result_error_device_lost = 0x10001, ///< ZE_RESULT_ERROR_DEVICE_LOST diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index b4c99b5e..d1bf425e 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -88,6 +88,7 @@ enum class status { ir_init_return_mismatch = tinytc_status_ir_init_return_mismatch, ir_invalid_matrix_use = tinytc_status_ir_invalid_matrix_use, ir_unsupported_coopmatrix_shape = tinytc_status_ir_unsupported_coopmatrix_shape, + ir_incompatible_scalar_types = tinytc_status_ir_incompatible_scalar_types, ze_result_not_ready = tinytc_status_ze_result_not_ready, ze_result_error_device_lost = tinytc_status_ze_result_error_device_lost, ze_result_error_out_of_host_memory = tinytc_status_ze_result_error_out_of_host_memory, diff --git a/src/error.cpp b/src/error.cpp index a7d934ff..33d6d976 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -192,6 +192,8 @@ char const *tinytc_error_string(tinytc_status_t status) { case tinytc_status_ir_unsupported_coopmatrix_shape: return "Unsupported coopmatrix shape for the combination of scalar type, matrix use, and " "target architecture"; + case tinytc_status_ir_incompatible_scalar_types: + return "Scalar types violate compatibility rules"; // Level Zero case tinytc_status_ze_result_not_ready: return "ZE_RESULT_NOT_READY"; diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 6fc66f09..10c67125 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -50,22 +50,50 @@ void check_index_ty(location const &loc, tinytc_data_type_t ty) { } blas_a2_inst::blas_a2_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t beta, - tinytc_value_t B, bool atomic) + tinytc_value_t B, bool atomic, location const &lc) : standard_inst{tid}, atomic_(atomic) { op(op_alpha, alpha); op(op_A, A); op(op_beta, beta); op(op_B, B); + loc(lc); + + auto At = get_memref_type(loc(), op(op_A)); + auto Bt = get_memref_type(loc(), op(op_B)); + auto alphat = get_scalar_type(loc(), op(op_alpha)); + auto betat = get_scalar_type(loc(), op(op_beta)); + + if (compatible_type(alphat->ty(), At->element_ty()) != At->element_ty()) { + throw compilation_error(loc(), status::ir_incompatible_scalar_types); + } + if (compatible_type(betat->ty(), Bt->element_ty()) != Bt->element_ty()) { + throw compilation_error(loc(), status::ir_incompatible_scalar_types); + } } blas_a3_inst::blas_a3_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, - tinytc_value_t beta, tinytc_value_t C, bool atomic) + tinytc_value_t beta, tinytc_value_t C, bool atomic, location const &lc) : standard_inst{tid}, atomic_(atomic) { op(op_alpha, alpha); op(op_A, A); op(op_B, B); op(op_beta, beta); op(op_C, C); + loc(lc); + + auto At = get_memref_type(loc(), op(op_A)); + auto Bt = get_memref_type(loc(), op(op_B)); + auto Ct = get_memref_type(loc(), op(op_C)); + auto alphat = get_scalar_type(loc(), op(op_alpha)); + auto betat = get_scalar_type(loc(), op(op_beta)); + + const auto AB_ty = compatible_type(At->element_ty(), Bt->element_ty()); + if (compatible_type(alphat->ty(), AB_ty) != AB_ty) { + throw compilation_error(loc(), status::ir_incompatible_scalar_types); + } + if (compatible_type(betat->ty(), Ct->element_ty()) != Ct->element_ty()) { + throw compilation_error(loc(), status::ir_incompatible_scalar_types); + } } loop_inst::loop_inst(IK tid, tinytc_value_t from0, tinytc_value_t to0, tinytc_value_t step0, @@ -134,9 +162,8 @@ alloca_inst::alloca_inst(tinytc_data_type_t ty, location const &lc) axpby_inst::axpby_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t beta0, tinytc_value_t B0, bool atomic, location const &lc) : blas_a2_inst(IK::axpby_blas_a2, std::move(alpha0), std::move(A0), std::move(beta0), - std::move(B0), atomic), + std::move(B0), atomic, lc), tA_(tA) { - loc(lc); auto a = get_memref_type(loc(), A()); auto b = get_memref_type(loc(), B()); @@ -628,9 +655,8 @@ gemm_inst::gemm_inst(transpose tA, transpose tB, tinytc_value_t alpha0, tinytc_v tinytc_value_t B0, tinytc_value_t beta0, tinytc_value_t C0, bool atomic, location const &lc) : blas_a3_inst(IK::gemm_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), - std::move(beta0), std::move(C0), atomic), + std::move(beta0), std::move(C0), atomic, lc), tA_(tA), tB_(tB) { - loc(lc); auto a = get_memref_type(loc(), A()); auto b = get_memref_type(loc(), B()); auto c = get_memref_type(loc(), C()); @@ -658,9 +684,8 @@ gemm_inst::gemm_inst(transpose tA, transpose tB, tinytc_value_t alpha0, tinytc_v gemv_inst::gemv_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t B0, tinytc_value_t beta0, tinytc_value_t C0, bool atomic, location const &lc) : blas_a3_inst(IK::gemv_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), - std::move(beta0), std::move(C0), atomic), + std::move(beta0), std::move(C0), atomic, lc), tA_(tA) { - loc(lc); auto a = get_memref_type(loc(), A()); auto b = get_memref_type(loc(), B()); auto c = get_memref_type(loc(), C()); @@ -686,8 +711,7 @@ gemv_inst::gemv_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, tin ger_inst::ger_inst(tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t B0, tinytc_value_t beta0, tinytc_value_t C0, bool atomic, location const &lc) : blas_a3_inst(IK::ger_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), - std::move(beta0), std::move(C0), atomic) { - loc(lc); + std::move(beta0), std::move(C0), atomic, lc) { auto a = get_memref_type(loc(), A()); auto b = get_memref_type(loc(), B()); auto c = get_memref_type(loc(), C()); @@ -720,8 +744,7 @@ hadamard_inst::hadamard_inst(tinytc_value_t alpha0, tinytc_value_t A0, tinytc_va tinytc_value_t beta0, tinytc_value_t C0, bool atomic, location const &lc) : blas_a3_inst(IK::hadamard_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), - std::move(beta0), std::move(C0), atomic) { - loc(lc); + std::move(beta0), std::move(C0), atomic, lc) { auto a = get_memref_type(loc(), A()); auto b = get_memref_type(loc(), B()); auto c = get_memref_type(loc(), C()); @@ -855,9 +878,8 @@ store_inst::store_inst(store_flag flag, tinytc_value_t val0, tinytc_value_t op0, sum_inst::sum_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t beta0, tinytc_value_t B0, bool atomic, location const &lc) : blas_a2_inst(IK::sum_blas_a2, std::move(alpha0), std::move(A0), std::move(beta0), - std::move(B0), atomic), + std::move(B0), atomic, lc), tA_(tA) { - loc(lc); auto a = get_memref_type(loc(), A()); auto b = get_memref_type(loc(), B()); diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 94951bb9..463e2c60 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -331,7 +331,7 @@ class blas_a2_inst : public standard_inst<4, 0> { } enum op_number { op_alpha = 0, op_A = 1, op_beta = 2, op_B = 3 }; blas_a2_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t beta, - tinytc_value_t B, bool atomic); + tinytc_value_t B, bool atomic, location const &lc); inline bool atomic() const { return atomic_; } inline void atomic(bool a) { atomic_ = a; } @@ -355,7 +355,7 @@ class blas_a3_inst : public standard_inst<5, 0> { } enum op_number { op_alpha = 0, op_A = 1, op_B = 2, op_beta = 3, op_C = 4 }; blas_a3_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, - tinytc_value_t beta, tinytc_value_t C, bool atomic); + tinytc_value_t beta, tinytc_value_t C, bool atomic, location const &lc); inline bool atomic() const { return atomic_; } inline void atomic(bool a) { atomic_ = a; } From a61692f599f2f811bbfa6e0c7c133c7e4b795e71 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 23 Oct 2024 10:28:20 +0200 Subject: [PATCH 067/297] More compatible types rules Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 8 +++++++- src/node/inst_node.cpp | 5 +++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 90841708..5c989c71 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -861,7 +861,7 @@ Matrix mul add returns the value of .. math:: - AB + C, + D := AB + C, where A, B, and C are matrices given by the three operands. @@ -871,6 +871,12 @@ and the third operand and the result have shape :math:`M\times N` with use "matr The component types of the operands and the result do not need to match. +Restrictions +~~~~~~~~~~~~ + +* :math:`\text{compatible_type}(\text{component_type}(A), \text{component_type}(B)) \preceq \text{component_type}(C)` +* Cast of :math:`\text{component_type}(C)` to :math:`\text{component_type}(D)` must be allowed + Cooperative matrix scale ........................ diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 10c67125..1b677ad7 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -481,6 +481,11 @@ cooperative_matrix_mul_add_inst::cooperative_matrix_mul_add_inst(tinytc_value_t throw compilation_error(loc(), status::ir_invalid_matrix_use); } + const auto AB_ty = compatible_type(at->component_ty(), bt->component_ty()); + if (compatible_type(AB_ty, ct->component_ty()) != ct->component_ty()) { + throw compilation_error(loc(), status::ir_incompatible_scalar_types); + } + result(0) = value_node{to_ty, this, lc}; } From be1e2238af0affcd30b5d167c1a7a3f602d4a653 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 23 Oct 2024 17:08:52 +0200 Subject: [PATCH 068/297] New test infrastructure Signed-off-by: Carsten Uphoff --- src/cl/argument_handler.hpp | 15 ++- src/recipe/small_gemm_batched.cpp | 2 +- test/CMakeLists.txt | 4 + test/cl/linalg.cpp | 9 ++ test/cl/smm.cpp | 24 ---- test/cl/test_runtime.cpp | 22 ++++ test/cl/test_runtime.hpp | 11 ++ test/doctest_util.hpp | 58 ++++++++++ test/linalg.hpp | 94 ++++++++++++++++ test/linalg_ops.cpp | 67 ++++++++++++ test/linalg_ops.hpp | 95 ++++++++++++++++ test/linalg_runner.hpp | 114 +++++++++++++++++++ test/linalg_types.cpp | 52 +++++++++ test/linalg_types.hpp | 39 +++++++ test/runtime_concept.hpp | 49 +++++++++ test/smm.hpp | 176 ------------------------------ test/sycl/CMakeLists.txt | 8 +- test/sycl/linalg.cpp | 9 ++ test/sycl/smm.cpp | 55 ---------- test/sycl/test_runtime.cpp | 33 ++++++ test/sycl/test_runtime.hpp | 14 +++ test/tensor3.hpp | 63 ----------- test/ze/CMakeLists.txt | 8 +- test/ze/linalg.cpp | 9 ++ test/ze/smm.cpp | 46 -------- test/ze/test_runtime.cpp | 20 ++++ test/ze/test_runtime.hpp | 9 ++ 27 files changed, 726 insertions(+), 379 deletions(-) create mode 100644 test/cl/linalg.cpp delete mode 100644 test/cl/smm.cpp create mode 100644 test/doctest_util.hpp create mode 100644 test/linalg.hpp create mode 100644 test/linalg_ops.cpp create mode 100644 test/linalg_ops.hpp create mode 100644 test/linalg_runner.hpp create mode 100644 test/linalg_types.cpp create mode 100644 test/linalg_types.hpp create mode 100644 test/runtime_concept.hpp delete mode 100644 test/smm.hpp create mode 100644 test/sycl/linalg.cpp delete mode 100644 test/sycl/smm.cpp delete mode 100644 test/tensor3.hpp create mode 100644 test/ze/linalg.cpp delete mode 100644 test/ze/smm.cpp diff --git a/src/cl/argument_handler.hpp b/src/cl/argument_handler.hpp index 7cc8d5f2..ef1da91e 100644 --- a/src/cl/argument_handler.hpp +++ b/src/cl/argument_handler.hpp @@ -19,13 +19,16 @@ class opencl_argument_handler { using clSetKernelArgMemPointerINTEL_t = cl_int (*)(cl_kernel kernel, cl_uint arg_index, const void *arg_value); //! ctor - inline opencl_argument_handler() : clSetKernelArgMemPointerINTEL_(nullptr) {} + inline opencl_argument_handler() = default; //! ctor; checks whether cl_intel_unified_shared_memory is available and gets //! clSetKernelArgMemPointerINTEL - inline opencl_argument_handler(cl_platform_id plat) - : clSetKernelArgMemPointerINTEL_( - (clSetKernelArgMemPointerINTEL_t)clGetExtensionFunctionAddressForPlatform( - plat, "clSetKernelArgMemPointerINTEL")) {} + inline opencl_argument_handler(cl_platform_id plat) { set_platform(plat); } + + inline void set_platform(cl_platform_id plat) { + clSetKernelArgMemPointerINTEL_ = + (clSetKernelArgMemPointerINTEL_t)clGetExtensionFunctionAddressForPlatform( + plat, "clSetKernelArgMemPointerINTEL"); + } /** * @brief Set single kernel argument @@ -69,7 +72,7 @@ class opencl_argument_handler { } private: - clSetKernelArgMemPointerINTEL_t clSetKernelArgMemPointerINTEL_; + clSetKernelArgMemPointerINTEL_t clSetKernelArgMemPointerINTEL_ = nullptr; }; } // namespace tinytc diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index 5d71bab3..56dff324 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -72,7 +72,7 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( ++l.end.column; return l; }; - auto const make_static_sizes = [](transpose t, int64_t A, std::int64_t B) { + auto const make_static_sizes = [](transpose t, std::int64_t A, std::int64_t B) { auto s = std::array{A, B, 0}; if (t == transpose::T) { std::swap(s[0], s[1]); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index ca72a24f..85e7844c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -26,6 +26,10 @@ target_link_libraries(test-visitor PRIVATE test-lib) doctest_discover_tests(test-visitor) set_cxx_common_options(test-visitor) +add_library(test-lib-linalg STATIC linalg_ops.cpp linalg_types.cpp) +target_link_libraries(test-lib-linalg PUBLIC tinytc) +set_cxx_common_options(test-lib-linalg) + configure_file(lit.site.cfg.py.in lit.site.cfg.py @ONLY) file(READ ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in LIT_SITE_CONFIG) string(CONFIGURE "${LIT_SITE_CONFIG}" LIT_SITE_CONFIG @ONLY) diff --git a/test/cl/linalg.cpp b/test/cl/linalg.cpp new file mode 100644 index 00000000..9820a9e7 --- /dev/null +++ b/test/cl/linalg.cpp @@ -0,0 +1,9 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "test_runtime.hpp" + +#define RUNTIME_CLASS opencl_test_runtime +#define RUNTIME_NAME "opencl" +#include "../linalg.hpp" +#undef RUNTIME_CLASS diff --git a/test/cl/smm.cpp b/test/cl/smm.cpp deleted file mode 100644 index 83c9a187..00000000 --- a/test/cl/smm.cpp +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "../smm.hpp" -#include "test_runtime.hpp" - -#include "doctest/doctest.h" - -#include - -using namespace tinytc; - -TEST_CASE_TEMPLATE("opencl packed alpha=1 beta=0", T, TEST_PRECISIONS) { - auto KK = std::vector{71}; - auto MM = std::vector{27, 43}; - auto NN = std::vector{3, 33}; - auto HH = std::vector{1, 51}; - - std::uint32_t M, N, K, howmany; - DOCTEST_TENSOR4_TEST(MM, NN, KK, HH); - - check_small_gemm_batched(transpose::N, transpose::N, M, N, K, M, M * K, - K, K * N, M, M * N, 1.0, 0.0, howmany); -} diff --git a/test/cl/test_runtime.cpp b/test/cl/test_runtime.cpp index 0f63ca79..f51875ba 100644 --- a/test/cl/test_runtime.cpp +++ b/test/cl/test_runtime.cpp @@ -26,6 +26,7 @@ opencl_test_runtime::opencl_test_runtime() { device_count = 1; CL_CHECK_STATUS( clGetDeviceIDs(platforms[p], CL_DEVICE_TYPE_GPU, device_count, &dev_, NULL)); + arg_handler_.set_platform(platforms[p]); break; } } @@ -76,6 +77,27 @@ auto opencl_test_runtime::get_command_list() -> command_list_t { return q_; } auto opencl_test_runtime::get_recipe_handler(tinytc::recipe const &rec) -> recipe_handler_t { return tinytc::make_recipe_handler(ctx_, dev_, rec); } +auto opencl_test_runtime::get_kernel_bundle(tinytc::prog p) -> kernel_bundle_t { + return ::tinytc::make_kernel_bundle(ctx_, dev_, std::move(p)); +} +auto opencl_test_runtime::get_kernel(kernel_bundle_t const &bundle, char const *name) -> kernel_t { + return ::tinytc::make_kernel(bundle.get(), name); +} +void opencl_test_runtime::set_arg(kernel_t &kernel, std::uint32_t arg_index, std::size_t arg_size, + const void *arg_value) { + arg_handler_.set_arg(kernel.get(), arg_index, arg_size, arg_value); +} +void opencl_test_runtime::set_mem_arg(kernel_t &kernel, std::uint32_t arg_index, + const void *arg_value, tinytc::mem_type type) { + arg_handler_.set_mem_arg(kernel.get(), arg_index, arg_value, + static_cast(type)); +} +void opencl_test_runtime::submit(kernel_t &kernel, std::int64_t howmany) { + auto ls = ::tinytc::get_group_size(kernel.get()); + auto gs = ::tinytc::get_global_size(howmany, ls); + CL_CHECK_STATUS( + clEnqueueNDRangeKernel(q_, kernel.get(), 3u, NULL, gs.data(), ls.data(), 0, NULL, NULL)); +} void opencl_test_runtime::synchronize() { CL_CHECK_STATUS(clFinish(q_)); } bool opencl_test_runtime::supports_fp64() { diff --git a/test/cl/test_runtime.hpp b/test/cl/test_runtime.hpp index f47678bc..d645356b 100644 --- a/test/cl/test_runtime.hpp +++ b/test/cl/test_runtime.hpp @@ -4,6 +4,7 @@ #ifndef CL_TEST_RUNTIME_20240314_HPP #define CL_TEST_RUNTIME_20240314_HPP +#include "cl/argument_handler.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/tinytc_cl.hpp" @@ -16,6 +17,8 @@ class opencl_test_runtime { using context_t = cl_context; using command_list_t = cl_command_queue; using recipe_handler_t = tinytc::opencl_recipe_handler; + using kernel_bundle_t = tinytc::shared_handle; + using kernel_t = tinytc::shared_handle; using mem_t = cl_mem; using const_mem_t = const cl_mem; @@ -38,6 +41,13 @@ class opencl_test_runtime { auto get_context() -> context_t; auto get_command_list() -> command_list_t; auto get_recipe_handler(tinytc::recipe const &rec) -> recipe_handler_t; + auto get_kernel_bundle(tinytc::prog p) -> kernel_bundle_t; + auto get_kernel(kernel_bundle_t const &bundle, char const *name) -> kernel_t; + void set_arg(kernel_t &kernel, std::uint32_t arg_index, std::size_t arg_size, + const void *arg_value); + void set_mem_arg(kernel_t &kernel, std::uint32_t arg_index, const void *arg_value, + tinytc::mem_type type); + void submit(kernel_t &kernel, std::int64_t howmany = 1); void synchronize(); bool supports_fp64(); @@ -46,6 +56,7 @@ class opencl_test_runtime { device_t dev_; context_t ctx_; command_list_t q_; + tinytc::opencl_argument_handler arg_handler_; }; #endif // CL_TEST_RUNTIME_20240314_HPP diff --git a/test/doctest_util.hpp b/test/doctest_util.hpp new file mode 100644 index 00000000..c9edf215 --- /dev/null +++ b/test/doctest_util.hpp @@ -0,0 +1,58 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DOCTEST_UTIL_20241023_HPP +#define DOCTEST_UTIL_20241023_HPP + +#define DOCTEST_TENSOR2_TEST(MM, NN) \ + do { \ + for (auto nn : NN) { \ + for (auto mm : MM) { \ + DOCTEST_SUBCASE((std::to_string(mm) + "x" + std::to_string(nn)).c_str()) { \ + N = nn; \ + M = mm; \ + } \ + } \ + } \ + } while (false) + +#define DOCTEST_TENSOR3_TEST(MM, NN, KK) \ + do { \ + for (auto kk : KK) { \ + for (auto nn : NN) { \ + for (auto mm : MM) { \ + DOCTEST_SUBCASE( \ + (std::to_string(mm) + "x" + std::to_string(nn) + "x" + std::to_string(kk)) \ + .c_str()) { \ + K = kk; \ + N = nn; \ + M = mm; \ + } \ + } \ + } \ + } \ + } while (false) + +#define DOCTEST_TENSOR4_TEST(MM, NN, KK, HH) \ + do { \ + for (auto hh : HH) { \ + for (auto kk : KK) { \ + for (auto nn : NN) { \ + for (auto mm : MM) { \ + DOCTEST_SUBCASE((std::to_string(mm) + "x" + std::to_string(nn) + "x" + \ + std::to_string(kk) + "*" + std::to_string(hh)) \ + .c_str()) { \ + howmany = hh; \ + K = kk; \ + N = nn; \ + M = mm; \ + } \ + } \ + } \ + } \ + } \ + } while (false) + +#define TEST_PRECISIONS float, double + +#endif // DOCTEST_UTIL_20241023_HPP diff --git a/test/linalg.hpp b/test/linalg.hpp new file mode 100644 index 00000000..19e75629 --- /dev/null +++ b/test/linalg.hpp @@ -0,0 +1,94 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "doctest_util.hpp" +#include "linalg_ops.hpp" +#include "linalg_runner.hpp" +#include "linalg_types.hpp" + +#include "doctest/doctest.h" + +#include +#include + +using runtime_class = RUNTIME_CLASS; +using namespace tinytc; + +TEST_CASE_TEMPLATE(RUNTIME_NAME " gemm packed alpha=1 beta=0", T, TEST_PRECISIONS) { + auto KK = std::vector{56}; + auto MM = std::vector{20, 32, 53}; + auto NN = std::vector{5, 16, 23}; + + std::int64_t M, N, K; + DOCTEST_TENSOR3_TEST(MM, NN, KK); + + auto op = test::gemm(transpose::N, transpose::N, {{M, K}}, {{K, N}}, {{M, N}}); + test::test_blas_a3(op, 1, 0); +} + +TEST_CASE_TEMPLATE(RUNTIME_NAME " gemm non-packed alpha=1 beta=0", T, TEST_PRECISIONS) { + std::int64_t M = 16, N = 32, K = 8; + std::int64_t ldA = 20, ldB = 9, ldC = 24; + + auto op = test::gemm(transpose::N, transpose::N, {{M, K}, {1, ldA}}, + {{K, N}, {1, ldB}}, {{M, N}, {1, ldC}}); + test::test_blas_a3(op, 1, 0); +} + +TEST_CASE_TEMPLATE(RUNTIME_NAME " gemm packed alpha=1 beta=1", T, TEST_PRECISIONS) { + std::int64_t M = 6, N = 33, K = 8; + + auto op = test::gemm(transpose::N, transpose::N, {{M, K}}, {{K, N}}, {{M, N}}); + test::test_blas_a3(op, 1, 1); +} + +TEST_CASE_TEMPLATE(RUNTIME_NAME " gemm packed alpha=-1 beta=2", T, TEST_PRECISIONS) { + std::int64_t M = 8, N = 16, K = 16; + + auto op = test::gemm(transpose::N, transpose::N, {{M, K}}, {{K, N}}, {{M, N}}); + test::test_blas_a3(op, -1, 2); +} + +TEST_CASE_TEMPLATE(RUNTIME_NAME " gemm non-packed alpha=1 beta=0 transA transB", T, + TEST_PRECISIONS) { + std::int64_t M = 16, N = 32, K = 8; + std::int64_t ldA = 10, ldB = 32, ldC = 24; + + auto op = test::gemm(transpose::T, transpose::T, {{K, M}, {1, ldA}}, + {{N, K}, {1, ldB}}, {{M, N}, {1, ldC}}); + test::test_blas_a3(op, 1, 0); +} + +TEST_CASE_TEMPLATE(RUNTIME_NAME " gemm non-static", T, TEST_PRECISIONS) { + std::int64_t M = 63, N = 43, K = 23; + + auto op = test::gemm(transpose::N, transpose::N, + {{M, K}, {1, M}, {dynamic, dynamic}, {1, dynamic}}, + {{K, N}, {1, K}, {dynamic, dynamic}, {1, dynamic}}, + {{M, N}, {1, M}, {dynamic, dynamic}, {1, dynamic}}); + test::test_blas_a3(op, 1, 1); +} + +TEST_CASE_TEMPLATE(RUNTIME_NAME " gemm packed complex alpha=1 beta=0", T, TEST_PRECISIONS) { + auto KK = std::vector{53}; + auto MM = std::vector{21, 42}; + auto NN = std::vector{7, 11}; + + std::int64_t M, N, K; + DOCTEST_TENSOR3_TEST(MM, NN, KK); + + using CT = std::complex; + auto op = + test::gemm(transpose::N, transpose::N, {{M, K}}, {{K, N}}, {{M, N}}); + test::test_blas_a3(op, 1, 0); +} + +TEST_CASE_TEMPLATE(RUNTIME_NAME " gemm packed complex alpha=(-1,-2) beta=(2,3)", T, + TEST_PRECISIONS) { + std::int64_t M = 8, N = 16, K = 16; + + using CT = std::complex; + auto op = + test::gemm(transpose::N, transpose::N, {{M, K}}, {{K, N}}, {{M, N}}); + test::test_blas_a3(op, {-1.0, -2.0}, {2.0, 3.0}); +} diff --git a/test/linalg_ops.cpp b/test/linalg_ops.cpp new file mode 100644 index 00000000..ddbec03d --- /dev/null +++ b/test/linalg_ops.cpp @@ -0,0 +1,67 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "linalg_ops.hpp" + +#include +#include + +namespace tinytc::test { + +auto gemm_mnk(transpose tA, transpose tB, tensor_layout const &A, tensor_layout const &B, + tensor_layout const &C) -> std::array { + if (A.dim() != 2 || B.dim() != 2 || C.dim() != 2) { + throw std::runtime_error("expected matrices"); + } + const int A_kmode = tA == transpose::T ? 1 : 0; + const int B_nmode = tB == transpose::T ? 0 : 1; + const auto M = C.shape(0); + const auto N = C.shape(1); + const auto K = A.shape(1 - A_kmode); + if (M != A.shape(A_kmode) || K != B.shape(1 - B_nmode) || N != B.shape(B_nmode)) { + throw std::runtime_error("incompatible matmul"); + } + return {M, N, K}; +} + +auto make_gemm_prog(char const *name, transpose tA, transpose tB, tensor_layout const &layoutA, + tensor_layout const &layoutB, tensor_layout const &layoutC, + scalar_type alpha_ty, scalar_type A_ty, scalar_type B_ty, scalar_type beta_ty, + scalar_type C_ty) -> prog { + auto ctx = make_compiler_context(); + + auto const alphat = get_scalar(ctx, alpha_ty); + auto const at = get_scalar(ctx, A_ty); + auto const bt = get_scalar(ctx, B_ty); + auto const betat = get_scalar(ctx, beta_ty); + auto const ct = get_scalar(ctx, C_ty); + + auto p = make_prog(ctx); + + auto At = + get_memref(at, layoutA.static_shape(), layoutA.static_stride(), address_space::global); + auto Bt = + get_memref(bt, layoutB.static_shape(), layoutB.static_stride(), address_space::global); + auto Ct = + get_memref(ct, layoutC.static_shape(), layoutC.static_stride(), address_space::global); + + auto f = make_func(name, {alphat, At, Bt, betat, Ct}); + auto fn_body = f.get_body(); + auto params = std::array{}; + fn_body.get_parameters(params); + params[0].set_name("alpha"); + params[1].set_name("A"); + params[2].set_name("B"); + params[3].set_name("beta"); + params[4].set_name("C"); + + auto bb = region_builder{fn_body}; + + bb.add(make_gemm(tA, tB, false, params[0], params[1], params[2], params[3], params[4])); + + p.add_function(std::move(f)); + + return p; +} + +} // namespace tinytc::test diff --git a/test/linalg_ops.hpp b/test/linalg_ops.hpp new file mode 100644 index 00000000..e4612341 --- /dev/null +++ b/test/linalg_ops.hpp @@ -0,0 +1,95 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LINALG_OPS_20241023_HPP +#define LINALG_OPS_20241023_HPP + +#include "linalg_types.hpp" +#include "tinytc/tinytc.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc::test { + +template +concept op_blas_a3 = requires(T op, typename T::alpha_type alpha, typename T::beta_type beta, + typename T::A_type const *A_ref, typename T::B_type const *B_ref, + typename T::C_type *C_ref) { + typename T::alpha_type; + typename T::A_type; + typename T::B_type; + typename T::beta_type; + typename T::C_type; + T::kernel_name; + { op.lA() } -> std::same_as; + { op.lB() } -> std::same_as; + { op.lC() } -> std::same_as; + { op.make_prog() } -> std::same_as; + op.reference_impl(alpha, A_ref, B_ref, beta, C_ref); +}; + +auto gemm_mnk(transpose tA, transpose tB, tensor_layout const &A, tensor_layout const &B, + tensor_layout const &C) -> std::array; + +auto make_gemm_prog(char const *name, transpose tA, transpose tB, tensor_layout const &layoutA, + tensor_layout const &layoutB, tensor_layout const &layoutC, + scalar_type alpha_ty, scalar_type A_ty, scalar_type B_ty, scalar_type beta_ty, + scalar_type C_ty) -> prog; + +template class gemm { + public: + using alpha_type = AlphaT; + using A_type = AT; + using B_type = BT; + using beta_type = BetaT; + using C_type = CT; + static constexpr char const *kernel_name = "gemm"; + + gemm(transpose tA, transpose tB, tensor_layout layoutA, tensor_layout layoutB, + tensor_layout layoutC) + : tA_(tA), tB_(tB), lA_{std::move(layoutA)}, lB_{std::move(layoutB)}, + lC_{std::move(layoutC)} {} + + auto lA() const -> tensor_layout const & { return lA_; } + auto lB() const -> tensor_layout const & { return lB_; } + auto lC() const -> tensor_layout const & { return lC_; } + + auto make_prog() const -> prog { + return make_gemm_prog(kernel_name, tA_, tB_, lA_, lB_, lC_, to_scalar_type_v, + to_scalar_type_v, to_scalar_type_v, to_scalar_type_v, + to_scalar_type_v); + } + void reference_impl(AlphaT alpha, AT const *A, BT const *B, BetaT beta, CT *C) { + const auto [M, N, K] = gemm_mnk(tA_, tB_, lA_, lB_, lC_); + auto const make_index = [](transpose t, std::int64_t m, std::int64_t n) { + auto idx = std::array{m, n}; + if (t == transpose::T) { + std::swap(idx[0], idx[1]); + } + return idx; + }; + for (std::int64_t n = 0; n < N; ++n) { + for (std::int64_t m = 0; m < M; ++m) { + CT c_acc = CT{0}; + for (std::int64_t k = 0; k < K; ++k) { + c_acc += A[lA_.linear_index(make_index(tA_, m, k))] * + B[lB_.linear_index(make_index(tB_, k, n))]; + } + auto &c = C[lC_.linear_index({m, n})]; + c = alpha * c_acc + beta * c; + } + } + } + + private: + transpose tA_, tB_; + tensor_layout lA_, lB_, lC_; +}; + +} // namespace tinytc::test + +#endif // LINALG_OPS_20241023_HPP diff --git a/test/linalg_runner.hpp b/test/linalg_runner.hpp new file mode 100644 index 00000000..2caf9c4a --- /dev/null +++ b/test/linalg_runner.hpp @@ -0,0 +1,114 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LINALG_RUNNER_20241023_HPP +#define LINALG_RUNNER_20241023_HPP + +#include "linalg_ops.hpp" +#include "linalg_types.hpp" +#include "runtime_concept.hpp" +#include "tinytc/tinytc.hpp" + +#include +#include +#include + +namespace tinytc::test { + +template struct is_complex : public std::false_type {}; +template struct is_complex> : public std::true_type {}; +template inline constexpr bool is_complex_v = is_complex::value; +template +constexpr bool requires_dp_v = std::is_same_v || std::is_same_v>; + +template +void test_blas_a3(T op, typename T::alpha_type alpha, typename T::beta_type beta) { + auto gpu_rt = std::make_shared(); + if constexpr (requires_dp_v || requires_dp_v || + requires_dp_v || requires_dp_v || + requires_dp_v) { + if (!gpu_rt->supports_fp64()) { + WARN_MESSAGE(false, "Double precision tests need double precision device support"); + return; + } + } + + auto const make_test_data = [](std::size_t size) { + auto data = std::vector(size); + for (std::size_t i = 0; i < data.size(); ++i) { + constexpr std::size_t prime = 101; + if constexpr (is_complex_v) { + data[i] = ScalarT{static_cast((2 * i) % prime), + static_cast((2 * i + 1) % prime)}; + } else { + data[i] = static_cast(i % prime); + } + } + return data; + }; + auto const compare_data = [](std::vector const &A, + std::vector const &B) { + REQUIRE(A.size() == B.size()); + for (std::size_t i = 0; i < A.size(); ++i) { + constexpr auto eps = + 10.0 * std::numeric_limits::epsilon(); + REQUIRE(std::abs(A[i] - B[i]) == doctest::Approx(0.0).epsilon(eps)); + } + }; + + auto A_ref = make_test_data.template operator()(op.lA().size()); + auto B_ref = make_test_data.template operator()(op.lB().size()); + auto C_ref = std::vector(op.lC().size()); + + op.reference_impl(alpha, A_ref.data(), B_ref.data(), beta, C_ref.data()); + + auto A = gpu_rt->create_buffer(A_ref.size() * sizeof(typename T::A_type)); + auto B = gpu_rt->create_buffer(B_ref.size() * sizeof(typename T::B_type)); + auto C = gpu_rt->create_buffer(C_ref.size() * sizeof(typename T::C_type)); + gpu_rt->memcpy_h2d(A, A_ref.data(), A_ref.size() * sizeof(typename T::A_type)); + gpu_rt->memcpy_h2d(B, B_ref.data(), B_ref.size() * sizeof(typename T::B_type)); + gpu_rt->fill_buffer(C, 0, C_ref.size() * sizeof(typename T::C_type)); + + auto bundle = gpu_rt->get_kernel_bundle(op.make_prog()); + auto kernel = gpu_rt->get_kernel(bundle, T::kernel_name); + + auto const set_dope_vector = [&](tensor_layout const &layout, std::uint32_t &arg_index) { + for (std::size_t i = 0; i < layout.shape().size(); ++i) { + if (layout.static_shape(i) == dynamic) { + std::int64_t s = layout.shape(i); + gpu_rt->set_arg(kernel, arg_index++, sizeof(s), &s); + } + } + for (std::size_t i = 0; i < layout.stride().size(); ++i) { + if (layout.static_stride(i) == dynamic) { + std::int64_t s = layout.stride(i); + gpu_rt->set_arg(kernel, arg_index++, sizeof(s), &s); + } + } + }; + + std::uint32_t i = 0; + gpu_rt->set_arg(kernel, i++, sizeof(typename T::alpha_type), &alpha); + gpu_rt->set_mem_arg(kernel, i++, A, auto_mem_type_v); + set_dope_vector(op.lA(), i); + gpu_rt->set_mem_arg(kernel, i++, B, auto_mem_type_v); + set_dope_vector(op.lB(), i); + gpu_rt->set_arg(kernel, i++, sizeof(typename T::beta_type), &beta); + gpu_rt->set_mem_arg(kernel, i++, C, auto_mem_type_v); + set_dope_vector(op.lC(), i); + gpu_rt->submit(kernel); + gpu_rt->synchronize(); + + auto C_host = std::vector(C_ref.size()); + gpu_rt->memcpy_d2h(C_host.data(), C, C_host.size() * sizeof(typename T::C_type)); + + compare_data(C_host, C_ref); + + gpu_rt->free_buffer(A); + gpu_rt->free_buffer(B); + gpu_rt->free_buffer(C); +} + +} // namespace tinytc::test + +#endif // LINALG_RUNNER_20241023_HPP diff --git a/test/linalg_types.cpp b/test/linalg_types.cpp new file mode 100644 index 00000000..36172f4c --- /dev/null +++ b/test/linalg_types.cpp @@ -0,0 +1,52 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "linalg_types.hpp" + +namespace tinytc::test { + +tensor_layout::tensor_layout(array_view shape, array_view stride, + array_view static_shape, + array_view static_stride) + : shape_(shape), stride_(stride), static_shape_(static_shape), static_stride_(static_stride) { + if (!shape_.empty()) { + if (stride_.empty() && !shape_.empty()) { + stride_.reserve(shape_.size()); + stride_.push_back(1); + for (auto &s : shape_) { + stride_.push_back(stride_.back() * s); + } + stride_.pop_back(); + } + if (static_shape_.empty()) { + static_shape_ = shape_; + } + } + if (static_stride_.empty()) { + static_stride_ = stride_; + } + + if (stride_.size() != shape_.size()) { + throw std::runtime_error("Invalid stride"); + } + if (static_shape_.size() != shape_.size()) { + throw std::runtime_error("Invalid static shape"); + } + if (static_stride_.size() != stride_.size()) { + throw std::runtime_error("Invalid static stride"); + } +} + +auto tensor_layout::linear_index(array_view idx) const -> std::int64_t { + if (static_cast(idx.size()) != dim()) { + throw std::runtime_error("index order mismatch"); + } + std::int64_t l = 0; + for (std::size_t i = 0; i < idx.size(); ++i) { + l += idx[i] * stride_[i]; + } + return l; +} + +} // namespace tinytc::test + diff --git a/test/linalg_types.hpp b/test/linalg_types.hpp new file mode 100644 index 00000000..bf39c943 --- /dev/null +++ b/test/linalg_types.hpp @@ -0,0 +1,39 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LINALG_TYPES_20241023_HPP +#define LINALG_TYPES_20241023_HPP + +#include "tinytc/tinytc.hpp" + +#include +#include + +namespace tinytc::test { + +class tensor_layout { + public: + tensor_layout(array_view shape, array_view stride = {}, + array_view static_shape = {}, + array_view static_stride = {}); + + inline auto dim() const -> std::int64_t { return shape_.size(); } + inline auto size() const -> std::int64_t { return stride_.back() * shape_.back(); } + inline auto shape() const -> array_view { return {shape_}; } + inline auto shape(std::size_t i) const { return shape_[i]; } + inline auto stride() const -> array_view { return {stride_}; } + inline auto stride(std::size_t i) const { return stride_[i]; } + inline auto static_shape() const -> array_view { return {static_shape_}; } + inline auto static_shape(std::size_t i) const { return static_shape_[i]; } + inline auto static_stride() const -> array_view { return {static_stride_}; } + inline auto static_stride(std::size_t i) const { return static_stride_[i]; } + + auto linear_index(array_view idx) const -> std::int64_t; + + private: + std::vector shape_, stride_, static_shape_, static_stride_; +}; + +} // namespace tinytc::test + +#endif // LINALG_TYPES_20241023_HPP diff --git a/test/runtime_concept.hpp b/test/runtime_concept.hpp new file mode 100644 index 00000000..e4b53378 --- /dev/null +++ b/test/runtime_concept.hpp @@ -0,0 +1,49 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef RUNTIME_CONCEPT_20241023_HPP +#define RUNTIME_CONCEPT_20241023_HPP + +#include "tinytc/tinytc.hpp" + +#include + +namespace tinytc::test { + +template +concept test_runtime_gpu = + requires(T rt, std::size_t bytes, typename T::mem_t buf, typename T::const_mem_t const_buf, + void *dst, void const *src, int value, tinytc::recipe const &rec, tinytc::prog p, + typename T::kernel_bundle_t const &bundle, char const *name, + typename T::kernel_t &kernel, std::uint32_t arg_index, std::size_t arg_size, + const void *arg_value, ::tinytc::mem_type type, std::int64_t howmany) { + typename T::device_t; + typename T::context_t; + typename T::command_list_t; + typename T::recipe_handler_t; + typename T::kernel_bundle_t; + typename T::kernel_t; + typename T::mem_t; + typename T::const_mem_t; + { rt.create_buffer(bytes) } -> std::same_as; + rt.free_buffer(buf); + rt.fill_buffer(buf, value, bytes); + rt.memcpy_h2d(buf, src, bytes); + rt.memcpy_d2h(dst, const_buf, bytes); + { rt.get_core_info() } -> std::same_as; + { rt.get_device() } -> std::same_as; + { rt.get_context() } -> std::same_as; + { rt.get_command_list() } -> std::same_as; + { rt.get_recipe_handler(rec) } -> std::same_as; + { rt.get_kernel_bundle(p) } -> std::same_as; + { rt.get_kernel(bundle, name) } -> std::same_as; + rt.set_arg(kernel, arg_index, arg_size, arg_value); + rt.set_mem_arg(kernel, arg_index, arg_value, type); + rt.submit(kernel, howmany); + { rt.supports_fp64() } -> std::same_as; + rt.synchronize(); + }; + +} // namespace tinytc::test + +#endif // RUNTIME_CONCEPT_20241023_HPP diff --git a/test/smm.hpp b/test/smm.hpp deleted file mode 100644 index 83ca2f55..00000000 --- a/test/smm.hpp +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef SMM_20240314_HPP -#define SMM_20240314_HPP - -#include "tensor3.hpp" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.hpp" - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#define DOCTEST_TENSOR4_TEST(MM, NN, KK, HH) \ - do { \ - for (auto hh : HH) { \ - for (auto kk : KK) { \ - for (auto nn : NN) { \ - for (auto mm : MM) { \ - DOCTEST_SUBCASE((std::to_string(mm) + "x" + std::to_string(nn) + "x" + \ - std::to_string(kk) + "*" + std::to_string(hh)) \ - .c_str()) { \ - howmany = hh; \ - K = kk; \ - N = nn; \ - M = mm; \ - } \ - } \ - } \ - } \ - } \ - } while (false) - -#define TEST_PRECISIONS float, double - -template struct is_complex : public std::false_type {}; -template struct is_complex> : public std::true_type {}; -template inline constexpr bool is_complex_v = is_complex::value; - -template -void small_gemm_batched_ref(tinytc::transpose transA, tinytc::transpose transB, T alpha, - tensor3 const &A, tensor3 const &B, T beta, tensor3 &C) { - bool compatible = A.shape(2) == B.shape(2) && B.shape(2) == C.shape(2); - auto Arows = A.shape(0); - auto Acols = A.shape(1); - if (transA == tinytc::transpose::T) { - std::swap(Arows, Acols); - } - auto Brows = B.shape(0); - auto Bcols = B.shape(1); - if (transB == tinytc::transpose::T) { - std::swap(Brows, Bcols); - } - compatible = compatible && Arows == C.shape(0) && Bcols == C.shape(1) && Acols == Brows; - if (!compatible) { - throw std::runtime_error("incompatible matmul"); - } - for (std::uint32_t j = 0; j < C.shape(2); ++j) { - for (std::uint32_t n = 0; n < C.shape(1); ++n) { - for (std::uint32_t m = 0; m < C.shape(0); ++m) { - T c_acc = T(0.0); - for (std::uint32_t k = 0; k < Acols; ++k) { - auto const a = transA == tinytc::transpose::T ? A(k, m, j) : A(m, k, j); - auto const b = transB == tinytc::transpose::T ? B(n, k, j) : B(k, n, j); - c_acc += a * b; - } - C(m, n, j) = alpha * c_acc + beta * C(m, n, j); - } - } - } -} - -template -concept test_runtime_gpu = - requires(T rt, std::size_t bytes, typename T::mem_t buf, typename T::const_mem_t const_buf, - void *dst, void const *src, int value, tinytc::recipe const &rec) { - typename T::device_t; - typename T::context_t; - typename T::command_list_t; - typename T::recipe_handler_t; - typename T::mem_t; - typename T::const_mem_t; - { rt.create_buffer(bytes) } -> std::same_as; - rt.free_buffer(buf); - rt.fill_buffer(buf, value, bytes); - rt.memcpy_h2d(buf, src, bytes); - rt.memcpy_d2h(dst, const_buf, bytes); - { rt.get_core_info() } -> std::same_as; - { rt.get_device() } -> std::same_as; - { rt.get_context() } -> std::same_as; - { rt.get_command_list() } -> std::same_as; - { rt.get_recipe_handler(rec) } -> std::same_as; - { rt.supports_fp64() } -> std::same_as; - rt.synchronize(); - }; - -template -void check_small_gemm_batched(tinytc::transpose transA, tinytc::transpose transB, std::uint32_t M, - std::uint32_t N, std::uint32_t K, std::uint32_t ldA, - std::uint32_t strideA, std::uint32_t ldB, std::uint32_t strideB, - std::uint32_t ldC, std::uint32_t strideC, T alpha, T beta, - std::uint32_t howmany) { - auto const selA = [&](std::uint32_t N1, std::uint32_t N2) { - return transA == tinytc::transpose::T ? N2 : N1; - }; - auto const selB = [&](std::uint32_t N1, std::uint32_t N2) { - return transB == tinytc::transpose::T ? N2 : N1; - }; - - auto gpu_rt = std::make_shared(); - if constexpr (std::is_same_v || std::is_same_v>) { - if (!gpu_rt->supports_fp64()) { - WARN_MESSAGE(false, "Double precision tests need double precision device support"); - return; - } - } - - auto const fill = [](tensor3 &x) { - T *data = x.data(); - std::size_t n = x.size(); - for (std::size_t i = 0; i < n; ++i) { - constexpr std::size_t prime = 101; - if constexpr (is_complex_v) { - data[i] = T{static_cast((2 * i) % prime), - static_cast((2 * i + 1) % prime)}; - } else { - data[i] = static_cast(i % prime); - } - } - }; - - auto A_ref = tensor3({selA(M, K), selA(K, M), howmany}, {1, ldA, strideA}); - auto B_ref = tensor3({selB(K, N), selB(N, K), howmany}, {1, ldB, strideB}); - auto C_ref = tensor3({M, N, howmany}, {1, ldC, strideC}); - fill(A_ref); - fill(B_ref); - C_ref.set_zero(); - - small_gemm_batched_ref(transA, transB, alpha, A_ref, B_ref, beta, C_ref); - - auto A = gpu_rt->create_buffer(A_ref.size() * sizeof(T)); - auto B = gpu_rt->create_buffer(B_ref.size() * sizeof(T)); - auto C = gpu_rt->create_buffer(C_ref.size() * sizeof(T)); - gpu_rt->memcpy_h2d(A, A_ref.data(), A_ref.size() * sizeof(T)); - gpu_rt->memcpy_h2d(B, B_ref.data(), B_ref.size() * sizeof(T)); - gpu_rt->fill_buffer(C, 0, C_ref.size() * sizeof(T)); - - auto info = gpu_rt->get_core_info(); - - auto g = gpu_rt->get_recipe_handler( - tinytc::make_small_gemm_batched(info, tinytc::to_scalar_type_v, transA, transB, M, N, K, - ldA, strideA, ldB, strideB, ldC, strideC)); - tinytc::small_gemm_batched::set_args(g, howmany, alpha, A, B, beta, C); - g.submit(gpu_rt->get_command_list()); - gpu_rt->synchronize(); - - auto C_host = tensor3({M, N, howmany}, {1, ldC, strideC}); - gpu_rt->memcpy_d2h(C_host.data(), C, C_host.size() * sizeof(T)); - - compare(C_host, C_ref); - - gpu_rt->free_buffer(A); - gpu_rt->free_buffer(B); - gpu_rt->free_buffer(C); -} - -#endif // SMM_20240314_HPP diff --git a/test/sycl/CMakeLists.txt b/test/sycl/CMakeLists.txt index fc0fdaee..bf7e70d5 100644 --- a/test/sycl/CMakeLists.txt +++ b/test/sycl/CMakeLists.txt @@ -10,7 +10,7 @@ add_library(test-sycl-lib STATIC test_runtime.cpp) target_link_libraries(test-sycl-lib PUBLIC test-lib tinytc_sycl SYCL::SYCL) set_cxx_common_options(test-sycl-lib) -add_executable(test-sycl-smm smm.cpp) -target_link_libraries(test-sycl-smm PRIVATE test-sycl-lib) -doctest_discover_tests(test-sycl-smm) -set_cxx_common_options(test-sycl-smm) +add_executable(test-sycl-linalg linalg.cpp) +target_link_libraries(test-sycl-linalg PRIVATE test-sycl-lib test-lib-linalg) +doctest_discover_tests(test-sycl-linalg) +set_cxx_common_options(test-sycl-linalg) diff --git a/test/sycl/linalg.cpp b/test/sycl/linalg.cpp new file mode 100644 index 00000000..3b3d2a00 --- /dev/null +++ b/test/sycl/linalg.cpp @@ -0,0 +1,9 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "test_runtime.hpp" + +#define RUNTIME_CLASS sycl_test_runtime +#define RUNTIME_NAME "sycl" +#include "../linalg.hpp" +#undef RUNTIME_CLASS diff --git a/test/sycl/smm.cpp b/test/sycl/smm.cpp deleted file mode 100644 index 3c247a34..00000000 --- a/test/sycl/smm.cpp +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "../smm.hpp" -#include "../tensor3.hpp" -#include "test_runtime.hpp" - -#include "doctest/doctest.h" - -using namespace tinytc; - -TEST_CASE_TEMPLATE("sycl packed alpha=1 beta=0", T, TEST_PRECISIONS) { - auto KK = std::vector{1, 9}; - auto MM = std::vector{1, 13, 33, 65}; - auto NN = std::vector{1, 5, 37}; - auto HH = std::vector{1, 100}; - - std::uint32_t M, N, K, howmany; - DOCTEST_TENSOR4_TEST(MM, NN, KK, HH); - - check_small_gemm_batched(transpose::N, transpose::N, M, N, K, M, M * K, K, - K * N, M, M * N, 1.0, 0.0, howmany); -} - -TEST_CASE_TEMPLATE("sycl non-packed alpha=1 beta=0", T, TEST_PRECISIONS) { - std::uint32_t M = 16, N = 32, K = 8, howmany = 10; - std::uint32_t ldA = 20, ldB = 9, ldC = 24; - - check_small_gemm_batched(transpose::N, transpose::N, M, N, K, ldA, - ldA * ldB, ldB, ldB * 2 * N, ldC, ldC * 3 * N, - 1.0, 0.0, howmany); -} - -TEST_CASE_TEMPLATE("sycl packed alpha=1 beta=1", T, TEST_PRECISIONS) { - std::uint32_t M = 6, N = 33, K = 8, howmany = 5; - - check_small_gemm_batched(transpose::N, transpose::N, M, N, K, M, M * K, K, - K * N, M, M * N, 1.0, 1.0, howmany); -} - -TEST_CASE_TEMPLATE("sycl packed alpha=-1 beta=2", T, TEST_PRECISIONS) { - std::uint32_t M = 8, N = 16, K = 16, howmany = 5; - - check_small_gemm_batched(transpose::N, transpose::N, M, N, K, M, M * K, K, - K * N, M, M * N, -1.0, 2.0, howmany); -} - -TEST_CASE_TEMPLATE("sycl non-packed alpha=1 beta=0 transA transB", T, TEST_PRECISIONS) { - std::uint32_t M = 16, N = 32, K = 8, howmany = 10; - std::uint32_t ldA = 10, ldB = 32, ldC = 24; - - check_small_gemm_batched(transpose::T, transpose::T, M, N, K, ldA, - ldA * ldB, ldB, ldB * 2 * N, ldC, ldC * 3 * N, - 1.0, 0.0, howmany); -} diff --git a/test/sycl/test_runtime.cpp b/test/sycl/test_runtime.cpp index 64f5578f..44f028c5 100644 --- a/test/sycl/test_runtime.cpp +++ b/test/sycl/test_runtime.cpp @@ -3,6 +3,21 @@ #include "test_runtime.hpp" +sycl_test_runtime::sycl_test_runtime() { + switch (q_.get_backend()) { + case sycl::backend::ext_oneapi_level_zero: + arg_handler_ = std::make_unique(); + break; + case sycl::backend::opencl: + arg_handler_ = std::make_unique( + q_.get_device().get_platform()); + break; + default: + throw ::tinytc::status::unsupported_backend; + break; + }; +} + void sycl_test_runtime::memcpy(void *dst, const void *src, std::size_t bytes) { q_.memcpy(dst, src, bytes).wait(); } @@ -30,6 +45,24 @@ auto sycl_test_runtime::get_command_list() -> command_list_t { return q_; } auto sycl_test_runtime::get_recipe_handler(tinytc::recipe const &rec) -> recipe_handler_t { return tinytc::make_recipe_handler(q_, rec); } +auto sycl_test_runtime::get_kernel_bundle(tinytc::prog p) -> kernel_bundle_t { + return ::tinytc::make_kernel_bundle(q_.get_context(), q_.get_device(), std::move(p)); +} +auto sycl_test_runtime::get_kernel(kernel_bundle_t const &bundle, char const *name) -> kernel_t { + return ::tinytc::make_kernel(bundle, name); +} +void sycl_test_runtime::set_arg(kernel_t &kernel, std::uint32_t arg_index, std::size_t arg_size, + const void *arg_value) { + arg_handler_->set_arg(kernel, arg_index, arg_size, arg_value); +} +void sycl_test_runtime::set_mem_arg(kernel_t &kernel, std::uint32_t arg_index, + const void *arg_value, tinytc::mem_type type) { + arg_handler_->set_mem_arg(kernel, arg_index, arg_value, static_cast(type)); +} +void sycl_test_runtime::submit(kernel_t &kernel, std::int64_t howmany) { + auto exe_range = ::tinytc::get_execution_range(kernel, howmany); + q_.submit([&](sycl::handler &h) { h.parallel_for(exe_range, kernel); }); +} void sycl_test_runtime::synchronize() { q_.wait(); } bool sycl_test_runtime::supports_fp64() { return q_.get_device().has(sycl::aspect::fp64); } diff --git a/test/sycl/test_runtime.hpp b/test/sycl/test_runtime.hpp index be8bd7ca..6d795d65 100644 --- a/test/sycl/test_runtime.hpp +++ b/test/sycl/test_runtime.hpp @@ -4,6 +4,8 @@ #ifndef SYCL_TEST_RUNTIME_20240314_HPP #define SYCL_TEST_RUNTIME_20240314_HPP +#include "sycl/argument_handler.hpp" + #include #include @@ -16,9 +18,13 @@ class sycl_test_runtime { using context_t = sycl::context; using command_list_t = sycl::queue; using recipe_handler_t = tinytc::sycl_recipe_handler; + using kernel_bundle_t = sycl::kernel_bundle; + using kernel_t = sycl::kernel; using mem_t = void *; using const_mem_t = const void *; + sycl_test_runtime(); + auto create_buffer(std::size_t bytes) const -> mem_t; void free_buffer(mem_t buf) const; void fill_buffer(mem_t buf, int value, std::size_t bytes); @@ -30,6 +36,13 @@ class sycl_test_runtime { auto get_context() -> context_t; auto get_command_list() -> command_list_t; auto get_recipe_handler(tinytc::recipe const &rec) -> recipe_handler_t; + auto get_kernel_bundle(tinytc::prog p) -> kernel_bundle_t; + auto get_kernel(kernel_bundle_t const &bundle, char const *name) -> kernel_t; + void set_arg(kernel_t &kernel, std::uint32_t arg_index, std::size_t arg_size, + const void *arg_value); + void set_mem_arg(kernel_t &kernel, std::uint32_t arg_index, const void *arg_value, + tinytc::mem_type type); + void submit(kernel_t &kernel, std::int64_t howmany = 1); void synchronize(); bool supports_fp64(); @@ -38,6 +51,7 @@ class sycl_test_runtime { void memcpy(void *dst, void const *src, std::size_t bytes); command_list_t q_; + std::unique_ptr arg_handler_; }; #endif // SYCL_TEST_RUNTIME_20240314_HPP diff --git a/test/tensor3.hpp b/test/tensor3.hpp deleted file mode 100644 index 64e12331..00000000 --- a/test/tensor3.hpp +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef TENSOR3_20240314_HPP -#define TENSOR3_20240314_HPP - -#include "doctest/doctest.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -template class tensor3 { - public: - using real_t = T; - - tensor3(std::array const &shape, std::array const &stride) - : shape_(shape), stride_(stride), data_(size()) {} - - std::size_t size() const { return stride_.back() * shape_.back(); } - auto shape(std::size_t i) const { return shape_[i]; } - auto stride(std::size_t i) const { return stride_[i]; } - - T const &operator()(std::uint32_t i, std::uint32_t j, std::uint32_t k) const { - return data_[i * stride_[0] + j * stride_[1] + k * stride_[2]]; - } - T &operator()(std::uint32_t i, std::uint32_t j, std::uint32_t k) { - return data_[i * stride_[0] + j * stride_[1] + k * stride_[2]]; - } - - T *data() { return data_.data(); } - T const *data() const { return data_.data(); } - - void set_zero() { std::fill(data_.begin(), data_.end(), T(0)); } - - private: - std::array shape_, stride_; - std::vector data_; -}; - -template bool compare(tensor3 const &A, tensor3 const &B) { - bool compatible = - A.shape(0) == B.shape(0) && A.shape(1) == B.shape(1) && A.shape(2) == B.shape(2); - if (!compatible) { - throw std::runtime_error("incompatible compare"); - } - for (std::uint32_t k = 0; k < A.shape(2); ++k) { - for (std::uint32_t j = 0; j < A.shape(1); ++j) { - for (std::uint32_t i = 0; i < A.shape(0); ++i) { - constexpr auto eps = 10.0 * std::numeric_limits::epsilon(); - REQUIRE(std::abs(A(i, j, k) - B(i, j, k)) == doctest::Approx(0.0).epsilon(eps)); - } - } - } - return true; -} - -#endif // TENSOR3_20240314_HPP diff --git a/test/ze/CMakeLists.txt b/test/ze/CMakeLists.txt index 009d73a2..32bfa11e 100644 --- a/test/ze/CMakeLists.txt +++ b/test/ze/CMakeLists.txt @@ -16,8 +16,8 @@ target_link_libraries(test-ze-device PRIVATE test-ze-lib) doctest_discover_tests(test-ze-device) set_cxx_common_options(test-ze-device) -add_executable(test-ze-smm smm.cpp) -target_link_libraries(test-ze-smm PRIVATE test-ze-lib) -doctest_discover_tests(test-ze-smm) -set_cxx_common_options(test-ze-smm) +add_executable(test-ze-linalg linalg.cpp) +target_link_libraries(test-ze-linalg PRIVATE test-ze-lib test-lib-linalg) +doctest_discover_tests(test-ze-linalg) +set_cxx_common_options(test-ze-linalg) diff --git a/test/ze/linalg.cpp b/test/ze/linalg.cpp new file mode 100644 index 00000000..e1c39cd2 --- /dev/null +++ b/test/ze/linalg.cpp @@ -0,0 +1,9 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "test_runtime.hpp" + +#define RUNTIME_CLASS level_zero_test_runtime +#define RUNTIME_NAME "level zero" +#include "../linalg.hpp" +#undef RUNTIME_CLASS diff --git a/test/ze/smm.cpp b/test/ze/smm.cpp deleted file mode 100644 index 0a945ebc..00000000 --- a/test/ze/smm.cpp +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "../smm.hpp" -#include "test_runtime.hpp" - -#include "doctest/doctest.h" - -#include -#include - -using namespace tinytc; - -TEST_CASE_TEMPLATE("level zero packed alpha=1 beta=0", T, TEST_PRECISIONS) { - auto KK = std::vector{56}; - auto MM = std::vector{20, 53}; - auto NN = std::vector{5, 23}; - auto HH = std::vector{1, 101}; - - std::uint32_t M, N, K, howmany; - DOCTEST_TENSOR4_TEST(MM, NN, KK, HH); - - check_small_gemm_batched( - transpose::N, transpose::N, M, N, K, M, M * K, K, K * N, M, M * N, 1.0, 0.0, howmany); -} - -TEST_CASE_TEMPLATE("level zero packed complex alpha=1 beta=0", T, TEST_PRECISIONS) { - auto KK = std::vector{53}; - auto MM = std::vector{21, 42}; - auto NN = std::vector{7, 11}; - auto HH = std::vector{1, 101}; - - std::uint32_t M, N, K, howmany; - DOCTEST_TENSOR4_TEST(MM, NN, KK, HH); - - check_small_gemm_batched, level_zero_test_runtime>( - transpose::N, transpose::N, M, N, K, M, M * K, K, K * N, M, M * N, 1.0, 0.0, howmany); -} - -TEST_CASE_TEMPLATE("level zero packed complex alpha=(-1,-2) beta=(2,3)", T, TEST_PRECISIONS) { - std::uint32_t M = 8, N = 16, K = 16, howmany = 5; - - check_small_gemm_batched, level_zero_test_runtime>( - transpose::N, transpose::N, M, N, K, M, M * K, K, K * N, M, M * N, {-1.0, -2.0}, {2.0, 3.0}, - howmany); -} diff --git a/test/ze/test_runtime.cpp b/test/ze/test_runtime.cpp index a2671048..13a43e76 100644 --- a/test/ze/test_runtime.cpp +++ b/test/ze/test_runtime.cpp @@ -66,6 +66,26 @@ auto level_zero_test_runtime::get_command_list() -> command_list_t { return list auto level_zero_test_runtime::get_recipe_handler(tinytc::recipe const &rec) -> recipe_handler_t { return tinytc::make_recipe_handler(ctx_, dev_, rec); } +auto level_zero_test_runtime::get_kernel_bundle(tinytc::prog p) -> kernel_bundle_t { + return ::tinytc::make_kernel_bundle(ctx_, dev_, std::move(p)); +} +auto level_zero_test_runtime::get_kernel(kernel_bundle_t const &bundle, + char const *name) -> kernel_t { + return ::tinytc::make_kernel(bundle.get(), name); +} +void level_zero_test_runtime::set_arg(kernel_t &kernel, std::uint32_t arg_index, + std::size_t arg_size, const void *arg_value) { + ZE_CHECK_STATUS(zeKernelSetArgumentValue(kernel.get(), arg_index, arg_size, arg_value)); +} +void level_zero_test_runtime::set_mem_arg(kernel_t &kernel, std::uint32_t arg_index, + const void *arg_value, tinytc::mem_type) { + set_arg(kernel, arg_index, sizeof(arg_value), &arg_value); +} +void level_zero_test_runtime::submit(kernel_t &kernel, std::int64_t howmany) { + auto group_count = ::tinytc::get_group_count(howmany); + ZE_CHECK_STATUS( + zeCommandListAppendLaunchKernel(list_, kernel.get(), &group_count, nullptr, 0, nullptr)); +} void level_zero_test_runtime::synchronize() { ZE_CHECK_STATUS(zeCommandListHostSynchronize(list_, UINT64_MAX)); } diff --git a/test/ze/test_runtime.hpp b/test/ze/test_runtime.hpp index 3ffb7aa4..23cd7ca5 100644 --- a/test/ze/test_runtime.hpp +++ b/test/ze/test_runtime.hpp @@ -16,6 +16,8 @@ class level_zero_test_runtime { using context_t = ze_context_handle_t; using command_list_t = ze_command_list_handle_t; using recipe_handler_t = tinytc::level_zero_recipe_handler; + using kernel_bundle_t = tinytc::unique_handle; + using kernel_t = tinytc::unique_handle; using mem_t = void *; using const_mem_t = const void *; @@ -38,6 +40,13 @@ class level_zero_test_runtime { auto get_context() -> context_t; auto get_command_list() -> command_list_t; auto get_recipe_handler(tinytc::recipe const &rec) -> recipe_handler_t; + auto get_kernel_bundle(tinytc::prog p) -> kernel_bundle_t; + auto get_kernel(kernel_bundle_t const &bundle, char const *name) -> kernel_t; + void set_arg(kernel_t &kernel, std::uint32_t arg_index, std::size_t arg_size, + const void *arg_value); + void set_mem_arg(kernel_t &kernel, std::uint32_t arg_index, const void *arg_value, + tinytc::mem_type type); + void submit(kernel_t &kernel, std::int64_t howmany = 1); void synchronize(); bool supports_fp64(); From 926be88a2e62e5162f31e73f2fd9e0a848d81763 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 23 Oct 2024 17:35:10 +0200 Subject: [PATCH 069/297] More tests Signed-off-by: Carsten Uphoff --- src/pass/convert_to_opencl.cpp | 2 +- test/cl/CMakeLists.txt | 9 +-- test/doctest_util.hpp | 7 +++ test/linalg.hpp | 21 +++++++ test/linalg_ops.cpp | 35 +++++++++-- test/linalg_ops.hpp | 107 +++++++++++++++++++++++++++++---- 6 files changed, 160 insertions(+), 21 deletions(-) diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index d5679654..3389659b 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -1141,7 +1141,7 @@ std::vector convert_to_opencl_pass::operator()(ger_inst const &g) { n < std::move(trip_count), ++n) .body([&](clir::block_builder &bb) { auto b = bb.declare_assign(to_clir_ty(bt->element_ty()), "b", - B + (block + n) * bdv.stride(0)); + B[(block + n) * bdv.stride(0)]); auto Cb = bb.declare_assign(this->operator()(*ct), "Cb", C + (block + n) * cdv.stride(1)); auto m = bb.declare_assign(clir::generic_uint(), "m", diff --git a/test/cl/CMakeLists.txt b/test/cl/CMakeLists.txt index ebcd7b00..9fb43525 100644 --- a/test/cl/CMakeLists.txt +++ b/test/cl/CMakeLists.txt @@ -17,7 +17,8 @@ target_link_libraries(test-cl-device PRIVATE test-cl-lib) doctest_discover_tests(test-cl-device) set_cxx_common_options(test-cl-lib) -add_executable(test-cl-smm smm.cpp) -target_link_libraries(test-cl-smm PRIVATE test-cl-lib) -doctest_discover_tests(test-cl-smm) -set_cxx_common_options(test-cl-lib) +add_executable(test-cl-linalg linalg.cpp) +target_link_libraries(test-cl-linalg PRIVATE test-cl-lib test-lib-linalg) +doctest_discover_tests(test-cl-linalg) +set_cxx_common_options(test-cl-linalg) + diff --git a/test/doctest_util.hpp b/test/doctest_util.hpp index c9edf215..46b5baa9 100644 --- a/test/doctest_util.hpp +++ b/test/doctest_util.hpp @@ -4,6 +4,13 @@ #ifndef DOCTEST_UTIL_20241023_HPP #define DOCTEST_UTIL_20241023_HPP +#define DOCTEST_TENSOR1_TEST(MM) \ + do { \ + for (auto mm : MM) { \ + DOCTEST_SUBCASE((std::to_string(mm)).c_str()) { M = mm; } \ + } \ + } while (false) + #define DOCTEST_TENSOR2_TEST(MM, NN) \ do { \ for (auto nn : NN) { \ diff --git a/test/linalg.hpp b/test/linalg.hpp index 19e75629..029f3f98 100644 --- a/test/linalg.hpp +++ b/test/linalg.hpp @@ -92,3 +92,24 @@ TEST_CASE_TEMPLATE(RUNTIME_NAME " gemm packed complex alpha=(-1,-2) beta=(2,3)", test::gemm(transpose::N, transpose::N, {{M, K}}, {{K, N}}, {{M, N}}); test::test_blas_a3(op, {-1.0, -2.0}, {2.0, 3.0}); } + +TEST_CASE_TEMPLATE(RUNTIME_NAME " ger packed alpha=1 beta=0", T, TEST_PRECISIONS) { + auto MM = std::vector{10, 32, 45}; + auto NN = std::vector{1, 16, 17, 48}; + + std::int64_t M, N; + DOCTEST_TENSOR2_TEST(MM, NN); + + auto op = test::ger({{M}}, {{N}}, {{M, N}}); + test::test_blas_a3(op, 1, 0); +} + +TEST_CASE_TEMPLATE(RUNTIME_NAME " hadamard packed alpha=1 beta=0", T, TEST_PRECISIONS) { + auto MM = std::vector{10, 32, 45}; + + std::int64_t M; + DOCTEST_TENSOR1_TEST(MM); + + auto op = test::hadamard({{M}}, {{M}}, {{M}}); + test::test_blas_a3(op, 1, 0); +} diff --git a/test/linalg_ops.cpp b/test/linalg_ops.cpp index ddbec03d..12348840 100644 --- a/test/linalg_ops.cpp +++ b/test/linalg_ops.cpp @@ -24,10 +24,35 @@ auto gemm_mnk(transpose tA, transpose tB, tensor_layout const &A, tensor_layout return {M, N, K}; } -auto make_gemm_prog(char const *name, transpose tA, transpose tB, tensor_layout const &layoutA, - tensor_layout const &layoutB, tensor_layout const &layoutC, - scalar_type alpha_ty, scalar_type A_ty, scalar_type B_ty, scalar_type beta_ty, - scalar_type C_ty) -> prog { +auto ger_mn(tensor_layout const &A, tensor_layout const &B, + tensor_layout const &C) -> std::array { + if (A.dim() != 1 || B.dim() != 1 || C.dim() != 2) { + throw std::runtime_error("expected vectors and matrix"); + } + const auto M = C.shape(0); + const auto N = C.shape(1); + if (M != A.shape(0) || N != B.shape(0)) { + throw std::runtime_error("incompatible ger"); + } + return {M, N}; +} + +auto hadamard_m(tensor_layout const &A, tensor_layout const &B, + tensor_layout const &C) -> std::int64_t { + if (A.dim() != 1 || B.dim() != 1 || C.dim() != 1) { + throw std::runtime_error("expected vectors"); + } + const auto M = C.shape(0); + if (M != A.shape(0) || M != B.shape(0)) { + throw std::runtime_error("incompatible hadamard"); + } + return M; +} + +auto make_blas_a3_prog(char const *name, tensor_layout const &layoutA, tensor_layout const &layoutB, + tensor_layout const &layoutC, scalar_type alpha_ty, scalar_type A_ty, + scalar_type B_ty, scalar_type beta_ty, scalar_type C_ty, + std::function)> make_op) -> prog { auto ctx = make_compiler_context(); auto const alphat = get_scalar(ctx, alpha_ty); @@ -57,7 +82,7 @@ auto make_gemm_prog(char const *name, transpose tA, transpose tB, tensor_layout auto bb = region_builder{fn_body}; - bb.add(make_gemm(tA, tB, false, params[0], params[1], params[2], params[3], params[4])); + make_op(bb, params); p.add_function(std::move(f)); diff --git a/test/linalg_ops.hpp b/test/linalg_ops.hpp index e4612341..2d859d69 100644 --- a/test/linalg_ops.hpp +++ b/test/linalg_ops.hpp @@ -11,10 +11,23 @@ #include #include #include +#include #include namespace tinytc::test { +auto gemm_mnk(transpose tA, transpose tB, tensor_layout const &A, tensor_layout const &B, + tensor_layout const &C) -> std::array; +auto ger_mn(tensor_layout const &A, tensor_layout const &B, + tensor_layout const &C) -> std::array; +auto hadamard_m(tensor_layout const &A, tensor_layout const &B, + tensor_layout const &C) -> std::int64_t; + +auto make_blas_a3_prog(char const *name, tensor_layout const &layoutA, tensor_layout const &layoutB, + tensor_layout const &layoutC, scalar_type alpha_ty, scalar_type A_ty, + scalar_type B_ty, scalar_type beta_ty, scalar_type C_ty, + std::function)> make_op) -> prog; + template concept op_blas_a3 = requires(T op, typename T::alpha_type alpha, typename T::beta_type beta, typename T::A_type const *A_ref, typename T::B_type const *B_ref, @@ -32,14 +45,6 @@ concept op_blas_a3 = requires(T op, typename T::alpha_type alpha, typename T::be op.reference_impl(alpha, A_ref, B_ref, beta, C_ref); }; -auto gemm_mnk(transpose tA, transpose tB, tensor_layout const &A, tensor_layout const &B, - tensor_layout const &C) -> std::array; - -auto make_gemm_prog(char const *name, transpose tA, transpose tB, tensor_layout const &layoutA, - tensor_layout const &layoutB, tensor_layout const &layoutC, - scalar_type alpha_ty, scalar_type A_ty, scalar_type B_ty, scalar_type beta_ty, - scalar_type C_ty) -> prog; - template class gemm { public: using alpha_type = AlphaT; @@ -59,9 +64,13 @@ template tensor_layout const & { return lC_; } auto make_prog() const -> prog { - return make_gemm_prog(kernel_name, tA_, tB_, lA_, lB_, lC_, to_scalar_type_v, - to_scalar_type_v, to_scalar_type_v, to_scalar_type_v, - to_scalar_type_v); + return make_blas_a3_prog(kernel_name, lA_, lB_, lC_, to_scalar_type_v, + to_scalar_type_v, to_scalar_type_v, + to_scalar_type_v, to_scalar_type_v, + [&](region_builder &bb, array_view params) { + bb.add(make_gemm(tA_, tB_, false, params[0], params[1], + params[2], params[3], params[4])); + }); } void reference_impl(AlphaT alpha, AT const *A, BT const *B, BetaT beta, CT *C) { const auto [M, N, K] = gemm_mnk(tA_, tB_, lA_, lB_, lC_); @@ -90,6 +99,82 @@ template class ger { + public: + using alpha_type = AlphaT; + using A_type = AT; + using B_type = BT; + using beta_type = BetaT; + using C_type = CT; + static constexpr char const *kernel_name = "ger"; + + ger(tensor_layout layoutA, tensor_layout layoutB, tensor_layout layoutC) + : lA_{std::move(layoutA)}, lB_{std::move(layoutB)}, lC_{std::move(layoutC)} {} + + auto lA() const -> tensor_layout const & { return lA_; } + auto lB() const -> tensor_layout const & { return lB_; } + auto lC() const -> tensor_layout const & { return lC_; } + + auto make_prog() const -> prog { + return make_blas_a3_prog( + kernel_name, lA_, lB_, lC_, to_scalar_type_v, to_scalar_type_v, + to_scalar_type_v, to_scalar_type_v, to_scalar_type_v, + [&](region_builder &bb, array_view params) { + bb.add(make_ger(false, params[0], params[1], params[2], params[3], params[4])); + }); + } + void reference_impl(AlphaT alpha, AT const *A, BT const *B, BetaT beta, CT *C) { + const auto [M, N] = ger_mn(lA_, lB_, lC_); + for (std::int64_t n = 0; n < N; ++n) { + for (std::int64_t m = 0; m < M; ++m) { + auto ab = A[lA_.linear_index({m})] * B[lB_.linear_index({n})]; + auto &c = C[lC_.linear_index({m, n})]; + c = alpha * ab + beta * c; + } + } + } + + private: + tensor_layout lA_, lB_, lC_; +}; + +template class hadamard { + public: + using alpha_type = AlphaT; + using A_type = AT; + using B_type = BT; + using beta_type = BetaT; + using C_type = CT; + static constexpr char const *kernel_name = "hadamard"; + + hadamard(tensor_layout layoutA, tensor_layout layoutB, tensor_layout layoutC) + : lA_{std::move(layoutA)}, lB_{std::move(layoutB)}, lC_{std::move(layoutC)} {} + + auto lA() const -> tensor_layout const & { return lA_; } + auto lB() const -> tensor_layout const & { return lB_; } + auto lC() const -> tensor_layout const & { return lC_; } + + auto make_prog() const -> prog { + return make_blas_a3_prog( + kernel_name, lA_, lB_, lC_, to_scalar_type_v, to_scalar_type_v, + to_scalar_type_v, to_scalar_type_v, to_scalar_type_v, + [&](region_builder &bb, array_view params) { + bb.add(make_hadamard(false, params[0], params[1], params[2], params[3], params[4])); + }); + } + void reference_impl(AlphaT alpha, AT const *A, BT const *B, BetaT beta, CT *C) { + const auto M = hadamard_m(lA_, lB_, lC_); + for (std::int64_t m = 0; m < M; ++m) { + auto ab = A[lA_.linear_index({m})] * B[lB_.linear_index({m})]; + auto &c = C[lC_.linear_index({m})]; + c = alpha * ab + beta * c; + } + } + + private: + tensor_layout lA_, lB_, lC_; +}; + } // namespace tinytc::test #endif // LINALG_OPS_20241023_HPP From 317bf3bd3b3b4ef09ff4f9ff214da877e857d005 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 24 Oct 2024 17:01:23 +0200 Subject: [PATCH 070/297] Bugfix Signed-off-by: Carsten Uphoff --- src/node/inst_node.hpp | 2 +- src/node/region_node.cpp | 8 +++++++- src/node/value_node.cpp | 6 ++++++ src/node/value_node.hpp | 1 + src/pass/constant_folding.hpp | 13 +++++++++++-- 5 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 463e2c60..9f058298 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -186,7 +186,7 @@ struct tinytc_inst : tinytc::ilist_node_with_parent return child_regions_end_ - child_regions_begin_; } - inline constexpr auto kind() const -> tinytc::inst_execution_kind { + inline auto kind() const -> tinytc::inst_execution_kind { switch (type_id()) { case tinytc::IK::alloca: case tinytc::IK::barrier: diff --git a/src/node/region_node.cpp b/src/node/region_node.cpp index 31266172..a5b42aad 100644 --- a/src/node/region_node.cpp +++ b/src/node/region_node.cpp @@ -29,7 +29,13 @@ tinytc_region::tinytc_region(array_view param_types, locatio set_params(std::move(param_types), lc); } -tinytc_region::~tinytc_region() {} +tinytc_region::~tinytc_region() { + // Erase instructions in reverse order such that we delete value use before value definition + auto prev_it = insts_.end(); + while (prev_it != insts_.begin()) { + prev_it = insts_.erase(--prev_it); + } +} void tinytc_region::set_params(array_view param_types, location const &lc) { params_.resize(param_types.size()); diff --git a/src/node/value_node.cpp b/src/node/value_node.cpp index a859d0dd..35941bad 100644 --- a/src/node/value_node.cpp +++ b/src/node/value_node.cpp @@ -3,11 +3,17 @@ #include "node/value_node.hpp" +#include + using namespace tinytc; tinytc_value::tinytc_value(tinytc_data_type_t ty, tinytc_inst_t def_inst, location const &lc) : ty_{std::move(ty)}, loc_{lc}, def_inst_{def_inst} {} +tinytc_value::~tinytc_value() { + assert(!has_uses() && "Destructor called for value that still has uses"); +} + auto tinytc_value::use_begin() -> use_iterator { return {first_use_}; } auto tinytc_value::use_end() -> use_iterator { return {nullptr}; } auto tinytc_value::uses() -> iterator_range_wrapper { diff --git a/src/node/value_node.hpp b/src/node/value_node.hpp index bb6348fe..c1b8f9ce 100644 --- a/src/node/value_node.hpp +++ b/src/node/value_node.hpp @@ -25,6 +25,7 @@ struct tinytc_value final { public: tinytc_value(tinytc_data_type_t ty = nullptr, tinytc_inst_t def_inst_ = nullptr, tinytc::location const &lc = {}); + ~tinytc_value(); tinytc_value(tinytc_value const &) = delete; tinytc_value(tinytc_value &&) = default; diff --git a/src/pass/constant_folding.hpp b/src/pass/constant_folding.hpp index baee354c..be478bcc 100644 --- a/src/pass/constant_folding.hpp +++ b/src/pass/constant_folding.hpp @@ -58,7 +58,11 @@ struct compute_unary_op { T val = 0; switch (operation) { case arithmetic_unary::abs: - val = a < 0 ? -a : a; + if constexpr (std::is_same_v) { + val = a; + } else { + val = a < 0 ? -a : a; + } break; case arithmetic_unary::neg: val = -a; @@ -163,7 +167,11 @@ struct compute_binary_op { val = a - b; break; case arithmetic::mul: - val = a * b; + if constexpr (std::is_same_v) { + val = a && b; + } else { + val = a * b; + } break; case arithmetic::div: val = a / b; @@ -282,6 +290,7 @@ struct compute_binop_identities { return &operand; } } + break; case arithmetic::and_: if (a == T{0}) { return make_constant(T{0}, operand.ty(), loc); From e414ad55bbaa49b8d52af55dd5110d45711593ef Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 24 Oct 2024 18:23:48 +0200 Subject: [PATCH 071/297] More bugfix Signed-off-by: Carsten Uphoff --- src/pass/convert_to_opencl.cpp | 35 ++++++++++++++++++------------ src/pass/dead_code_elimination.cpp | 4 ++-- src/pass/lower_linalg.cpp | 28 +++++++++++++++--------- test/linalg.hpp | 18 +++++++++++++++ test/opt/dead-code-elimination.ir | 22 +++++++++++++++++++ 5 files changed, 81 insertions(+), 26 deletions(-) diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 3389659b..cd0d603b 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -547,10 +547,14 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_lo auto rt = get_coopmatrix_type(c.result(0)); auto &dv = get_dope_vector(c.operand()); - const bool check_rows = c.checked() == checked_flag::rows || c.checked() == checked_flag::both; - const bool check_cols = c.checked() == checked_flag::cols || c.checked() == checked_flag::both; const int rmode = rt->distributed_mode(); const int omode = c.t() == transpose::T ? 1 - rmode : rmode; + const bool check_m = c.checked() == checked_flag::both || + (rmode == 0 && c.checked() == checked_flag::rows) || + (rmode == 1 && c.checked() == checked_flag::cols); + const bool check_k = c.checked() == checked_flag::both || + (rmode == 1 && c.checked() == checked_flag::rows) || + (rmode == 0 && c.checked() == checked_flag::cols); const bool enable_sub_group_reads = core_cfg_.block_read_write_supported && c.t() == transpose::N && ot->stride(omode) == 1; @@ -565,7 +569,7 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_lo declaration_assignment(visit(*this, *c.operand().ty()), pointer, val(c.operand()) + pv[0] * dv.stride(0) + pv[1] * dv.stride(1))); clir::var rem[2] = {}; - if (check_rows || check_cols) { + if (check_m || check_k) { clinst.emplace_back( declaration_assignment(to_clir_ty(scalar_type::index), rem[0], dv.shape(0) - pv[0])); clinst.emplace_back( @@ -575,7 +579,7 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_lo const std::int64_t num_blocks = rt->num_blocks(core_cfg_.subgroup_size); for (std::int64_t block = 0; block < num_blocks; ++block) { auto row_in_bounds = clir::var{}; - if (check_rows) { + if (check_m) { auto m = clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size; clinst.emplace_back(declaration_assignment(to_clir_ty(scalar_type::i1), row_in_bounds, m >= -pv[omode] && m < rem[omode])); @@ -589,11 +593,11 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_lo }; auto const remainder = rt->shape(rmode) - core_cfg_.subgroup_size * block; const bool needs_mask = remainder < core_cfg_.subgroup_size; - if (enable_sub_group_reads && !needs_mask && !check_rows) { + if (enable_sub_group_reads && !needs_mask && !check_m) { auto rhs = sub_group_block_read_helper( pointer + block * core_cfg_.subgroup_size + k * ot->stride(1), ot->element_ty(), to_clir_address_space(ot->addrspace())); - if (check_cols) { + if (check_k) { rhs = ternary_conditional(col_cond(), std::move(rhs), 0); } clinst.emplace_back(store(std::move(rhs))); @@ -602,10 +606,10 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_lo block * core_cfg_.subgroup_size) + k * ot->stride(1 - omode)]; clir::expr cond = {}; - if (check_rows) { + if (check_m) { cond = row_in_bounds; } - if (check_cols) { + if (check_k) { cond = cond ? cond && col_cond() : col_cond(); } if (needs_mask) { @@ -760,11 +764,14 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_st auto &dv = get_dope_vector(c.operand()); auto valv = val(c.val()); - const bool check_rows = c.checked() == checked_flag::rows || c.checked() == checked_flag::both; - const bool check_cols = c.checked() == checked_flag::cols || c.checked() == checked_flag::both; - const int vmode = vt->distributed_mode(); const int omode = vmode; + const bool check_m = c.checked() == checked_flag::both || + (vmode == 0 && c.checked() == checked_flag::rows) || + (vmode == 1 && c.checked() == checked_flag::cols); + const bool check_k = c.checked() == checked_flag::both || + (vmode == 1 && c.checked() == checked_flag::rows) || + (vmode == 0 && c.checked() == checked_flag::cols); auto clinst = std::vector{}; auto const len = vt->length(core_cfg_.subgroup_size); @@ -776,7 +783,7 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_st declaration_assignment(visit(*this, *c.operand().ty()), pointer, val(c.operand()) + pv[0] * dv.stride(0) + pv[1] * dv.stride(1))); clir::var rem[2] = {}; - if (check_rows || check_cols) { + if (check_m || check_k) { clinst.emplace_back( declaration_assignment(to_clir_ty(scalar_type::index), rem[0], dv.shape(0) - pv[0])); clinst.emplace_back( @@ -819,7 +826,7 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_st k * ot->stride(1 - omode); auto rhs = valv[k + block * vt->shape(1 - vmode)]; clir::expr cond = {}; - if (check_cols) { + if (check_k) { cond = k >= -pv[1 - omode] && k < rem[1 - omode]; } if (needs_mask) { @@ -837,7 +844,7 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_st } } - if (check_rows) { + if (check_m) { auto m = clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size; clinst.emplace_back(clir::if_selection_builder(m >= -pv[omode] && m < rem[omode]) .then([&](clir::block_builder &bb) { diff --git a/src/pass/dead_code_elimination.cpp b/src/pass/dead_code_elimination.cpp index 17775b8e..d0d51bd6 100644 --- a/src/pass/dead_code_elimination.cpp +++ b/src/pass/dead_code_elimination.cpp @@ -42,7 +42,7 @@ auto dead_code_analysis::operator()(inst_node &in) -> bool { auto dead_code_analysis::operator()(if_inst &in) -> bool { constant_inst *cond_const = dyn_cast(in.condition().defining_inst()); - if (cond_const) { + if (in.num_results() == 0 && cond_const) { // If-instruction is dead if condition is constant and false return std::holds_alternative(cond_const->value()) && std::get(cond_const->value()) == 0; @@ -54,7 +54,7 @@ auto dead_code_analysis::operator()(if_inst &in) -> bool { auto dead_code_analysis::operator()(loop_inst &in) -> bool { constant_inst *from_const = dyn_cast(in.from().defining_inst()); constant_inst *to_const = dyn_cast(in.to().defining_inst()); - if (from_const && to_const) { + if (in.num_results() == 0 && from_const && to_const) { // For-instruction is dead if from >= to return std::holds_alternative(from_const->value()) && std::holds_alternative(to_const->value()) && diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index 74c8a909..47aff94c 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -93,17 +93,25 @@ void gemm_microkernel(region_builder &bb, transpose tA, transpose tB, bool atomi auto c_zero = bb.add(make_constant_zero(index_ty, loc)); auto c_k_block_size = bb.add(make_constant(k_block_size, index_ty, loc)); - auto tmp = bb.add(make_arith(arithmetic::div, K, c_k_block_size, loc)); - auto K0 = bb.add(make_arith(arithmetic::mul, tmp, c_k_block_size, loc)); + auto tmp = instant_constant_fold_add(bb, make_arith(arithmetic::div, K, c_k_block_size, loc)); + auto K0 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, tmp, c_k_block_size, loc)); c_init = compute_c(bb, k_block_size, c_zero, K0, c_init); - auto needs_remainder = bb.add(make_cmp(cmp_condition::lt, K0, K, loc)); - bb.if_condition( - needs_remainder, - [&](region_builder &bb) { - auto c_next = compute_c(bb, 1, K0, K, c_init); - bb.add(make_yield(c_next, loc)); - }, - {coopmatrix_c_ty}, loc); + auto needs_remainder = instant_constant_fold_add(bb, make_cmp(cmp_condition::lt, K0, K, loc)); + auto r = get_int_constant(needs_remainder); + if (r) { + if (*r != 0) { + c_init = compute_c(bb, 1, K0, K, c_init); + } + } else { + auto remainder = bb.if_condition( + needs_remainder, + [&](region_builder &bb) { + auto c_next = compute_c(bb, 1, K0, K, c_init); + bb.add(make_yield(c_next, loc)); + }, + {coopmatrix_c_ty}, loc); + c_init = remainder[0]; + } auto alpha_ab = mixed_precision_coopmatrix_scale(bb, alpha, c_init, loc); if (atomic) { diff --git a/test/linalg.hpp b/test/linalg.hpp index 029f3f98..a19e1193 100644 --- a/test/linalg.hpp +++ b/test/linalg.hpp @@ -59,6 +59,24 @@ TEST_CASE_TEMPLATE(RUNTIME_NAME " gemm non-packed alpha=1 beta=0 transA transB", test::test_blas_a3(op, 1, 0); } +TEST_CASE_TEMPLATE(RUNTIME_NAME " gemm non-static M", T, TEST_PRECISIONS) { + std::int64_t M = 63, N = 43, K = 23; + + auto op = test::gemm( + transpose::N, transpose::N, {{M, K}, {1, M}, {dynamic, K}, {1, dynamic}}, {{K, N}, {1, K}}, + {{M, N}, {1, M}, {dynamic, N}, {1, dynamic}}); + test::test_blas_a3(op, 1, 1); +} + +TEST_CASE_TEMPLATE(RUNTIME_NAME " gemm non-static N", T, TEST_PRECISIONS) { + std::int64_t M = 63, N = 43, K = 23; + + auto op = test::gemm(transpose::N, transpose::N, {{M, K}, {1, M}}, + {{K, N}, {1, K}, {K, dynamic}, {1, K}}, + {{M, N}, {1, M}, {M, dynamic}, {1, M}}); + test::test_blas_a3(op, 1, 1); +} + TEST_CASE_TEMPLATE(RUNTIME_NAME " gemm non-static", T, TEST_PRECISIONS) { std::int64_t M = 63, N = 43, K = 23; diff --git a/test/opt/dead-code-elimination.ir b/test/opt/dead-code-elimination.ir index 5f63d6d1..a7e66f66 100644 --- a/test/opt/dead-code-elimination.ir +++ b/test/opt/dead-code-elimination.ir @@ -21,6 +21,28 @@ func @dead_if(%a: memref) { ; CHECK-NEXT: } } +func @dead_if_with_yield(%a: memref) { + %c0 = constant 0 -> i1 + %0 = if %c0 -> (f64) { + %c42 = constant 42.0 -> f64 + yield %c42 : f64 + } else { + %c43 = constant 43.0 -> f64 + yield %c43 : f64 + } + store %0, %a[] : memref +; Cannot eliminate if that returns results currently +; CHECK-LABEL: func @dead_if_with_yield({{.*}} +; CHECK: %0 = if %c0 { +; CHECK-NEXT: %c42 = constant 0x1.5p+5 -> f64 +; CHECK-NEXT: yield %c42 : f64 +; CHECK-NEXT: } else { +; CHECK-NEXT: %c43 = constant 0x1.58p+5 -> f64 +; CHECK-NEXT: yield %c43 : f64 +; CHECK-NEXT: } +; CHECK-NEXT: store %0, %a[] : memref +} + func @dead_loop(%a: memref) { %c2 = constant 2 -> index for %0=%c2,%c2 { From 4fd33d8a1e9626e994723e5de9d47e8f096aa147 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 24 Oct 2024 20:26:30 +0200 Subject: [PATCH 072/297] More bugfixes Signed-off-by: Carsten Uphoff --- src/pass/convert_to_opencl.cpp | 26 ++--- test/codegen/axpby1.ir | 4 +- test/codegen/coopmatrix_load.ir | 176 ++++++++++++++--------------- test/codegen/coopmatrix_mul_add.ir | 32 +++--- test/codegen/coopmatrix_store.ir | 24 ++-- 5 files changed, 131 insertions(+), 131 deletions(-) diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index cd0d603b..48ece310 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -545,7 +545,7 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_lo auto lhs_ty = visit(*this, *c.result(0).ty()); auto ot = get_memref_type(c.operand()); auto rt = get_coopmatrix_type(c.result(0)); - auto &dv = get_dope_vector(c.operand()); + auto &odv = get_dope_vector(c.operand()); const int rmode = rt->distributed_mode(); const int omode = c.t() == transpose::T ? 1 - rmode : rmode; @@ -567,13 +567,13 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_lo auto pointer = clir::var{}; clinst.emplace_back( declaration_assignment(visit(*this, *c.operand().ty()), pointer, - val(c.operand()) + pv[0] * dv.stride(0) + pv[1] * dv.stride(1))); + val(c.operand()) + pv[0] * odv.stride(0) + pv[1] * odv.stride(1))); clir::var rem[2] = {}; if (check_m || check_k) { clinst.emplace_back( - declaration_assignment(to_clir_ty(scalar_type::index), rem[0], dv.shape(0) - pv[0])); + declaration_assignment(to_clir_ty(scalar_type::index), rem[0], odv.shape(0) - pv[0])); clinst.emplace_back( - declaration_assignment(to_clir_ty(scalar_type::index), rem[1], dv.shape(1) - pv[1])); + declaration_assignment(to_clir_ty(scalar_type::index), rem[1], odv.shape(1) - pv[1])); } const std::int64_t num_blocks = rt->num_blocks(core_cfg_.subgroup_size); @@ -595,16 +595,16 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_lo const bool needs_mask = remainder < core_cfg_.subgroup_size; if (enable_sub_group_reads && !needs_mask && !check_m) { auto rhs = sub_group_block_read_helper( - pointer + block * core_cfg_.subgroup_size + k * ot->stride(1), ot->element_ty(), + pointer + block * core_cfg_.subgroup_size + k * odv.stride(1), ot->element_ty(), to_clir_address_space(ot->addrspace())); if (check_k) { rhs = ternary_conditional(col_cond(), std::move(rhs), 0); } clinst.emplace_back(store(std::move(rhs))); } else { - auto rhs = pointer[ot->stride(omode) * (clir::get_sub_group_local_id() + + auto rhs = pointer[odv.stride(omode) * (clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size) + - k * ot->stride(1 - omode)]; + k * odv.stride(1 - omode)]; clir::expr cond = {}; if (check_m) { cond = row_in_bounds; @@ -761,7 +761,7 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_sc std::vector convert_to_opencl_pass::operator()(cooperative_matrix_store_inst const &c) { auto ot = get_memref_type(c.operand()); auto vt = get_coopmatrix_type(c.val()); - auto &dv = get_dope_vector(c.operand()); + auto &odv = get_dope_vector(c.operand()); auto valv = val(c.val()); const int vmode = vt->distributed_mode(); @@ -781,13 +781,13 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_st auto pointer = clir::var{}; clinst.emplace_back( declaration_assignment(visit(*this, *c.operand().ty()), pointer, - val(c.operand()) + pv[0] * dv.stride(0) + pv[1] * dv.stride(1))); + val(c.operand()) + pv[0] * odv.stride(0) + pv[1] * odv.stride(1))); clir::var rem[2] = {}; if (check_m || check_k) { clinst.emplace_back( - declaration_assignment(to_clir_ty(scalar_type::index), rem[0], dv.shape(0) - pv[0])); + declaration_assignment(to_clir_ty(scalar_type::index), rem[0], odv.shape(0) - pv[0])); clinst.emplace_back( - declaration_assignment(to_clir_ty(scalar_type::index), rem[1], dv.shape(1) - pv[1])); + declaration_assignment(to_clir_ty(scalar_type::index), rem[1], odv.shape(1) - pv[1])); } auto atomic_pointer = @@ -821,9 +821,9 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_st auto const remainder = vt->shape(vmode) - core_cfg_.subgroup_size * block; const bool needs_mask = remainder < core_cfg_.subgroup_size; - auto offset = ot->stride(omode) * + auto offset = odv.stride(omode) * (clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size) + - k * ot->stride(1 - omode); + k * odv.stride(1 - omode); auto rhs = valv[k + block * vt->shape(1 - vmode)]; clir::expr cond = {}; if (check_k) { diff --git a/test/codegen/axpby1.ir b/test/codegen/axpby1.ir index 4ff3517e..9ebd146c 100644 --- a/test/codegen/axpby1.ir +++ b/test/codegen/axpby1.ir @@ -7,9 +7,9 @@ func @axpby0(%alpha: f32, %A: memref, %B: memref) { axpby.n %alpha, %A, %z, %B : f32, memref, f32, memref } -func @axpby1(%alpha: f64, %A: memref>, %B: memref) { +func @axpby1(%alpha: f32, %A: memref>, %B: memref) { %z = constant 0.0 -> f32 - axpby.n %alpha, %A, %z, %B : f64, memref>, f32, memref + axpby.n %alpha, %A, %z, %B : f32, memref>, f32, memref } func @axpby2(%alpha: f32, %A: memref, %B: memref) { diff --git a/test/codegen/coopmatrix_load.ir b/test/codegen/coopmatrix_load.ir index 4feddee4..d0b51885 100644 --- a/test/codegen/coopmatrix_load.ir +++ b/test/codegen/coopmatrix_load.ir @@ -7,14 +7,14 @@ func @coopmatrix_a_load_n(%A: memref, %x: index, %y: index) subgroup_ ; CHECK-LABEL: void coopmatrix_a_load_n({{.*}} ; CHECK: float x1[8]; ; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; -; CHECK-NEXT: x1[0] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 0))); -; CHECK-NEXT: x1[1] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 64))); -; CHECK-NEXT: x1[2] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 128))); -; CHECK-NEXT: x1[3] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 192))); -; CHECK-NEXT: x1[4] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 256))); -; CHECK-NEXT: x1[5] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 320))); -; CHECK-NEXT: x1[6] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 384))); -; CHECK-NEXT: x1[7] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 448))); +; CHECK-NEXT: x1[0] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 0 * 64))); +; CHECK-NEXT: x1[1] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 1 * 64))); +; CHECK-NEXT: x1[2] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 2 * 64))); +; CHECK-NEXT: x1[3] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 3 * 64))); +; CHECK-NEXT: x1[4] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 4 * 64))); +; CHECK-NEXT: x1[5] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 5 * 64))); +; CHECK-NEXT: x1[6] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 6 * 64))); +; CHECK-NEXT: x1[7] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 7 * 64))); } func @coopmatrix_a_load_n_rows_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { @@ -25,11 +25,11 @@ func @coopmatrix_a_load_n_rows_checked(%A: memref, %x: index, %y: ind ; CHECK-NEXT: long x3 = 64 - x; ; CHECK-NEXT: long x4 = 48 - y; ; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x3; -; CHECK-NEXT: x1[0] = x5 ? x2[1 * (get_sub_group_local_id() + 0) + 0] : 0; -; CHECK-NEXT: x1[1] = x5 ? x2[1 * (get_sub_group_local_id() + 0) + 64] : 0; +; CHECK-NEXT: x1[0] = x5 ? x2[1 * (get_sub_group_local_id() + 0) + 0 * 64] : 0; +; CHECK-NEXT: x1[1] = x5 ? x2[1 * (get_sub_group_local_id() + 0) + 1 * 64] : 0; ; CHECK-NEXT: bool x6 = get_sub_group_local_id() + 16 >= -x && get_sub_group_local_id() + 16 < x3; -; CHECK-NEXT: x1[2] = x6 ? x2[1 * (get_sub_group_local_id() + 16) + 0] : 0; -; CHECK-NEXT: x1[3] = x6 ? x2[1 * (get_sub_group_local_id() + 16) + 64] : 0; +; CHECK-NEXT: x1[2] = x6 ? x2[1 * (get_sub_group_local_id() + 16) + 0 * 64] : 0; +; CHECK-NEXT: x1[3] = x6 ? x2[1 * (get_sub_group_local_id() + 16) + 1 * 64] : 0; } func @coopmatrix_a_load_n_cols_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { @@ -39,10 +39,10 @@ func @coopmatrix_a_load_n_cols_checked(%A: memref, %x: index, %y: ind ; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; ; CHECK-NEXT: long x3 = 64 - x; ; CHECK-NEXT: long x4 = 48 - y; -; CHECK-NEXT: x1[0] = 0 >= -y && 0 < x4 ? as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 0))) : 0; -; CHECK-NEXT: x1[1] = 1 >= -y && 1 < x4 ? as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 64))) : 0; -; CHECK-NEXT: x1[2] = 0 >= -y && 0 < x4 ? as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 16 + 0))) : 0; -; CHECK-NEXT: x1[3] = 1 >= -y && 1 < x4 ? as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 16 + 64))) : 0; +; CHECK-NEXT: x1[0] = 0 >= -y && 0 < x4 ? as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 0 * 64))) : 0; +; CHECK-NEXT: x1[1] = 1 >= -y && 1 < x4 ? as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 1 * 64))) : 0; +; CHECK-NEXT: x1[2] = 0 >= -y && 0 < x4 ? as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 16 + 0 * 64))) : 0; +; CHECK-NEXT: x1[3] = 1 >= -y && 1 < x4 ? as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 16 + 1 * 64))) : 0; } func @coopmatrix_a_load_n_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { @@ -53,23 +53,23 @@ func @coopmatrix_a_load_n_checked(%A: memref, %x: index, %y: index) s ; CHECK-NEXT: long x3 = 64 - x; ; CHECK-NEXT: long x4 = 48 - y; ; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x3; -; CHECK-NEXT: x1[0] = x5 && (0 >= -y && 0 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 0] : 0; -; CHECK-NEXT: x1[1] = x5 && (1 >= -y && 1 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 64] : 0; -; CHECK-NEXT: x1[2] = x5 && (2 >= -y && 2 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 128] : 0; -; CHECK-NEXT: x1[3] = x5 && (3 >= -y && 3 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 192] : 0; -; CHECK-NEXT: x1[4] = x5 && (4 >= -y && 4 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 256] : 0; -; CHECK-NEXT: x1[5] = x5 && (5 >= -y && 5 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 320] : 0; -; CHECK-NEXT: x1[6] = x5 && (6 >= -y && 6 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 384] : 0; -; CHECK-NEXT: x1[7] = x5 && (7 >= -y && 7 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 448] : 0; +; CHECK-NEXT: x1[0] = x5 && (0 >= -y && 0 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 0 * 64] : 0; +; CHECK-NEXT: x1[1] = x5 && (1 >= -y && 1 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 1 * 64] : 0; +; CHECK-NEXT: x1[2] = x5 && (2 >= -y && 2 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 2 * 64] : 0; +; CHECK-NEXT: x1[3] = x5 && (3 >= -y && 3 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 3 * 64] : 0; +; CHECK-NEXT: x1[4] = x5 && (4 >= -y && 4 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 4 * 64] : 0; +; CHECK-NEXT: x1[5] = x5 && (5 >= -y && 5 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 5 * 64] : 0; +; CHECK-NEXT: x1[6] = x5 && (6 >= -y && 6 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 6 * 64] : 0; +; CHECK-NEXT: x1[7] = x5 && (7 >= -y && 7 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 7 * 64] : 0; ; CHECK-NEXT: bool x6 = get_sub_group_local_id() + 16 >= -x && get_sub_group_local_id() + 16 < x3; -; CHECK-NEXT: x1[8] = x6 && (0 >= -y && 0 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 0] : 0; -; CHECK-NEXT: x1[9] = x6 && (1 >= -y && 1 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 64] : 0; -; CHECK-NEXT: x1[10] = x6 && (2 >= -y && 2 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 128] : 0; -; CHECK-NEXT: x1[11] = x6 && (3 >= -y && 3 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 192] : 0; -; CHECK-NEXT: x1[12] = x6 && (4 >= -y && 4 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 256] : 0; -; CHECK-NEXT: x1[13] = x6 && (5 >= -y && 5 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 320] : 0; -; CHECK-NEXT: x1[14] = x6 && (6 >= -y && 6 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 384] : 0; -; CHECK-NEXT: x1[15] = x6 && (7 >= -y && 7 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 448] : 0; +; CHECK-NEXT: x1[8] = x6 && (0 >= -y && 0 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 0 * 64] : 0; +; CHECK-NEXT: x1[9] = x6 && (1 >= -y && 1 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 1 * 64] : 0; +; CHECK-NEXT: x1[10] = x6 && (2 >= -y && 2 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 2 * 64] : 0; +; CHECK-NEXT: x1[11] = x6 && (3 >= -y && 3 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 3 * 64] : 0; +; CHECK-NEXT: x1[12] = x6 && (4 >= -y && 4 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 4 * 64] : 0; +; CHECK-NEXT: x1[13] = x6 && (5 >= -y && 5 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 5 * 64] : 0; +; CHECK-NEXT: x1[14] = x6 && (6 >= -y && 6 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 6 * 64] : 0; +; CHECK-NEXT: x1[15] = x6 && (7 >= -y && 7 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 7 * 64] : 0; } func @coopmatrix_a_load_t(%A: memref, %x: index, %y: index) subgroup_size(16) { @@ -77,14 +77,14 @@ func @coopmatrix_a_load_t(%A: memref, %x: index, %y: index) subgroup_ ; CHECK-LABEL: void coopmatrix_a_load_t({{.*}} ; CHECK: float x1[8]; ; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; -; CHECK-NEXT: x1[0] = x2[64 * (get_sub_group_local_id() + 0) + 0]; -; CHECK-NEXT: x1[1] = x2[64 * (get_sub_group_local_id() + 0) + 1]; -; CHECK-NEXT: x1[2] = x2[64 * (get_sub_group_local_id() + 0) + 2]; -; CHECK-NEXT: x1[3] = x2[64 * (get_sub_group_local_id() + 0) + 3]; -; CHECK-NEXT: x1[4] = x2[64 * (get_sub_group_local_id() + 0) + 4]; -; CHECK-NEXT: x1[5] = x2[64 * (get_sub_group_local_id() + 0) + 5]; -; CHECK-NEXT: x1[6] = x2[64 * (get_sub_group_local_id() + 0) + 6]; -; CHECK-NEXT: x1[7] = x2[64 * (get_sub_group_local_id() + 0) + 7]; +; CHECK-NEXT: x1[0] = x2[64 * (get_sub_group_local_id() + 0) + 0 * 1]; +; CHECK-NEXT: x1[1] = x2[64 * (get_sub_group_local_id() + 0) + 1 * 1]; +; CHECK-NEXT: x1[2] = x2[64 * (get_sub_group_local_id() + 0) + 2 * 1]; +; CHECK-NEXT: x1[3] = x2[64 * (get_sub_group_local_id() + 0) + 3 * 1]; +; CHECK-NEXT: x1[4] = x2[64 * (get_sub_group_local_id() + 0) + 4 * 1]; +; CHECK-NEXT: x1[5] = x2[64 * (get_sub_group_local_id() + 0) + 5 * 1]; +; CHECK-NEXT: x1[6] = x2[64 * (get_sub_group_local_id() + 0) + 6 * 1]; +; CHECK-NEXT: x1[7] = x2[64 * (get_sub_group_local_id() + 0) + 7 * 1]; } func @coopmatrix_a_load_t_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { @@ -95,14 +95,14 @@ func @coopmatrix_a_load_t_checked(%A: memref, %x: index, %y: index) s ; CHECK-NEXT: long x3 = 64 - x; ; CHECK-NEXT: long x4 = 48 - y; ; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -y && get_sub_group_local_id() + 0 < x4; -; CHECK-NEXT: x1[0] = x5 && (0 >= -x && 0 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 0] : 0; -; CHECK-NEXT: x1[1] = x5 && (1 >= -x && 1 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 1] : 0; -; CHECK-NEXT: x1[2] = x5 && (2 >= -x && 2 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 2] : 0; -; CHECK-NEXT: x1[3] = x5 && (3 >= -x && 3 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 3] : 0; -; CHECK-NEXT: x1[4] = x5 && (4 >= -x && 4 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 4] : 0; -; CHECK-NEXT: x1[5] = x5 && (5 >= -x && 5 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 5] : 0; -; CHECK-NEXT: x1[6] = x5 && (6 >= -x && 6 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 6] : 0; -; CHECK-NEXT: x1[7] = x5 && (7 >= -x && 7 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 7] : 0; +; CHECK-NEXT: x1[0] = x5 && (0 >= -x && 0 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 0 * 1] : 0; +; CHECK-NEXT: x1[1] = x5 && (1 >= -x && 1 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 1 * 1] : 0; +; CHECK-NEXT: x1[2] = x5 && (2 >= -x && 2 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 2 * 1] : 0; +; CHECK-NEXT: x1[3] = x5 && (3 >= -x && 3 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 3 * 1] : 0; +; CHECK-NEXT: x1[4] = x5 && (4 >= -x && 4 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 4 * 1] : 0; +; CHECK-NEXT: x1[5] = x5 && (5 >= -x && 5 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 5 * 1] : 0; +; CHECK-NEXT: x1[6] = x5 && (6 >= -x && 6 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 6 * 1] : 0; +; CHECK-NEXT: x1[7] = x5 && (7 >= -x && 7 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 7 * 1] : 0; } func @coopmatrix_b_load_n(%B: memref, %x: index, %y: index) subgroup_size(16) { @@ -110,14 +110,14 @@ func @coopmatrix_b_load_n(%B: memref, %x: index, %y: index) subgroup_ ; CHECK-LABEL: void coopmatrix_b_load_n({{.*}} ; CHECK: float x1[8]; ; CHECK-NEXT: global float* x2 = B + x * 1 + y * 64; -; CHECK-NEXT: x1[0] = x2[64 * (get_sub_group_local_id() + 0) + 0]; -; CHECK-NEXT: x1[1] = x2[64 * (get_sub_group_local_id() + 0) + 1]; -; CHECK-NEXT: x1[2] = x2[64 * (get_sub_group_local_id() + 0) + 2]; -; CHECK-NEXT: x1[3] = x2[64 * (get_sub_group_local_id() + 0) + 3]; -; CHECK-NEXT: x1[4] = x2[64 * (get_sub_group_local_id() + 0) + 4]; -; CHECK-NEXT: x1[5] = x2[64 * (get_sub_group_local_id() + 0) + 5]; -; CHECK-NEXT: x1[6] = x2[64 * (get_sub_group_local_id() + 0) + 6]; -; CHECK-NEXT: x1[7] = x2[64 * (get_sub_group_local_id() + 0) + 7]; +; CHECK-NEXT: x1[0] = x2[64 * (get_sub_group_local_id() + 0) + 0 * 1]; +; CHECK-NEXT: x1[1] = x2[64 * (get_sub_group_local_id() + 0) + 1 * 1]; +; CHECK-NEXT: x1[2] = x2[64 * (get_sub_group_local_id() + 0) + 2 * 1]; +; CHECK-NEXT: x1[3] = x2[64 * (get_sub_group_local_id() + 0) + 3 * 1]; +; CHECK-NEXT: x1[4] = x2[64 * (get_sub_group_local_id() + 0) + 4 * 1]; +; CHECK-NEXT: x1[5] = x2[64 * (get_sub_group_local_id() + 0) + 5 * 1]; +; CHECK-NEXT: x1[6] = x2[64 * (get_sub_group_local_id() + 0) + 6 * 1]; +; CHECK-NEXT: x1[7] = x2[64 * (get_sub_group_local_id() + 0) + 7 * 1]; } func @coopmatrix_b_load_n_checked(%B: memref, %x: index, %y: index) subgroup_size(16) { @@ -128,23 +128,23 @@ func @coopmatrix_b_load_n_checked(%B: memref, %x: index, %y: index) s ; CHECK-NEXT: long x3 = 64 - x; ; CHECK-NEXT: long x4 = 48 - y; ; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -y && get_sub_group_local_id() + 0 < x4; -; CHECK-NEXT: x1[0] = x5 && (0 >= -x && 0 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 0] : 0; -; CHECK-NEXT: x1[1] = x5 && (1 >= -x && 1 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 1] : 0; -; CHECK-NEXT: x1[2] = x5 && (2 >= -x && 2 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 2] : 0; -; CHECK-NEXT: x1[3] = x5 && (3 >= -x && 3 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 3] : 0; -; CHECK-NEXT: x1[4] = x5 && (4 >= -x && 4 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 4] : 0; -; CHECK-NEXT: x1[5] = x5 && (5 >= -x && 5 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 5] : 0; -; CHECK-NEXT: x1[6] = x5 && (6 >= -x && 6 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 6] : 0; -; CHECK-NEXT: x1[7] = x5 && (7 >= -x && 7 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 7] : 0; +; CHECK-NEXT: x1[0] = x5 && (0 >= -x && 0 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 0 * 1] : 0; +; CHECK-NEXT: x1[1] = x5 && (1 >= -x && 1 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 1 * 1] : 0; +; CHECK-NEXT: x1[2] = x5 && (2 >= -x && 2 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 2 * 1] : 0; +; CHECK-NEXT: x1[3] = x5 && (3 >= -x && 3 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 3 * 1] : 0; +; CHECK-NEXT: x1[4] = x5 && (4 >= -x && 4 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 4 * 1] : 0; +; CHECK-NEXT: x1[5] = x5 && (5 >= -x && 5 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 5 * 1] : 0; +; CHECK-NEXT: x1[6] = x5 && (6 >= -x && 6 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 6 * 1] : 0; +; CHECK-NEXT: x1[7] = x5 && (7 >= -x && 7 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 7 * 1] : 0; ; CHECK-NEXT: bool x6 = get_sub_group_local_id() + 16 >= -y && get_sub_group_local_id() + 16 < x4; -; CHECK-NEXT: x1[8] = x6 && (0 >= -x && 0 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 0] : 0; -; CHECK-NEXT: x1[9] = x6 && (1 >= -x && 1 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 1] : 0; -; CHECK-NEXT: x1[10] = x6 && (2 >= -x && 2 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 2] : 0; -; CHECK-NEXT: x1[11] = x6 && (3 >= -x && 3 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 3] : 0; -; CHECK-NEXT: x1[12] = x6 && (4 >= -x && 4 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 4] : 0; -; CHECK-NEXT: x1[13] = x6 && (5 >= -x && 5 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 5] : 0; -; CHECK-NEXT: x1[14] = x6 && (6 >= -x && 6 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 6] : 0; -; CHECK-NEXT: x1[15] = x6 && (7 >= -x && 7 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 7] : 0; +; CHECK-NEXT: x1[8] = x6 && (0 >= -x && 0 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 0 * 1] : 0; +; CHECK-NEXT: x1[9] = x6 && (1 >= -x && 1 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 1 * 1] : 0; +; CHECK-NEXT: x1[10] = x6 && (2 >= -x && 2 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 2 * 1] : 0; +; CHECK-NEXT: x1[11] = x6 && (3 >= -x && 3 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 3 * 1] : 0; +; CHECK-NEXT: x1[12] = x6 && (4 >= -x && 4 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 4 * 1] : 0; +; CHECK-NEXT: x1[13] = x6 && (5 >= -x && 5 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 5 * 1] : 0; +; CHECK-NEXT: x1[14] = x6 && (6 >= -x && 6 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 6 * 1] : 0; +; CHECK-NEXT: x1[15] = x6 && (7 >= -x && 7 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 7 * 1] : 0; } func @coopmatrix_b_load_t(%B: memref, %x: index, %y: index) subgroup_size(16) { @@ -152,14 +152,14 @@ func @coopmatrix_b_load_t(%B: memref, %x: index, %y: index) subgroup_ ; CHECK-LABEL: void coopmatrix_b_load_t({{.*}} ; CHECK: float x1[8]; ; CHECK-NEXT: global float* x2 = B + x * 1 + y * 64; -; CHECK-NEXT: x1[0] = x2[1 * (get_sub_group_local_id() + 0) + 0]; -; CHECK-NEXT: x1[1] = x2[1 * (get_sub_group_local_id() + 0) + 64]; -; CHECK-NEXT: x1[2] = x2[1 * (get_sub_group_local_id() + 0) + 128]; -; CHECK-NEXT: x1[3] = x2[1 * (get_sub_group_local_id() + 0) + 192]; -; CHECK-NEXT: x1[4] = x2[1 * (get_sub_group_local_id() + 0) + 256]; -; CHECK-NEXT: x1[5] = x2[1 * (get_sub_group_local_id() + 0) + 320]; -; CHECK-NEXT: x1[6] = x2[1 * (get_sub_group_local_id() + 0) + 384]; -; CHECK-NEXT: x1[7] = x2[1 * (get_sub_group_local_id() + 0) + 448]; +; CHECK-NEXT: x1[0] = x2[1 * (get_sub_group_local_id() + 0) + 0 * 64]; +; CHECK-NEXT: x1[1] = x2[1 * (get_sub_group_local_id() + 0) + 1 * 64]; +; CHECK-NEXT: x1[2] = x2[1 * (get_sub_group_local_id() + 0) + 2 * 64]; +; CHECK-NEXT: x1[3] = x2[1 * (get_sub_group_local_id() + 0) + 3 * 64]; +; CHECK-NEXT: x1[4] = x2[1 * (get_sub_group_local_id() + 0) + 4 * 64]; +; CHECK-NEXT: x1[5] = x2[1 * (get_sub_group_local_id() + 0) + 5 * 64]; +; CHECK-NEXT: x1[6] = x2[1 * (get_sub_group_local_id() + 0) + 6 * 64]; +; CHECK-NEXT: x1[7] = x2[1 * (get_sub_group_local_id() + 0) + 7 * 64]; } func @coopmatrix_b_load_t_checked(%B: memref, %x: index, %y: index) subgroup_size(16) { @@ -170,12 +170,12 @@ func @coopmatrix_b_load_t_checked(%B: memref, %x: index, %y: index) s ; CHECK-NEXT: long x3 = 64 - x; ; CHECK-NEXT: long x4 = 48 - y; ; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x3; -; CHECK-NEXT: x1[0] = x5 && (0 >= -y && 0 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 0] : 0; -; CHECK-NEXT: x1[1] = x5 && (1 >= -y && 1 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 64] : 0; -; CHECK-NEXT: x1[2] = x5 && (2 >= -y && 2 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 128] : 0; -; CHECK-NEXT: x1[3] = x5 && (3 >= -y && 3 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 192] : 0; -; CHECK-NEXT: x1[4] = x5 && (4 >= -y && 4 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 256] : 0; -; CHECK-NEXT: x1[5] = x5 && (5 >= -y && 5 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 320] : 0; -; CHECK-NEXT: x1[6] = x5 && (6 >= -y && 6 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 384] : 0; -; CHECK-NEXT: x1[7] = x5 && (7 >= -y && 7 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 448] : 0; +; CHECK-NEXT: x1[0] = x5 && (0 >= -y && 0 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 0 * 64] : 0; +; CHECK-NEXT: x1[1] = x5 && (1 >= -y && 1 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 1 * 64] : 0; +; CHECK-NEXT: x1[2] = x5 && (2 >= -y && 2 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 2 * 64] : 0; +; CHECK-NEXT: x1[3] = x5 && (3 >= -y && 3 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 3 * 64] : 0; +; CHECK-NEXT: x1[4] = x5 && (4 >= -y && 4 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 4 * 64] : 0; +; CHECK-NEXT: x1[5] = x5 && (5 >= -y && 5 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 5 * 64] : 0; +; CHECK-NEXT: x1[6] = x5 && (6 >= -y && 6 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 6 * 64] : 0; +; CHECK-NEXT: x1[7] = x5 && (7 >= -y && 7 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 7 * 64] : 0; } diff --git a/test/codegen/coopmatrix_mul_add.ir b/test/codegen/coopmatrix_mul_add.ir index 62e0e181..aec95f2e 100644 --- a/test/codegen/coopmatrix_mul_add.ir +++ b/test/codegen/coopmatrix_mul_add.ir @@ -24,16 +24,16 @@ func @coopmatrix_mul_add_ff() subgroup_size(16) { func @coopmatrix_mul_add_cf() subgroup_size(16) { %a = constant [1.0, 0.0] -> coopmatrix %b = constant 1.0 -> coopmatrix - %c = constant 1.0 -> coopmatrix + %c = constant [1.0, 0.0] -> coopmatrix %c_next = cooperative_matrix_mul_add %a, %b, %c : coopmatrix, coopmatrix, - coopmatrix -> coopmatrix + coopmatrix -> coopmatrix ; CHECK-LABEL: void coopmatrix_mul_add_cf({{.*}} ; CHECK: float2 c_next[4]; -; CHECK-NEXT: c_next[0] = (float2) (c[0], 0) + a[0] * sub_group_broadcast(b[0], 0); -; CHECK-NEXT: c_next[1] = (float2) (c[1], 0) + a[0] * sub_group_broadcast(b[0], 1); -; CHECK-NEXT: c_next[2] = (float2) (c[2], 0) + a[0] * sub_group_broadcast(b[0], 2); -; CHECK-NEXT: c_next[3] = (float2) (c[3], 0) + a[0] * sub_group_broadcast(b[0], 3); +; CHECK-NEXT: c_next[0] = c[0] + a[0] * sub_group_broadcast(b[0], 0); +; CHECK-NEXT: c_next[1] = c[1] + a[0] * sub_group_broadcast(b[0], 1); +; CHECK-NEXT: c_next[2] = c[2] + a[0] * sub_group_broadcast(b[0], 2); +; CHECK-NEXT: c_next[3] = c[3] + a[0] * sub_group_broadcast(b[0], 3); ; CHECK-NEXT: c_next[0] = c_next[0] + a[1] * sub_group_broadcast(b[1], 0); ; CHECK-NEXT: c_next[1] = c_next[1] + a[1] * sub_group_broadcast(b[1], 1); ; CHECK-NEXT: c_next[2] = c_next[2] + a[1] * sub_group_broadcast(b[1], 2); @@ -43,20 +43,20 @@ func @coopmatrix_mul_add_cf() subgroup_size(16) { func @coopmatrix_mul_add_fc() subgroup_size(16) { %a = constant 1.0 -> coopmatrix %b = constant [1.0, 0.0] -> coopmatrix - %c = constant 1.0 -> coopmatrix + %c = constant [1.0, 0.0] -> coopmatrix %c_next = cooperative_matrix_mul_add %a, %b, %c : coopmatrix, coopmatrix, - coopmatrix -> coopmatrix + coopmatrix -> coopmatrix ; CHECK-LABEL: void coopmatrix_mul_add_fc({{.*}} ; CHECK: float2 c_next[4]; -; CHECK-NEXT: c_next[0].x = ((float2) (c[0], 0)).x + a[0] * sub_group_broadcast(b[0].x, 0); -; CHECK-NEXT: c_next[0].y = ((float2) (c[0], 0)).y + a[0] * sub_group_broadcast(b[0].y, 0); -; CHECK-NEXT: c_next[1].x = ((float2) (c[1], 0)).x + a[0] * sub_group_broadcast(b[0].x, 1); -; CHECK-NEXT: c_next[1].y = ((float2) (c[1], 0)).y + a[0] * sub_group_broadcast(b[0].y, 1); -; CHECK-NEXT: c_next[2].x = ((float2) (c[2], 0)).x + a[0] * sub_group_broadcast(b[0].x, 2); -; CHECK-NEXT: c_next[2].y = ((float2) (c[2], 0)).y + a[0] * sub_group_broadcast(b[0].y, 2); -; CHECK-NEXT: c_next[3].x = ((float2) (c[3], 0)).x + a[0] * sub_group_broadcast(b[0].x, 3); -; CHECK-NEXT: c_next[3].y = ((float2) (c[3], 0)).y + a[0] * sub_group_broadcast(b[0].y, 3); +; CHECK-NEXT: c_next[0].x = c[0].x + a[0] * sub_group_broadcast(b[0].x, 0); +; CHECK-NEXT: c_next[0].y = c[0].y + a[0] * sub_group_broadcast(b[0].y, 0); +; CHECK-NEXT: c_next[1].x = c[1].x + a[0] * sub_group_broadcast(b[0].x, 1); +; CHECK-NEXT: c_next[1].y = c[1].y + a[0] * sub_group_broadcast(b[0].y, 1); +; CHECK-NEXT: c_next[2].x = c[2].x + a[0] * sub_group_broadcast(b[0].x, 2); +; CHECK-NEXT: c_next[2].y = c[2].y + a[0] * sub_group_broadcast(b[0].y, 2); +; CHECK-NEXT: c_next[3].x = c[3].x + a[0] * sub_group_broadcast(b[0].x, 3); +; CHECK-NEXT: c_next[3].y = c[3].y + a[0] * sub_group_broadcast(b[0].y, 3); ; CHECK-NEXT: c_next[0].x = c_next[0].x + a[1] * sub_group_broadcast(b[1].x, 0); ; CHECK-NEXT: c_next[0].y = c_next[0].y + a[1] * sub_group_broadcast(b[1].y, 0); ; CHECK-NEXT: c_next[1].x = c_next[1].x + a[1] * sub_group_broadcast(b[1].x, 1); diff --git a/test/codegen/coopmatrix_store.ir b/test/codegen/coopmatrix_store.ir index 60b4694b..e49bb80e 100644 --- a/test/codegen/coopmatrix_store.ir +++ b/test/codegen/coopmatrix_store.ir @@ -7,8 +7,8 @@ func @coopmatrix_a_store_n(%A: memref, %x: index, %y: index) subgroup cooperative_matrix_store %c0, %A[%x,%y] : coopmatrix, memref ; CHECK-LABEL: void coopmatrix_a_store_n({{.*}} ; CHECK: global float* x1 = A + x * 1 + y * 64; -; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 0] = c0[0]; -; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 64] = c0[1]; +; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 0 * 64] = c0[0]; +; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 1 * 64] = c0[1]; } func @coopmatrix_a_store_n_rows_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { @@ -19,8 +19,8 @@ func @coopmatrix_a_store_n_rows_checked(%A: memref, %x: index, %y: in ; CHECK-NEXT: long x2 = 64 - x; ; CHECK-NEXT: long x3 = 48 - y; ; CHECK-NEXT: if (get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x2) { -; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 0] = c0[0]; -; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 64] = c0[1]; +; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 0 * 64] = c0[0]; +; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 1 * 64] = c0[1]; ; CHECK-NEXT: } } @@ -32,10 +32,10 @@ func @coopmatrix_a_store_n_cols_checked(%A: memref, %x: index, %y: in ; CHECK-NEXT: long x2 = 64 - x; ; CHECK-NEXT: long x3 = 48 - y; ; CHECK-NEXT: if (0 >= -y && 0 < x3) { -; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 0] = c0[0]; +; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 0 * 64] = c0[0]; ; CHECK-NEXT: } ; CHECK-NEXT: if (1 >= -y && 1 < x3) { -; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 64] = c0[1]; +; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 1 * 64] = c0[1]; ; CHECK-NEXT: } } @@ -48,10 +48,10 @@ func @coopmatrix_a_store_n_checked(%A: memref, %x: index, %y: index) ; CHECK-NEXT: long x3 = 48 - y; ; CHECK-NEXT: if (get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x2) { ; CHECK-NEXT: if (0 >= -y && 0 < x3) { -; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 0] = c0[0]; +; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 0 * 64] = c0[0]; ; CHECK-NEXT: } ; CHECK-NEXT: if (1 >= -y && 1 < x3) { -; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 64] = c0[1]; +; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 1 * 64] = c0[1]; ; CHECK-NEXT: } ; CHECK-NEXT: } } @@ -61,8 +61,8 @@ func @coopmatrix_a_store_atomic_add(%A: memref, %x: index, %y: index) cooperative_matrix_store.atomic_add %c0, %A[%x,%y] : coopmatrix, memref ; CHECK-LABEL: void coopmatrix_a_store_atomic_add({{.*}} ; CHECK: global float* x1 = A + x * 1 + y * 64; -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) x1 + (1 * (get_sub_group_local_id() + 0) + 0), c0[0], memory_order_relaxed, memory_scope_work_group); -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) x1 + (1 * (get_sub_group_local_id() + 0) + 64), c0[1], memory_order_relaxed, memory_scope_work_group); +; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) x1 + (1 * (get_sub_group_local_id() + 0) + 0 * 64), c0[0], memory_order_relaxed, memory_scope_work_group); +; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) x1 + (1 * (get_sub_group_local_id() + 0) + 1 * 64), c0[1], memory_order_relaxed, memory_scope_work_group); } func @coopmatrix_a_store_checked_atomic_add(%A: memref, %x: index, %y: index) subgroup_size(16) { @@ -74,10 +74,10 @@ func @coopmatrix_a_store_checked_atomic_add(%A: memref, %x: index, %y ; CHECK-NEXT: long x3 = 48 - y; ; CHECK-NEXT: if (get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x2) { ; CHECK-NEXT: if (0 >= -y && 0 < x3) { -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) x1 + (1 * (get_sub_group_local_id() + 0) + 0), c0[0], memory_order_relaxed, memory_scope_work_group); +; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) x1 + (1 * (get_sub_group_local_id() + 0) + 0 * 64), c0[0], memory_order_relaxed, memory_scope_work_group); ; CHECK-NEXT: } ; CHECK-NEXT: if (1 >= -y && 1 < x3) { -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) x1 + (1 * (get_sub_group_local_id() + 0) + 64), c0[1], memory_order_relaxed, memory_scope_work_group); +; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) x1 + (1 * (get_sub_group_local_id() + 0) + 1 * 64), c0[1], memory_order_relaxed, memory_scope_work_group); ; CHECK-NEXT: } ; CHECK-NEXT: } } From 70a1da8b927501097db22f0b87388f7d53a218d4 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 24 Oct 2024 11:42:53 -0700 Subject: [PATCH 073/297] argparser bugfix Signed-off-by: Carsten Uphoff --- tools/argparser/argparser.hpp | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tools/argparser/argparser.hpp b/tools/argparser/argparser.hpp index e655cda5..02666c37 100644 --- a/tools/argparser/argparser.hpp +++ b/tools/argparser/argparser.hpp @@ -42,6 +42,12 @@ enum class parser_status { auto to_string(parser_status status) -> char const *; template struct default_converter; +template <> struct default_converter { + auto operator()(char const *str, char &val) const -> parser_status { + val = str[0]; + return parser_status::success; + } +}; template struct default_converter { auto operator()(char const *str, T &val) const -> parser_status { long v = strtol(str, nullptr, 0); @@ -168,9 +174,10 @@ class arg_parser { } template - auto set_short_opt(char opt, T *ptr, char const *help = nullptr, - std::optional::value_type> default_argument = - std::nullopt) -> par_model & { + auto + set_short_opt(char opt, T *ptr, char const *help = nullptr, + std::optional::value_type> default_argument = std::nullopt) + -> par_model & { auto model = std::make_unique>(ptr, std::move(default_argument)); auto model_ptr = model.get(); set_short_opt(opt, {help, std::move(model)}); @@ -182,9 +189,10 @@ class arg_parser { } template - auto set_long_opt(char const *opt, T *ptr, char const *help = nullptr, - std::optional::value_type> default_argument = - std::nullopt) -> par_model & { + auto + set_long_opt(char const *opt, T *ptr, char const *help = nullptr, + std::optional::value_type> default_argument = std::nullopt) + -> par_model & { auto model = std::make_unique>(ptr, std::move(default_argument)); auto model_ptr = model.get(); set_long_opt({opt, help, std::move(model)}); @@ -202,8 +210,8 @@ class arg_parser { } template - auto add_positional_arg(char const *opt, std::vector *ptr, - char const *help = nullptr) -> par_model & { + auto add_positional_arg(char const *opt, std::vector *ptr, char const *help = nullptr) + -> par_model & { auto model = std::make_unique>>(ptr, std::make_optional(T{})); auto model_ptr = model.get(); add_positional_arg({opt, help, std::move(model)}); From 74055d66453de404259c1df563d19a27de73b5ca Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 24 Oct 2024 20:57:19 +0200 Subject: [PATCH 074/297] Constant folding for matrix scale Signed-off-by: Carsten Uphoff --- src/node/inst_node.hpp | 2 ++ src/pass/constant_folding.cpp | 20 ++++++++++++++++++++ src/pass/constant_folding.hpp | 1 + test/linalg.hpp | 13 +++++++++++++ test/linalg_ops.hpp | 4 ++-- 5 files changed, 38 insertions(+), 2 deletions(-) diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 9f058298..5065f558 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -552,7 +552,9 @@ class cooperative_matrix_scale_inst : public standard_inst<2, 1, 0> { } enum op_number { op_a = 0, op_b = 1 }; cooperative_matrix_scale_inst(tinytc_value_t a0, tinytc_value_t b0, location const &lc = {}); + inline auto a() -> tinytc_value & { return op(op_a); } inline auto a() const -> tinytc_value const & { return op(op_a); } + inline auto b() -> tinytc_value & { return op(op_b); } inline auto b() const -> tinytc_value const & { return op(op_b); } }; diff --git a/src/pass/constant_folding.cpp b/src/pass/constant_folding.cpp index 3e97b72b..6602b943 100644 --- a/src/pass/constant_folding.cpp +++ b/src/pass/constant_folding.cpp @@ -244,6 +244,26 @@ auto constant_folding::operator()(compare_inst &in) -> fold_result { return std::visit(std::move(dispatcher), a_const->value(), b_const->value()); } +auto constant_folding::operator()(cooperative_matrix_scale_inst &in) -> fold_result { + auto &op_a = in.a(); + auto &op_b = in.b(); + + constant_inst *a_const = dyn_cast(op_a.defining_inst()); + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_scalar); + } + + if (a_const != nullptr) { + auto computer = + compute_binop_identities{unsafe_fp_math_, arithmetic::mul, op_b, true, in.loc()}; + auto dispatcher = unary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), a_const->value()); + } + return tinytc_value_t{}; +} + auto constant_folding::operator()(size_inst &in) -> fold_result { auto ct = get_memref_type(in.operand()); diff --git a/src/pass/constant_folding.hpp b/src/pass/constant_folding.hpp index be478bcc..6eb25f5f 100644 --- a/src/pass/constant_folding.hpp +++ b/src/pass/constant_folding.hpp @@ -31,6 +31,7 @@ class constant_folding { auto operator()(inst_node &) -> fold_result; auto operator()(arith_inst &) -> fold_result; auto operator()(arith_unary_inst &) -> fold_result; + auto operator()(cooperative_matrix_scale_inst &) -> fold_result; auto operator()(cast_inst &) -> fold_result; auto operator()(compare_inst &) -> fold_result; auto operator()(size_inst &in) -> fold_result; diff --git a/test/linalg.hpp b/test/linalg.hpp index a19e1193..631fb4ac 100644 --- a/test/linalg.hpp +++ b/test/linalg.hpp @@ -111,6 +111,19 @@ TEST_CASE_TEMPLATE(RUNTIME_NAME " gemm packed complex alpha=(-1,-2) beta=(2,3)", test::test_blas_a3(op, {-1.0, -2.0}, {2.0, 3.0}); } +TEST_CASE(RUNTIME_NAME " gemm packed mixed precision") { + auto KK = std::vector{53}; + auto MM = std::vector{21, 42}; + auto NN = std::vector{7, 11}; + + std::int64_t M, N, K; + DOCTEST_TENSOR3_TEST(MM, NN, KK); + + auto op = test::gemm(transpose::N, transpose::N, {{M, K}}, + {{K, N}}, {{M, N}}); + test::test_blas_a3(op, 1, 0); +} + TEST_CASE_TEMPLATE(RUNTIME_NAME " ger packed alpha=1 beta=0", T, TEST_PRECISIONS) { auto MM = std::vector{10, 32, 45}; auto NN = std::vector{1, 16, 17, 48}; diff --git a/test/linalg_ops.hpp b/test/linalg_ops.hpp index 2d859d69..206871f0 100644 --- a/test/linalg_ops.hpp +++ b/test/linalg_ops.hpp @@ -85,8 +85,8 @@ template Date: Fri, 25 Oct 2024 11:05:02 +0200 Subject: [PATCH 075/297] Split atomic store for real / imag part Signed-off-by: Carsten Uphoff --- src/codegen_tools.cpp | 49 +++++++++++++-- src/codegen_tools.hpp | 3 + src/pass/convert_to_opencl.cpp | 69 ++++++--------------- test/{linalg_ops.cpp => linalg_blas_a3.cpp} | 0 test/{linalg_ops.hpp => linalg_blas_a3.hpp} | 0 5 files changed, 66 insertions(+), 55 deletions(-) rename test/{linalg_ops.cpp => linalg_blas_a3.cpp} (100%) rename test/{linalg_ops.hpp => linalg_blas_a3.hpp} (100%) diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 8f066d4a..7a32425f 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -172,6 +172,46 @@ void atomic_store_helper(block_builder &bb, expr dst, scalar_type ty, clir::addr } } +auto atomic_store_helper_new(store_flag flag, memref_data_type const *ty, expr pointer, + expr value) -> std::vector { + const auto make_atomic_store = [&](auto fun, expr pointer, expr value) -> std::vector { + constexpr auto mem_order = clir::memory_order::relaxed; + constexpr auto mem_scope = clir::memory_scope::work_group; + constexpr auto qualifier = clir::type_qualifier::volatile_t; + + const auto sty = ty->element_ty(); + const auto addrspace = to_clir_address_space(ty->addrspace()); + if (is_complex_type(sty)) { + const auto atomic_pointer_ty = + pointer_to(to_clir_atomic_ty(element_type(sty)), addrspace, qualifier); + return {expression_statement(call_builtin( + fun, {cast(atomic_pointer_ty, address_of(dereference(pointer).s(0))), value, + mem_order, mem_scope})), + expression_statement(call_builtin( + fun, {cast(atomic_pointer_ty, address_of(dereference(pointer).s(1))), value, + mem_order, mem_scope}))}; + } else { + const auto atomic_pointer_ty = pointer_to(to_clir_atomic_ty(sty, addrspace, qualifier)); + return { + expression_statement(call_builtin(fun, {cast(atomic_pointer_ty, std::move(pointer)), + std::move(value), mem_order, mem_scope}))}; + } + }; + + switch (flag) { + case store_flag::regular: + return { + expression_statement(assignment(dereference(std::move(pointer)), std::move(value)))}; + case store_flag::atomic: + return make_atomic_store(clir::builtin_function::atomic_store_explicit, std::move(pointer), + std::move(value)); + case store_flag::atomic_add: + return make_atomic_store(clir::builtin_function::atomic_fetch_add_explicit, + std::move(pointer), std::move(value)); + } + return {}; +} + void dispatch_constant_dynamic(expr e, std::function const &const_case, std::function const &dyn_case) { visit( @@ -508,7 +548,8 @@ void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int bloc // Here we compute // blocks = ceil(loop_trip_count / block_size) = 1 + (loop_trip_count - 1) / block_size - // blocks = ceil(blocks / num_tiles) * num_tiles = (1 + (blocks - 1) / num_tiles) * num_tiles + // blocks = ceil(blocks / num_tiles) * num_tiles = (1 + (blocks - 1) / num_tiles) * + // num_tiles auto c_block_size = bb.add(make_constant(block_size, index_ty)); auto blocks0 = instant_constant_fold_add(bb, make_arith(arithmetic::sub, loop_trip_count, c1)); auto blocks1 = @@ -522,9 +563,9 @@ void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int bloc auto rem = instant_constant_fold_add(bb, make_arith(arithmetic::rem, loop_trip_count, blocks)); auto sg_id_index = instant_constant_fold_add(bb, make_cast(sg_id, index_ty)); - // The following if makes it easy to eliminate the remainder handler in optimization if rem == 0 - // is known at compile time. Without the if, we would need to prove that block_start_1 is - // non-negative to eliminate the for-loop. + // The following if makes it easy to eliminate the remainder handler in optimization if rem + // == 0 is known at compile time. Without the if, we would need to prove that block_start_1 + // is non-negative to eliminate the for-loop. auto is_rem_gt_0 = instant_constant_fold_add(bb, make_cmp(cmp_condition::gt, rem, c0)); bb.if_condition(is_rem_gt_0, [&](region_builder &bb) { auto block_start_1 = diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index 8ede8a62..f8a87770 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -5,6 +5,7 @@ #define CODEGEN_TOOLS_20240229_HPP #include "device_info.hpp" +#include "node/data_type_node.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" @@ -39,6 +40,8 @@ void store_helper(clir::block_builder &bb, bool is_atomic, clir::expr dst, scala void atomic_store_helper(clir::block_builder &bb, clir::expr dst, scalar_type ty, clir::address_space as, clir::expr value, scalar_type beta_ty, clir::expr beta); +auto atomic_store_helper_new(store_flag flag, memref_data_type const *ty, clir::expr pointer, + clir::expr value) -> std::vector; void dispatch_constant_dynamic(clir::expr e, std::function const &const_case, std::function const &dyn_case); diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 48ece310..0b8f93fb 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -778,9 +778,9 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_st clinst.reserve(len + 4); clir::expr pv[] = {val(c.pos0()), val(c.pos1())}; - auto pointer = clir::var{}; + auto base_pointer = clir::var{}; clinst.emplace_back( - declaration_assignment(visit(*this, *c.operand().ty()), pointer, + declaration_assignment(visit(*this, *c.operand().ty()), base_pointer, val(c.operand()) + pv[0] * odv.stride(0) + pv[1] * odv.stride(1))); clir::var rem[2] = {}; if (check_m || check_k) { @@ -790,10 +790,6 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_st declaration_assignment(to_clir_ty(scalar_type::index), rem[1], odv.shape(1) - pv[1])); } - auto atomic_pointer = - cast(pointer_to(to_clir_atomic_ty(ot->element_ty(), to_clir_address_space(ot->addrspace()), - clir::type_qualifier::volatile_t)), - pointer); const std::int64_t num_blocks = vt->num_blocks(core_cfg_.subgroup_size); auto const num_k = vt->shape(1 - vmode); auto store_block = std::vector{}; @@ -801,29 +797,13 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_st for (std::int64_t block = 0; block < num_blocks; ++block) { store_block.clear(); for (std::int64_t k = 0; k < num_k; ++k) { - auto const store = [&](clir::expr offset, clir::expr rhs) -> clir::expr { - switch (c.flag()) { - case store_flag::regular: - return assignment(pointer[std::move(offset)], std::move(rhs)); - case store_flag::atomic: - return call_builtin(clir::builtin_function::atomic_store_explicit, - {atomic_pointer + std::move(offset), std::move(rhs), - clir::memory_order::relaxed, - clir::memory_scope::work_group}); - case store_flag::atomic_add: - return call_builtin(clir::builtin_function::atomic_fetch_add_explicit, - {atomic_pointer + std::move(offset), std::move(rhs), - clir::memory_order::relaxed, - clir::memory_scope::work_group}); - }; - return {}; - }; auto const remainder = vt->shape(vmode) - core_cfg_.subgroup_size * block; const bool needs_mask = remainder < core_cfg_.subgroup_size; - auto offset = odv.stride(omode) * - (clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size) + - k * odv.stride(1 - omode); + auto pointer = base_pointer + + odv.stride(omode) * + (clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size) + + k * odv.stride(1 - omode); auto rhs = valv[k + block * vt->shape(1 - vmode)]; clir::expr cond = {}; if (check_k) { @@ -833,14 +813,22 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_st auto mask_cond = clir::get_sub_group_local_id() < remainder; cond = cond ? cond && mask_cond : mask_cond; } - auto st = clir::expression_statement(store(std::move(offset), std::move(rhs))); + if (cond) { store_block.emplace_back( clir::if_selection_builder(cond) - .then([&](clir::block_builder &bb) { bb.add(std::move(st)); }) + .then([&](clir::block_builder &bb) { + for (auto &s : atomic_store_helper_new(c.flag(), ot, std::move(pointer), + std::move(rhs))) { + bb.add(std::move(s)); + } + }) .get_product()); } else { - store_block.emplace_back(std::move(st)); + for (auto &s : + atomic_store_helper_new(c.flag(), ot, std::move(pointer), std::move(rhs))) { + store_block.emplace_back(std::move(s)); + } } } @@ -1419,28 +1407,7 @@ std::vector convert_to_opencl_pass::operator()(store_inst const &s) } auto rhs = val(s.val()); - auto st = clir::expr{}; - auto atomic_pointer_ty = - pointer_to(to_clir_atomic_ty(ot->element_ty(), to_clir_address_space(ot->addrspace()), - clir::type_qualifier::volatile_t)); - switch (s.flag()) { - case store_flag::regular: - st = assignment(dereference(std::move(lhs)), std::move(rhs)); - break; - case store_flag::atomic: - lhs = cast(std::move(atomic_pointer_ty), std::move(lhs)); - st = call_builtin(clir::builtin_function::atomic_store_explicit, - {std::move(lhs), std::move(rhs), clir::memory_order::relaxed, - clir::memory_scope::work_group}); - break; - case store_flag::atomic_add: - lhs = cast(std::move(atomic_pointer_ty), std::move(lhs)); - st = call_builtin(clir::builtin_function::atomic_fetch_add_explicit, - {std::move(lhs), std::move(rhs), clir::memory_order::relaxed, - clir::memory_scope::work_group}); - break; - } - return {expression_statement(std::move(st))}; + return atomic_store_helper_new(s.flag(), ot, std::move(lhs), std::move(rhs)); } std::vector convert_to_opencl_pass::operator()(sum_inst const &inst) { diff --git a/test/linalg_ops.cpp b/test/linalg_blas_a3.cpp similarity index 100% rename from test/linalg_ops.cpp rename to test/linalg_blas_a3.cpp diff --git a/test/linalg_ops.hpp b/test/linalg_blas_a3.hpp similarity index 100% rename from test/linalg_ops.hpp rename to test/linalg_blas_a3.hpp From 66bf2f76a193d976fded79685025f1d93c45ef29 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 25 Oct 2024 11:05:16 +0200 Subject: [PATCH 076/297] GEMV lowering Signed-off-by: Carsten Uphoff --- src/pass/lower_linalg.cpp | 45 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index 47aff94c..9e6e1849 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -138,6 +138,7 @@ class linalg_generator { auto operator()(axpby_inst &in) -> inst; auto operator()(ger_inst &in) -> inst; auto operator()(gemm_inst &in) -> inst; + auto operator()(gemv_inst &in) -> inst; auto operator()(hadamard_inst &in) -> inst; auto operator()(sum_inst &in) -> inst; @@ -205,7 +206,7 @@ auto linalg_generator::operator()(axpby_inst &in) -> inst { auto zero = bb.add(make_constant(0, index_ty)); bb.for_loop(zero, trip_count, index_ty, [&](region_builder &bb, value n) { auto nn = bb.add(make_arith(arithmetic::add, block, n, in.loc())); - auto static_offset_list = std::array{dynamic, 0}; + auto static_offset_list = std::array{0, dynamic}; auto static_size_list = std::array{dynamic, 0}; auto Bb = bb.add(make_subview(&in.B(), static_offset_list, static_size_list, {nn}, {c_shape0}, in.loc())); @@ -331,6 +332,48 @@ auto linalg_generator::operator()(gemm_inst &in) -> inst { return parallel; } +auto linalg_generator::operator()(gemv_inst &in) -> inst { + auto parallel = make_parallel(in.loc()); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto ct = get_memref_type(in.C()); + + auto ctx = compiler_context{in.alpha().context(), true}; + auto index_ty = get_scalar(ctx, scalar_type::index); + + auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); + + auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.C(), 0, in.loc())); + auto K = instant_constant_fold_add( + bb, make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, in.loc())); + + tile_loop_by_sgs_standard( + bb, c_shape0, core_cfg_.subgroup_size, tiling_.m_tiles() * tiling_.n_tiles(), sgid, + [&](region_builder &bb, value mm) { + auto c_zero = bb.add(make_constant(0, index_ty)); + auto c_step = bb.add(make_constant(1, index_ty)); + auto c_init = bb.add(make_constant_zero(ct->element_data_ty())); + auto c_acc = bb.for_loop( + c_zero, K, c_step, {c_init}, {ct->element_data_ty()}, index_ty, + [&](region_builder &bb, array_view p) { + auto a_idx = std::array{mm, p[0]}; + if (in.tA() == transpose::T) { + std::swap(a_idx[0], a_idx[1]); + } + auto a = bb.add(make_load(&in.A(), a_idx, in.loc())); + auto b = bb.add(make_load(&in.B(), {p[0]}, in.loc())); + auto ab = mixed_precision_arithmetic(bb, arithmetic::mul, a, b, in.loc()); + auto ab_c = mixed_precision_arithmetic(bb, arithmetic::add, p[1], ab, in.loc()); + bb.add(make_yield({ab_c}, in.loc())); + }); + blas_update(bb, in.atomic(), &in.alpha(), c_acc[0], &in.beta(), &in.C(), {mm}, + in.loc()); + }); + + return parallel; +} + auto linalg_generator::operator()(hadamard_inst &in) -> inst { auto parallel = make_parallel(in.loc()); tinytc_region_t body = ¶llel->child_region(0); From 17af0fe2e50b543dfd72de0db0ab05dffae4f24d Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 25 Oct 2024 11:06:38 +0200 Subject: [PATCH 077/297] More tests Signed-off-by: Carsten Uphoff --- CMakeLists.txt | 1 + test/CMakeLists.txt | 5 +- test/doctest_util.hpp | 4 ++ test/linalg.hpp | 102 ++++++++++++++++++++++++++++- test/linalg_blas_a2.cpp | 47 ++++++++++++++ test/linalg_blas_a2.hpp | 138 ++++++++++++++++++++++++++++++++++++++++ test/linalg_blas_a3.cpp | 16 ++++- test/linalg_blas_a3.hpp | 79 ++++++++++++++--------- test/linalg_runner.hpp | 131 +++++++++++++++++++++++++------------- test/linalg_types.hpp | 45 ++++++++++++- 10 files changed, 491 insertions(+), 77 deletions(-) create mode 100644 test/linalg_blas_a2.cpp create mode 100644 test/linalg_blas_a2.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index ca7f1f1e..67d7350b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,7 @@ cmake_dependent_option(BUILD_LEVEL_ZERO cmake_dependent_option(BUILD_OPENCL "Build support for OpenCL run-time; required when SYCL is enabled" ON "NOT BUILD_SYCL" ON) +option(BUILD_DOUBLE_PRECISION_TESTS "Build double precision unit tests" ON) include(CTest) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 85e7844c..344dd152 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -26,8 +26,11 @@ target_link_libraries(test-visitor PRIVATE test-lib) doctest_discover_tests(test-visitor) set_cxx_common_options(test-visitor) -add_library(test-lib-linalg STATIC linalg_ops.cpp linalg_types.cpp) +add_library(test-lib-linalg STATIC linalg_blas_a2.cpp linalg_blas_a3.cpp linalg_types.cpp) target_link_libraries(test-lib-linalg PUBLIC tinytc) +if (BUILD_DOUBLE_PRECISION_TESTS) + target_compile_definitions(test-lib-linalg PUBLIC ENABLE_DOUBLE_PRECISION) +endif() set_cxx_common_options(test-lib-linalg) configure_file(lit.site.cfg.py.in lit.site.cfg.py @ONLY) diff --git a/test/doctest_util.hpp b/test/doctest_util.hpp index 46b5baa9..8aea1094 100644 --- a/test/doctest_util.hpp +++ b/test/doctest_util.hpp @@ -60,6 +60,10 @@ } \ } while (false) +#ifdef ENABLE_DOUBLE_PRECISION #define TEST_PRECISIONS float, double +#else +#define TEST_PRECISIONS float +#endif #endif // DOCTEST_UTIL_20241023_HPP diff --git a/test/linalg.hpp b/test/linalg.hpp index 631fb4ac..2812382d 100644 --- a/test/linalg.hpp +++ b/test/linalg.hpp @@ -2,7 +2,8 @@ // SPDX-License-Identifier: BSD-3-Clause #include "doctest_util.hpp" -#include "linalg_ops.hpp" +#include "linalg_blas_a2.hpp" +#include "linalg_blas_a3.hpp" #include "linalg_runner.hpp" #include "linalg_types.hpp" @@ -14,6 +15,43 @@ using runtime_class = RUNTIME_CLASS; using namespace tinytc; +TEST_CASE_TEMPLATE(RUNTIME_NAME " axpby 0d", T, TEST_PRECISIONS) { + auto op = test::axpby(transpose::N, {{}}, {{}}); + test::test_blas_a2(op, 1, 0); +} + +TEST_CASE_TEMPLATE(RUNTIME_NAME " axpby 1d", T, TEST_PRECISIONS) { + auto MM = std::vector{18, 16, 32}; + + std::int64_t M; + DOCTEST_TENSOR1_TEST(MM); + + auto op = test::axpby(transpose::N, {{M}}, {{M}}); + test::test_blas_a2(op, 1, 0); +} + +TEST_CASE_TEMPLATE(RUNTIME_NAME " axpby 2d", T, TEST_PRECISIONS) { + auto MM = std::vector{18, 16, 32}; + auto NN = std::vector{5, 17}; + + std::int64_t M, N; + DOCTEST_TENSOR2_TEST(MM, NN); + + auto op = test::axpby(transpose::N, {{M, N}}, {{M, N}}); + test::test_blas_a2(op, 1, 0); +} + +TEST_CASE_TEMPLATE(RUNTIME_NAME " axpby 2d trans", T, TEST_PRECISIONS) { + auto MM = std::vector{18, 16, 32}; + auto NN = std::vector{5, 17}; + + std::int64_t M, N; + DOCTEST_TENSOR2_TEST(MM, NN); + + auto op = test::axpby(transpose::T, {{N, M}}, {{M, N}}); + test::test_blas_a2(op, 1, 0); +} + TEST_CASE_TEMPLATE(RUNTIME_NAME " gemm packed alpha=1 beta=0", T, TEST_PRECISIONS) { auto KK = std::vector{56}; auto MM = std::vector{20, 32, 53}; @@ -124,6 +162,36 @@ TEST_CASE(RUNTIME_NAME " gemm packed mixed precision") { test::test_blas_a3(op, 1, 0); } +TEST_CASE_TEMPLATE(RUNTIME_NAME " gemv packed alpha=1 beta=0", T, TEST_PRECISIONS) { + auto NN = std::vector{21}; + auto MM = std::vector{16, 23}; + + std::int64_t M, N; + DOCTEST_TENSOR2_TEST(MM, NN); + + auto op = test::gemv(transpose::N, {{M, N}}, {{N}}, {{M}}); + test::test_blas_a3(op, 1, 0); +} + +TEST_CASE_TEMPLATE(RUNTIME_NAME " gemv packed trans alpha=1 beta=0", T, TEST_PRECISIONS) { + std::int64_t M = 19, N = 32; + + auto op = test::gemv(transpose::T, {{N, M}}, {{N}}, {{M}}); + test::test_blas_a3(op, 1, 0); +} + +TEST_CASE_TEMPLATE(RUNTIME_NAME " gemv packed complex alpha=1 beta=0", T, TEST_PRECISIONS) { + auto NN = std::vector{5}; + auto MM = std::vector{8, 37}; + + std::int64_t M, N; + DOCTEST_TENSOR2_TEST(MM, NN); + + using CT = std::complex; + auto op = test::gemv(transpose::N, {{M, N}}, {{N}}, {{M}}); + test::test_blas_a3(op, 1, 0); +} + TEST_CASE_TEMPLATE(RUNTIME_NAME " ger packed alpha=1 beta=0", T, TEST_PRECISIONS) { auto MM = std::vector{10, 32, 45}; auto NN = std::vector{1, 16, 17, 48}; @@ -144,3 +212,35 @@ TEST_CASE_TEMPLATE(RUNTIME_NAME " hadamard packed alpha=1 beta=0", T, TEST_PRECI auto op = test::hadamard({{M}}, {{M}}, {{M}}); test::test_blas_a3(op, 1, 0); } + +TEST_CASE_TEMPLATE(RUNTIME_NAME " sum 1d", T, TEST_PRECISIONS) { + auto MM = std::vector{18, 16, 32}; + + std::int64_t M; + DOCTEST_TENSOR1_TEST(MM); + + auto op = test::sum(transpose::N, {{M}}, {{}}); + test::test_blas_a2(op, 1, 0); +} + +TEST_CASE_TEMPLATE(RUNTIME_NAME " sum 2d", T, TEST_PRECISIONS) { + auto MM = std::vector{18, 16, 32}; + auto NN = std::vector{5, 17}; + + std::int64_t M, N; + DOCTEST_TENSOR2_TEST(MM, NN); + + auto op = test::sum(transpose::N, {{M, N}}, {{M}}); + test::test_blas_a2(op, 1, 0); +} + +TEST_CASE_TEMPLATE(RUNTIME_NAME " sum 2d trans", T, TEST_PRECISIONS) { + auto MM = std::vector{18, 16, 32}; + auto NN = std::vector{5, 17}; + + std::int64_t M, N; + DOCTEST_TENSOR2_TEST(MM, NN); + + auto op = test::sum(transpose::T, {{N, M}}, {{M}}); + test::test_blas_a2(op, 1, 0); +} diff --git a/test/linalg_blas_a2.cpp b/test/linalg_blas_a2.cpp new file mode 100644 index 00000000..13d727b5 --- /dev/null +++ b/test/linalg_blas_a2.cpp @@ -0,0 +1,47 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "linalg_blas_a2.hpp" + +#include +#include + +namespace tinytc::test { + +auto make_blas_a2_prog(char const *name, tensor_layout const &layoutA, tensor_layout const &layoutB, + scalar_type alpha_ty, scalar_type A_ty, scalar_type beta_ty, + scalar_type B_ty, + std::function)> make_op) -> prog { + auto ctx = make_compiler_context(); + + auto const alphat = get_scalar(ctx, alpha_ty); + auto const at = get_scalar(ctx, A_ty); + auto const betat = get_scalar(ctx, beta_ty); + auto const bt = get_scalar(ctx, B_ty); + + auto p = make_prog(ctx); + + auto At = + get_memref(at, layoutA.static_shape(), layoutA.static_stride(), address_space::global); + auto Bt = + get_memref(bt, layoutB.static_shape(), layoutB.static_stride(), address_space::global); + + auto f = make_func(name, {alphat, At, betat, Bt}); + auto fn_body = f.get_body(); + auto params = std::array{}; + fn_body.get_parameters(params); + params[0].set_name("alpha"); + params[1].set_name("A"); + params[2].set_name("beta"); + params[3].set_name("B"); + + auto bb = region_builder{fn_body}; + + make_op(bb, params); + + p.add_function(std::move(f)); + + return p; +} + +} // namespace tinytc::test diff --git a/test/linalg_blas_a2.hpp b/test/linalg_blas_a2.hpp new file mode 100644 index 00000000..31e4ee03 --- /dev/null +++ b/test/linalg_blas_a2.hpp @@ -0,0 +1,138 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LINALG_BLAS_A2_20241025_HPP +#define LINALG_BLAS_A2_20241025_HPP + +#include "linalg_types.hpp" +#include "tinytc/tinytc.hpp" + +#include +#include +#include +#include +#include +#include + +namespace tinytc::test { + +auto make_blas_a2_prog(char const *name, tensor_layout const &layoutA, tensor_layout const &layoutB, + scalar_type alpha_ty, scalar_type A_ty, scalar_type beta_ty, + scalar_type B_ty, + std::function)> make_op) -> prog; + +template class axpby { + public: + using alpha_type = AlphaT; + using A_type = AT; + using beta_type = BetaT; + using B_type = BT; + static constexpr char const *kernel_name = "axpby"; + + axpby(transpose tA, tensor_layout layoutA, tensor_layout layoutB) + : tA_(tA), lA_{std::move(layoutA)}, lB_{std::move(layoutB)} {} + + auto lA() const -> tensor_layout const & { return lA_; } + auto lB() const -> tensor_layout const & { return lB_; } + + auto make_prog() const -> prog { + return make_blas_a2_prog( + kernel_name, lA_, lB_, to_scalar_type_v, to_scalar_type_v, + to_scalar_type_v, to_scalar_type_v, + [&](region_builder &bb, array_view params) { + bb.add(make_axpby(tA_, false, params[0], params[1], params[2], params[3])); + }); + } + void reference_impl(AlphaT alpha, AT const *A, BetaT beta, BT *B) { + if (lA_.dim() == 0 && lB_.dim() == 0) { + *B = alpha * (*A) + beta * (*B); + } else if (lA_.dim() == 1 && lB_.dim() == 1) { + const auto M = lB_.shape(0); + if (M != lA_.shape(0)) { + throw std::runtime_error("incompatible axpby"); + } + for (std::int64_t m = 0; m < M; ++m) { + auto &b = B[lB_.linear_index({m})]; + b = alpha * A[lA_.linear_index({m})] + beta * b; + } + } else if (lA_.dim() == 2 && lB_.dim() == 2) { + const int A_mmode = tA_ == transpose::T ? 1 : 0; + const auto M = lB_.shape(0); + const auto N = lB_.shape(1); + if (M != lA_.shape(A_mmode) || N != lA_.shape(1 - A_mmode)) { + throw std::runtime_error("incompatible axpby"); + } + for (std::int64_t n = 0; n < N; ++n) { + for (std::int64_t m = 0; m < M; ++m) { + auto &b = B[lB_.linear_index({m, n})]; + b = alpha * A[lA_.linear_index(make_index_2d(tA_, m, n))] + beta * b; + } + } + } else { + throw std::runtime_error("invald axpby dimension combination"); + } + } + + private: + transpose tA_; + tensor_layout lA_, lB_; +}; + +template class sum { + public: + using alpha_type = AlphaT; + using A_type = AT; + using beta_type = BetaT; + using B_type = BT; + static constexpr char const *kernel_name = "sum"; + + sum(transpose tA, tensor_layout layoutA, tensor_layout layoutB) + : tA_(tA), lA_{std::move(layoutA)}, lB_{std::move(layoutB)} {} + + auto lA() const -> tensor_layout const & { return lA_; } + auto lB() const -> tensor_layout const & { return lB_; } + + auto make_prog() const -> prog { + return make_blas_a2_prog( + kernel_name, lA_, lB_, to_scalar_type_v, to_scalar_type_v, + to_scalar_type_v, to_scalar_type_v, + [&](region_builder &bb, array_view params) { + bb.add(make_sum(tA_, false, params[0], params[1], params[2], params[3])); + }); + } + void reference_impl(AlphaT alpha, AT const *A, BetaT beta, BT *B) { + if (lA_.dim() == 1 && lB_.dim() == 0) { + const auto M = lA_.shape(0); + AT a_acc = AT{0}; + for (std::int64_t m = 0; m < M; ++m) { + a_acc += A[lA_.linear_index({m})]; + } + *B = alpha * a_acc + beta * (*B); + } else if (lA_.dim() == 2 && lB_.dim() == 1) { + const int A_nmode = tA_ == transpose::T ? 0 : 1; + const auto M = lB_.shape(0); + const auto N = lA_.shape(A_nmode); + if (M != lA_.shape(1 - A_nmode)) { + throw std::runtime_error("incompatible sum"); + } + for (std::int64_t m = 0; m < M; ++m) { + auto &b = B[lB_.linear_index({m})]; + AT a_acc = AT{0}; + for (std::int64_t n = 0; n < N; ++n) { + a_acc += A[lA_.linear_index(make_index_2d(tA_, m, n))]; + } + b = alpha * a_acc + beta * b; + } + } else { + throw std::runtime_error("invald sum dimension combination"); + } + } + + private: + transpose tA_; + tensor_layout lA_, lB_; +}; + +} // namespace tinytc::test + +#endif // LINALG_BLAS_A2_20241025_HPP diff --git a/test/linalg_blas_a3.cpp b/test/linalg_blas_a3.cpp index 12348840..e639407c 100644 --- a/test/linalg_blas_a3.cpp +++ b/test/linalg_blas_a3.cpp @@ -1,7 +1,7 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "linalg_ops.hpp" +#include "linalg_blas_a3.hpp" #include #include @@ -24,6 +24,20 @@ auto gemm_mnk(transpose tA, transpose tB, tensor_layout const &A, tensor_layout return {M, N, K}; } +auto gemv_mk(transpose tA, tensor_layout const &A, tensor_layout const &B, + tensor_layout const &C) -> std::array { + if (A.dim() != 2 || B.dim() != 1 || C.dim() != 1) { + throw std::runtime_error("expected vectors and matrix"); + } + const int A_kmode = tA == transpose::T ? 1 : 0; + const auto M = C.shape(0); + const auto K = A.shape(1 - A_kmode); + if (M != A.shape(A_kmode) || K != B.shape(0)) { + throw std::runtime_error("incompatible matvec"); + } + return {M, K}; +} + auto ger_mn(tensor_layout const &A, tensor_layout const &B, tensor_layout const &C) -> std::array { if (A.dim() != 1 || B.dim() != 1 || C.dim() != 2) { diff --git a/test/linalg_blas_a3.hpp b/test/linalg_blas_a3.hpp index 206871f0..6f2e5feb 100644 --- a/test/linalg_blas_a3.hpp +++ b/test/linalg_blas_a3.hpp @@ -1,8 +1,8 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#ifndef LINALG_OPS_20241023_HPP -#define LINALG_OPS_20241023_HPP +#ifndef LINALG_BLAS_A3_20241025_HPP +#define LINALG_BLAS_A3_20241025_HPP #include "linalg_types.hpp" #include "tinytc/tinytc.hpp" @@ -18,6 +18,8 @@ namespace tinytc::test { auto gemm_mnk(transpose tA, transpose tB, tensor_layout const &A, tensor_layout const &B, tensor_layout const &C) -> std::array; +auto gemv_mk(transpose tA, tensor_layout const &A, tensor_layout const &B, + tensor_layout const &C) -> std::array; auto ger_mn(tensor_layout const &A, tensor_layout const &B, tensor_layout const &C) -> std::array; auto hadamard_m(tensor_layout const &A, tensor_layout const &B, @@ -28,23 +30,6 @@ auto make_blas_a3_prog(char const *name, tensor_layout const &layoutA, tensor_la scalar_type B_ty, scalar_type beta_ty, scalar_type C_ty, std::function)> make_op) -> prog; -template -concept op_blas_a3 = requires(T op, typename T::alpha_type alpha, typename T::beta_type beta, - typename T::A_type const *A_ref, typename T::B_type const *B_ref, - typename T::C_type *C_ref) { - typename T::alpha_type; - typename T::A_type; - typename T::B_type; - typename T::beta_type; - typename T::C_type; - T::kernel_name; - { op.lA() } -> std::same_as; - { op.lB() } -> std::same_as; - { op.lC() } -> std::same_as; - { op.make_prog() } -> std::same_as; - op.reference_impl(alpha, A_ref, B_ref, beta, C_ref); -}; - template class gemm { public: using alpha_type = AlphaT; @@ -74,19 +59,12 @@ template {m, n}; - if (t == transpose::T) { - std::swap(idx[0], idx[1]); - } - return idx; - }; for (std::int64_t n = 0; n < N; ++n) { for (std::int64_t m = 0; m < M; ++m) { CT c_acc = CT{0}; for (std::int64_t k = 0; k < K; ++k) { - c_acc = c_acc + A[lA_.linear_index(make_index(tA_, m, k))] * - B[lB_.linear_index(make_index(tB_, k, n))]; + c_acc = c_acc + A[lA_.linear_index(make_index_2d(tA_, m, k))] * + B[lB_.linear_index(make_index_2d(tB_, k, n))]; } auto &c = C[lC_.linear_index({m, n})]; c = alpha * c_acc + beta * c; @@ -99,6 +77,49 @@ template class gemv { + public: + using alpha_type = AlphaT; + using A_type = AT; + using B_type = BT; + using beta_type = BetaT; + using C_type = CT; + static constexpr char const *kernel_name = "gemv"; + + gemv(transpose tA, tensor_layout layoutA, tensor_layout layoutB, tensor_layout layoutC) + : tA_(tA), lA_{std::move(layoutA)}, lB_{std::move(layoutB)}, lC_{std::move(layoutC)} {} + + auto lA() const -> tensor_layout const & { return lA_; } + auto lB() const -> tensor_layout const & { return lB_; } + auto lC() const -> tensor_layout const & { return lC_; } + + auto make_prog() const -> prog { + return make_blas_a3_prog(kernel_name, lA_, lB_, lC_, to_scalar_type_v, + to_scalar_type_v, to_scalar_type_v, + to_scalar_type_v, to_scalar_type_v, + [&](region_builder &bb, array_view params) { + bb.add(make_gemv(tA_, false, params[0], params[1], params[2], + params[3], params[4])); + }); + } + void reference_impl(AlphaT alpha, AT const *A, BT const *B, BetaT beta, CT *C) { + const auto [M, K] = gemv_mk(tA_, lA_, lB_, lC_); + for (std::int64_t m = 0; m < M; ++m) { + CT c_acc = CT{0}; + for (std::int64_t k = 0; k < K; ++k) { + c_acc = c_acc + + A[lA_.linear_index(make_index_2d(tA_, m, k))] * B[lB_.linear_index({k})]; + } + auto &c = C[lC_.linear_index({m})]; + c = alpha * c_acc + beta * c; + } + } + + private: + transpose tA_; + tensor_layout lA_, lB_, lC_; +}; + template class ger { public: using alpha_type = AlphaT; @@ -177,4 +198,4 @@ template inline constexpr bool is_complex_v = is_complex::value; template constexpr bool requires_dp_v = std::is_same_v || std::is_same_v>; +template auto make_test_data(std::size_t size) -> std::vector { + auto data = std::vector(size); + for (std::size_t i = 0; i < data.size(); ++i) { + constexpr std::size_t prime = 101; + if constexpr (is_complex_v) { + data[i] = T{static_cast((2 * i) % prime), + static_cast((2 * i + 1) % prime)}; + } else { + data[i] = static_cast(i % prime); + } + } + return data; +} + +template auto compare_data(std::vector const &A, std::vector const &B) { + REQUIRE(A.size() == B.size()); + for (std::size_t i = 0; i < A.size(); ++i) { + constexpr auto eps = 10.0 * std::numeric_limits::epsilon(); + REQUIRE(std::abs(A[i] - B[i]) == doctest::Approx(0.0).epsilon(eps)); + } +} + +template +void set_dope_vector(R &rt, typename R::kernel_t &kernel, tensor_layout const &layout, + std::uint32_t &arg_index) { + for (std::size_t i = 0; i < layout.shape().size(); ++i) { + if (layout.static_shape(i) == dynamic) { + std::int64_t s = layout.shape(i); + rt.set_arg(kernel, arg_index++, sizeof(s), &s); + } + } + for (std::size_t i = 0; i < layout.stride().size(); ++i) { + if (layout.static_stride(i) == dynamic) { + std::int64_t s = layout.stride(i); + rt.set_arg(kernel, arg_index++, sizeof(s), &s); + } + } +}; + +template +void test_blas_a2(T op, typename T::alpha_type alpha, typename T::beta_type beta) { + auto gpu_rt = std::make_shared(); + if constexpr (requires_dp_v || requires_dp_v || + requires_dp_v || requires_dp_v) { + if (!gpu_rt->supports_fp64()) { + WARN_MESSAGE(false, "Double precision tests need double precision device support"); + return; + } + } + + auto A_ref = make_test_data(op.lA().size()); + auto B_ref = std::vector(op.lB().size()); + + op.reference_impl(alpha, A_ref.data(), beta, B_ref.data()); + + auto A = gpu_rt->create_buffer(A_ref.size() * sizeof(typename T::A_type)); + auto B = gpu_rt->create_buffer(B_ref.size() * sizeof(typename T::B_type)); + gpu_rt->memcpy_h2d(A, A_ref.data(), A_ref.size() * sizeof(typename T::A_type)); + gpu_rt->fill_buffer(B, 0, B_ref.size() * sizeof(typename T::B_type)); + + auto bundle = gpu_rt->get_kernel_bundle(op.make_prog()); + auto kernel = gpu_rt->get_kernel(bundle, T::kernel_name); + + std::uint32_t i = 0; + gpu_rt->set_arg(kernel, i++, sizeof(typename T::alpha_type), &alpha); + gpu_rt->set_mem_arg(kernel, i++, A, auto_mem_type_v); + set_dope_vector(*gpu_rt, kernel, op.lA(), i); + gpu_rt->set_arg(kernel, i++, sizeof(typename T::beta_type), &beta); + gpu_rt->set_mem_arg(kernel, i++, B, auto_mem_type_v); + set_dope_vector(*gpu_rt, kernel, op.lB(), i); + gpu_rt->submit(kernel); + gpu_rt->synchronize(); + + auto B_host = std::vector(B_ref.size()); + gpu_rt->memcpy_d2h(B_host.data(), B, B_host.size() * sizeof(typename T::B_type)); + + compare_data(B_host, B_ref); + + gpu_rt->free_buffer(A); + gpu_rt->free_buffer(B); +} + template void test_blas_a3(T op, typename T::alpha_type alpha, typename T::beta_type beta) { auto gpu_rt = std::make_shared(); @@ -33,31 +114,8 @@ void test_blas_a3(T op, typename T::alpha_type alpha, typename T::beta_type beta } } - auto const make_test_data = [](std::size_t size) { - auto data = std::vector(size); - for (std::size_t i = 0; i < data.size(); ++i) { - constexpr std::size_t prime = 101; - if constexpr (is_complex_v) { - data[i] = ScalarT{static_cast((2 * i) % prime), - static_cast((2 * i + 1) % prime)}; - } else { - data[i] = static_cast(i % prime); - } - } - return data; - }; - auto const compare_data = [](std::vector const &A, - std::vector const &B) { - REQUIRE(A.size() == B.size()); - for (std::size_t i = 0; i < A.size(); ++i) { - constexpr auto eps = - 10.0 * std::numeric_limits::epsilon(); - REQUIRE(std::abs(A[i] - B[i]) == doctest::Approx(0.0).epsilon(eps)); - } - }; - - auto A_ref = make_test_data.template operator()(op.lA().size()); - auto B_ref = make_test_data.template operator()(op.lB().size()); + auto A_ref = make_test_data(op.lA().size()); + auto B_ref = make_test_data(op.lB().size()); auto C_ref = std::vector(op.lC().size()); op.reference_impl(alpha, A_ref.data(), B_ref.data(), beta, C_ref.data()); @@ -72,30 +130,15 @@ void test_blas_a3(T op, typename T::alpha_type alpha, typename T::beta_type beta auto bundle = gpu_rt->get_kernel_bundle(op.make_prog()); auto kernel = gpu_rt->get_kernel(bundle, T::kernel_name); - auto const set_dope_vector = [&](tensor_layout const &layout, std::uint32_t &arg_index) { - for (std::size_t i = 0; i < layout.shape().size(); ++i) { - if (layout.static_shape(i) == dynamic) { - std::int64_t s = layout.shape(i); - gpu_rt->set_arg(kernel, arg_index++, sizeof(s), &s); - } - } - for (std::size_t i = 0; i < layout.stride().size(); ++i) { - if (layout.static_stride(i) == dynamic) { - std::int64_t s = layout.stride(i); - gpu_rt->set_arg(kernel, arg_index++, sizeof(s), &s); - } - } - }; - std::uint32_t i = 0; gpu_rt->set_arg(kernel, i++, sizeof(typename T::alpha_type), &alpha); gpu_rt->set_mem_arg(kernel, i++, A, auto_mem_type_v); - set_dope_vector(op.lA(), i); + set_dope_vector(*gpu_rt, kernel, op.lA(), i); gpu_rt->set_mem_arg(kernel, i++, B, auto_mem_type_v); - set_dope_vector(op.lB(), i); + set_dope_vector(*gpu_rt, kernel, op.lB(), i); gpu_rt->set_arg(kernel, i++, sizeof(typename T::beta_type), &beta); gpu_rt->set_mem_arg(kernel, i++, C, auto_mem_type_v); - set_dope_vector(op.lC(), i); + set_dope_vector(*gpu_rt, kernel, op.lC(), i); gpu_rt->submit(kernel); gpu_rt->synchronize(); diff --git a/test/linalg_types.hpp b/test/linalg_types.hpp index bf39c943..af8dc803 100644 --- a/test/linalg_types.hpp +++ b/test/linalg_types.hpp @@ -6,6 +6,8 @@ #include "tinytc/tinytc.hpp" +#include +#include #include #include @@ -18,7 +20,9 @@ class tensor_layout { array_view static_stride = {}); inline auto dim() const -> std::int64_t { return shape_.size(); } - inline auto size() const -> std::int64_t { return stride_.back() * shape_.back(); } + inline auto size() const -> std::int64_t { + return dim() > 0 ? stride_.back() * shape_.back() : 1; + } inline auto shape() const -> array_view { return {shape_}; } inline auto shape(std::size_t i) const { return shape_[i]; } inline auto stride() const -> array_view { return {stride_}; } @@ -34,6 +38,45 @@ class tensor_layout { std::vector shape_, stride_, static_shape_, static_stride_; }; +template +concept op_blas_a2 = requires(T op, typename T::alpha_type alpha, typename T::beta_type beta, + typename T::A_type const *A_ref, typename T::B_type *B_ref) { + typename T::alpha_type; + typename T::A_type; + typename T::beta_type; + typename T::B_type; + T::kernel_name; + { op.lA() } -> std::same_as; + { op.lB() } -> std::same_as; + { op.make_prog() } -> std::same_as; + op.reference_impl(alpha, A_ref, beta, B_ref); +}; + +template +concept op_blas_a3 = requires(T op, typename T::alpha_type alpha, typename T::beta_type beta, + typename T::A_type const *A_ref, typename T::B_type const *B_ref, + typename T::C_type *C_ref) { + typename T::alpha_type; + typename T::A_type; + typename T::B_type; + typename T::beta_type; + typename T::C_type; + T::kernel_name; + { op.lA() } -> std::same_as; + { op.lB() } -> std::same_as; + { op.lC() } -> std::same_as; + { op.make_prog() } -> std::same_as; + op.reference_impl(alpha, A_ref, B_ref, beta, C_ref); +}; + +inline auto make_index_2d(transpose t, std::int64_t m, std::int64_t n) { + auto idx = std::array{m, n}; + if (t == transpose::T) { + std::swap(idx[0], idx[1]); + } + return idx; +}; + } // namespace tinytc::test #endif // LINALG_TYPES_20241023_HPP From 2b71c9f1c0b1d3489b227937448bf2f0b5bd2a26 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 25 Oct 2024 12:06:36 +0200 Subject: [PATCH 078/297] Add reduce_add work group collective and implement 1d sum Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.yaml | 3 ++ docs/api/builder_cxxapi.yaml | 3 ++ docs/manual/tensor-ir.rst | 24 ++++++++++++ include/tinytc/tinytc.h | 21 ++++++++++ include/tinytc/tinytc.hpp | 21 ++++++++++ include/tinytc/types.h | 5 +++ include/tinytc/types.hpp | 3 ++ src/inst.cpp | 22 +++++++++++ src/node/inst_node.cpp | 64 +++++++++++++++++++++++++++++++ src/node/inst_node.hpp | 66 ++++++++------------------------ src/parser/lexer.re | 25 ++++++------ src/parser/parser_impl.yy | 19 +++++++++ src/pass/convert_to_opencl.cpp | 21 ++++++++++ src/pass/convert_to_opencl.hpp | 1 + src/pass/dump_ir.cpp | 8 ++++ src/pass/dump_ir.hpp | 1 + src/pass/lower_linalg.cpp | 31 ++++++++++++++- test/codegen/coopmatrix_store.ir | 24 ++++++------ test/codegen/work_group.ir | 16 ++++++++ 19 files changed, 303 insertions(+), 75 deletions(-) create mode 100644 test/codegen/work_group.ir diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index 5aaa570c..95f35ff7 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -12,6 +12,7 @@ Builder C-API: - tinytc_scalar_type_t - tinytc_store_flag_t - tinytc_transpose_t + - tinytc_work_group_operation_t define: - TINYTC_DYNAMIC function: @@ -25,6 +26,7 @@ Builder C-API: - tinytc_scalar_type_to_string - tinytc_store_flag_to_string - tinytc_transpose_to_string + - tinytc_work_group_operation_to_string struct: - tinytc_position - tinytc_location @@ -94,6 +96,7 @@ Builder C-API: - tinytc_subgroup_size_inst_create - tinytc_subview_inst_create - tinytc_sum_inst_create + - tinytc_work_group_inst_create - tinytc_yield_inst_create - tinytc_inst_get_regions - tinytc_inst_get_values diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index 0797a739..34db5874 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -11,6 +11,7 @@ Builder C++-API: - tinytc::scalar_type - tinytc::store_flag - tinytc::transpose + - tinytc::work_group_operation function: - tinytc::is_dynamic_value - tinytc::to_string(address_space) @@ -22,6 +23,7 @@ Builder C++-API: - tinytc::to_string(scalar_type) - tinytc::to_string(store_flag) - tinytc::to_string(transpose) + - tinytc::to_string(work_group_operation) - tinytc::size class: - tinytc::builder_error @@ -86,6 +88,7 @@ Builder C++-API: - tinytc::make_subgroup_size - tinytc::make_subview - tinytc::make_sum + - tinytc::make_work_group - tinytc::make_yield class: - tinytc::inst diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 5c989c71..39965758 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -1369,6 +1369,30 @@ Arguments The first operand must have the same scalar type as the memref type. The indices must be of ``index`` type. +Work group collectives +...................... + +.. code:: abnf + + value-instruction =/ "work_group" work-group-op local-identifier ":" scalar-type + work-group-op = ".reduce_add" + +Overview +~~~~~~~~ + +Collective operations across a work-group. + +============= ================================================================ +Work group op Description +============= ================================================================ +.reduce_add Compute work group sum of value +============= ================================================================ + +Restrictions +~~~~~~~~~~~~ + +The work group collective must be encountered by all work-items. + Yield ..... diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 37e86284..36da020d 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -180,6 +180,8 @@ TINYTC_EXPORT char const *tinytc_matrix_use_to_string(tinytc_matrix_use_t u); TINYTC_EXPORT char const *tinytc_store_flag_to_string(tinytc_store_flag_t flag); //! Convert transpose to string TINYTC_EXPORT char const *tinytc_transpose_to_string(tinytc_transpose_t t); +//! Convert work group operation to string +TINYTC_EXPORT char const *tinytc_work_group_operation_to_string(tinytc_work_group_operation_t op); /** * @brief Create arithmetic instruction (binary) @@ -861,6 +863,25 @@ TINYTC_EXPORT tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc const tinytc_data_type_t *return_type_list, const tinytc_location_t *loc); +/** + * @brief Create work group instruction + * + * @code + * %value = work_group work_group_op %operand : type(%operand) + * @endcode + * + * @param instr [out] pointer to the inst object created + * @param operation [in] Work group operation + * @param operand [in] operand + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_work_group_inst_create(tinytc_inst_t *instr, + tinytc_work_group_operation_t operation, + tinytc_value_t operand, + const tinytc_location_t *loc); + /** * @brief Create yield instruction * diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 816a4536..4a2f8ee3 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -748,6 +748,17 @@ inline char const *to_string(transpose t) { return ::tinytc_transpose_to_string(static_cast(t)); } +/** + * @brief Convert work group operation to string + * + * @param op Operation + * + * @return C-string + */ +inline char const *to_string(work_group_operation op) { + return ::tinytc_work_group_operation_to_string(static_cast(op)); +} + namespace internal { template <> struct unique_handle_traits { static void destroy(tinytc_inst_t handle) { return tinytc_inst_destroy(handle); } @@ -1531,6 +1542,16 @@ inline inst make_if(value condition, array_view return_type_list = {} return inst(instr); } +inline inst make_work_group(work_group_operation operation, value operand, + location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC( + tinytc_work_group_inst_create(&instr, static_cast(operation), + operand, &loc), + loc); + return inst(instr); +} + /** * @brief Make yield instruction * diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 49e6b15c..410621ca 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -287,6 +287,11 @@ typedef enum { tinytc_transpose_T = 1 ///< Transpose } tinytc_transpose_t; +//! Work group collectives +typedef enum { + tinytc_work_group_operation_reduce_add = 0 ///< Reduction (add) +} tinytc_work_group_operation_t; + //! Address space typedef enum { tinytc_address_space_global = 0x1, ///< Global memory diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index d1bf425e..9c02e5fb 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -267,6 +267,9 @@ enum class transpose { T = tinytc_transpose_T ///< transpose }; +//! Work group collectives +enum class work_group_operation { reduce_add = tinytc_work_group_operation_reduce_add }; + //! Address space enum class address_space { global = tinytc_address_space_global, ///< Global memory diff --git a/src/inst.cpp b/src/inst.cpp index 1712ac77..295f2064 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -134,6 +134,14 @@ char const *tinytc_transpose_to_string(tinytc_transpose_t t) { return "unknown"; } +char const *tinytc_work_group_operation_to_string(tinytc_work_group_operation_t op) { + switch (op) { + case tinytc_work_group_operation_reduce_add: + return "reduce_add"; + } + return "unknown"; +} + tinytc_status_t tinytc_arith_inst_create(tinytc_inst_t *instr, tinytc_arithmetic_t op, tinytc_value_t a, tinytc_value_t b, const tinytc_location_t *loc) { @@ -639,6 +647,20 @@ tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condi }); } +tinytc_status_t tinytc_work_group_inst_create(tinytc_inst_t *instr, + tinytc_work_group_operation_t operation, + tinytc_value_t operand, + const tinytc_location_t *loc) { + if (instr == nullptr || operand == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = std::make_unique(enum_cast(operation), + operand, get_optional(loc)) + .release(); + }); +} + tinytc_status_t tinytc_yield_inst_create(tinytc_inst_t *instr, uint32_t yield_list_size, const tinytc_value_t *yield_list, const tinytc_location_t *loc) { diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 1b677ad7..cd1de1c5 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -25,6 +25,57 @@ auto tinytc_inst::context() const -> tinytc_compiler_context_t { return nullptr; } +auto tinytc_inst::kind() const -> tinytc::inst_execution_kind { + switch (type_id()) { + case tinytc::IK::alloca: + case tinytc::IK::barrier: + case tinytc::IK::lifetime_stop: + case tinytc::IK::foreach_loop: + case tinytc::IK::parallel: + case tinytc::IK::blas_a2: + case tinytc::IK::axpby_blas_a2: + case tinytc::IK::sum_blas_a2: + case tinytc::IK::last_blas_a2: + case tinytc::IK::blas_a3: + case tinytc::IK::gemm_blas_a3: + case tinytc::IK::gemv_blas_a3: + case tinytc::IK::ger_blas_a3: + case tinytc::IK::hadamard_blas_a3: + case tinytc::IK::last_blas_a3: + return tinytc::inst_execution_kind::collective; + case tinytc::IK::arith: + case tinytc::IK::arith_unary: + case tinytc::IK::cast: + case tinytc::IK::compare: + case tinytc::IK::constant: + case tinytc::IK::cooperative_matrix_load: + case tinytc::IK::cooperative_matrix_mul_add: + case tinytc::IK::cooperative_matrix_scale: + case tinytc::IK::cooperative_matrix_store: + case tinytc::IK::expand: + case tinytc::IK::fuse: + case tinytc::IK::load: + case tinytc::IK::group_id: + case tinytc::IK::group_size: + case tinytc::IK::if_: + case tinytc::IK::num_subgroups: + case tinytc::IK::size: + case tinytc::IK::subgroup_size: + case tinytc::IK::subview: + case tinytc::IK::store: + case tinytc::IK::work_group: + case tinytc::IK::yield: + case tinytc::IK::loop: + case tinytc::IK::for_loop: + case tinytc::IK::last_loop: + return tinytc::inst_execution_kind::mixed; + case tinytc::IK::subgroup_id: + case tinytc::IK::subgroup_local_id: + return tinytc::inst_execution_kind::spmd; + }; + throw tinytc::internal_compiler_error(); +} + namespace tinytc { scalar_data_type *get_scalar_type(location const &loc, tinytc_value const &v) { @@ -900,6 +951,19 @@ sum_inst::sum_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, tinyt } } +work_group_inst::work_group_inst(work_group_operation operation, tinytc_value_t operand0, + location const &lc) + : standard_inst{IK::work_group}, operation_(operation) { + loc(lc); + op(0, operand0); + + if (!isa(*(operand().ty()))) { + throw compilation_error(loc(), status::ir_expected_scalar); + } + + result(0) = value_node{operand().ty(), this, lc}; +} + yield_inst::yield_inst(array_view vals, location const &lc) : standard_inst{IK::yield, static_cast(vals.size())} { loc(lc); diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 5065f558..b62d0e38 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -64,6 +64,7 @@ enum class IK { subgroup_size, subview, store, + work_group, yield, // blas a2 blas_a2, @@ -94,7 +95,7 @@ using inst_nodes = class if_inst, class num_subgroups_inst, class parallel_inst, class size_inst, class subview_inst, class store_inst, class subgroup_id_inst, class subgroup_local_id_inst, class subgroup_size_inst, class sum_inst, - class yield_inst>; + class work_group_inst, class yield_inst>; using result_range = iterator_range_wrapper; using const_result_range = iterator_range_wrapper; @@ -186,55 +187,7 @@ struct tinytc_inst : tinytc::ilist_node_with_parent return child_regions_end_ - child_regions_begin_; } - inline auto kind() const -> tinytc::inst_execution_kind { - switch (type_id()) { - case tinytc::IK::alloca: - case tinytc::IK::barrier: - case tinytc::IK::lifetime_stop: - case tinytc::IK::foreach_loop: - case tinytc::IK::parallel: - case tinytc::IK::blas_a2: - case tinytc::IK::axpby_blas_a2: - case tinytc::IK::sum_blas_a2: - case tinytc::IK::last_blas_a2: - case tinytc::IK::blas_a3: - case tinytc::IK::gemm_blas_a3: - case tinytc::IK::gemv_blas_a3: - case tinytc::IK::ger_blas_a3: - case tinytc::IK::hadamard_blas_a3: - case tinytc::IK::last_blas_a3: - return tinytc::inst_execution_kind::collective; - case tinytc::IK::arith: - case tinytc::IK::arith_unary: - case tinytc::IK::cast: - case tinytc::IK::compare: - case tinytc::IK::constant: - case tinytc::IK::cooperative_matrix_load: - case tinytc::IK::cooperative_matrix_mul_add: - case tinytc::IK::cooperative_matrix_scale: - case tinytc::IK::cooperative_matrix_store: - case tinytc::IK::expand: - case tinytc::IK::fuse: - case tinytc::IK::load: - case tinytc::IK::group_id: - case tinytc::IK::group_size: - case tinytc::IK::if_: - case tinytc::IK::num_subgroups: - case tinytc::IK::size: - case tinytc::IK::subgroup_size: - case tinytc::IK::subview: - case tinytc::IK::store: - case tinytc::IK::yield: - case tinytc::IK::loop: - case tinytc::IK::for_loop: - case tinytc::IK::last_loop: - return tinytc::inst_execution_kind::mixed; - case tinytc::IK::subgroup_id: - case tinytc::IK::subgroup_local_id: - return tinytc::inst_execution_kind::spmd; - }; - throw tinytc::internal_compiler_error(); - } + auto kind() const -> tinytc::inst_execution_kind; protected: inline auto set_op_range(tinytc::use *begin, tinytc::use *end) noexcept { @@ -833,6 +786,19 @@ class sum_inst : public blas_a2_inst { transpose tA_; }; +class work_group_inst : public standard_inst<1, 1> { + public: + inline static bool classof(inst_node const &i) { return i.type_id() == IK::work_group; } + work_group_inst(work_group_operation operation, tinytc_value_t operand, + location const &lc = {}); + + inline auto operation() const -> work_group_operation { return operation_; } + inline auto operand() const -> tinytc_value const & { return op(0); } + + private: + work_group_operation operation_; +}; + class yield_inst : public standard_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::yield; } diff --git a/src/parser/lexer.re b/src/parser/lexer.re index e67c1e23..fdb80368 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -178,6 +178,7 @@ lex: "subview" { adv_loc(); return parser::make_SUBVIEW(loc_); } "store" { adv_loc(); return parser::make_STORE(loc_); } "sum" { adv_loc(); return parser::make_SUM(loc_); } + "work_group" { adv_loc(); return parser::make_WORK_GROUP(loc_); } "yield" { adv_loc(); return parser::make_YIELD(loc_); } // binary op @@ -201,18 +202,18 @@ lex: ".re" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::re, loc_); } // comparison condition - ".eq" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::eq, - loc_); } - ".ne" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::ne, - loc_); } - ".gt" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::gt, - loc_); } - ".ge" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::ge, - loc_); } - ".lt" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::lt, - loc_); } - ".le" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::le, - loc_); } + ".eq" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::eq, loc_); } + ".ne" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::ne, loc_); } + ".gt" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::gt, loc_); } + ".ge" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::ge, loc_); } + ".lt" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::lt, loc_); } + ".le" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::le, loc_); } + + // work group operation + ".reduce_add" { + adv_loc(); + return parser::make_WORK_GROUP_OPERATION(work_group_operation::reduce_add, loc_); + } whitespace { adv_loc(); goto lex; } comment { adv_loc(); goto lex; } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index b83421d7..70938bf5 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -138,6 +138,7 @@ SUBVIEW "subview" STORE "store" SUM "sum" + WORK_GROUP "work_group" YIELD "yield" ; %token > LOCAL_IDENTIFIER @@ -149,6 +150,7 @@ %token ARITHMETIC %token ARITHMETIC_UNARY %token CMP_CONDITION +%token WORK_GROUP_OPERATION %token MATRIX_USE %token CHECKED @@ -232,6 +234,7 @@ %nterm >> slice_list %nterm > slice %nterm slice_size +%nterm work_group_inst %% prog: @@ -788,6 +791,7 @@ valued_inst: | subgroup_local_id_inst { $$ = std::move($1); } | subgroup_size_inst { $$ = std::move($1); } | subview_inst { $$ = std::move($1); } + | work_group_inst { $$ = std::move($1); } ; alloca_inst: @@ -1265,6 +1269,21 @@ slice_size: | COLON integer_constant_or_identifier { $$ = $2; } ; +work_group_inst: + WORK_GROUP WORK_GROUP_OPERATION[operation] var[a] COLON data_type[ty] { + check_type($a, $ty, @a, @ty); + try { + $$ = inst { + std::make_unique($operation, std::move($a), @work_group_inst) + .release() + }; + } catch (compilation_error const &e) { + error(e.loc(), e.what()); + YYERROR; + } + } +; + %% namespace tinytc { diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 0b8f93fb..db55b17d 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -1495,6 +1495,27 @@ std::vector convert_to_opencl_pass::operator()(sum_inst const &inst) return {bb.get_product()}; } +std::vector convert_to_opencl_pass::operator()(work_group_inst const &in) { + auto const make = [](work_group_operation operation, clir::expr operand, + scalar_type sty) -> clir::expr { + switch (operation) { + case work_group_operation::reduce_add: + if (is_complex_type(sty)) { + return init_vector(to_clir_ty(sty), {clir::work_group_reduce_add(operand.s(0)), + clir::work_group_reduce_add(operand.s(1))}); + } + return clir::work_group_reduce_add(operand); + } + return {}; + }; + + auto lhs = declare(in.result(0)); + auto lhs_ty = visit(*this, *in.result()->ty()); + auto sty = get_scalar_type(in.operand()); + return {declaration_assignment(std::move(lhs_ty), std::move(lhs), + make(in.operation(), val(in.operand()), sty))}; +} + std::vector convert_to_opencl_pass::operator()(yield_inst const &in) { if (yielded_vars_.empty()) { throw compilation_error(in.loc(), status::ir_unexpected_yield); diff --git a/src/pass/convert_to_opencl.hpp b/src/pass/convert_to_opencl.hpp index f8c82c64..f109f44e 100644 --- a/src/pass/convert_to_opencl.hpp +++ b/src/pass/convert_to_opencl.hpp @@ -101,6 +101,7 @@ class convert_to_opencl_pass { std::vector operator()(subview_inst const &s); std::vector operator()(store_inst const &s); std::vector operator()(sum_inst const &s); + std::vector operator()(work_group_inst const &in); std::vector operator()(yield_inst const &in); auto run_on_program(program_node const &p) -> clir::prog; diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index ab6723a7..85d359cb 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -507,6 +507,14 @@ void dump_ir_pass::operator()(sum_inst const &a) { dump_blas_a2(static_cast(a)); } +void dump_ir_pass::operator()(work_group_inst const &in) { + dump_val(in.result(0)); + *os_ << " = work_group." << to_string(in.operation()) << " "; + dump_val(in.operand()); + *os_ << " : "; + visit(*this, *in.operand().ty()); +} + void dump_ir_pass::operator()(yield_inst const &y) { *os_ << "yield "; if (y.num_operands() > 0) { diff --git a/src/pass/dump_ir.hpp b/src/pass/dump_ir.hpp index fa6275f5..4e6dcf18 100644 --- a/src/pass/dump_ir.hpp +++ b/src/pass/dump_ir.hpp @@ -63,6 +63,7 @@ class dump_ir_pass { void operator()(subview_inst const &s); void operator()(store_inst const &s); void operator()(sum_inst const &s); + void operator()(work_group_inst const &in); void operator()(yield_inst const &y); void run_on_function(function_node const &fn); diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index 9e6e1849..131e4194 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -401,6 +401,7 @@ auto linalg_generator::operator()(sum_inst &in) -> inst { auto bb = region_builder{body}; auto ctx = compiler_context{in.alpha().context(), true}; + auto i32_ty = get_scalar(ctx, scalar_type::i32); auto index_ty = get_scalar(ctx, scalar_type::index); auto bt = get_memref_type(in.B()); @@ -408,7 +409,35 @@ auto linalg_generator::operator()(sum_inst &in) -> inst { auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); if (bt->dim() == 0) { - // @todo + auto c_sgs = bb.add(make_constant(core_cfg_.subgroup_size, i32_ty, in.loc())); + auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); + auto m = bb.add(make_subgroup_local_id(ctx, in.loc())); + auto from0 = bb.add(make_arith(arithmetic::mul, sgid, c_sgs, in.loc())); + auto from1 = bb.add(make_arith(arithmetic::add, from0, m, in.loc())); + auto from_index = bb.add(make_cast(from1, index_ty, in.loc())); + + auto c_zero = bb.add(make_constant_zero(i32_ty, in.loc())); + auto is_from_0 = bb.add(make_cmp(cmp_condition::eq, from1, c_zero, in.loc())); + + auto c_trip_count = instant_constant_fold_add(bb, make_size(&in.A(), 0, in.loc())); + auto c_step = bb.add(make_constant( + core_cfg_.subgroup_size * tiling_.m_tiles() * tiling_.n_tiles(), index_ty, in.loc())); + auto c_init = bb.add(make_constant_zero(bt->element_data_ty(), in.loc())); + + auto acc = bb.for_loop(from_index, c_trip_count, c_step, {c_init}, {bt->element_data_ty()}, + index_ty, [&](region_builder &bb, array_view args) { + auto a = bb.add(make_load(&in.A(), {args[0]}, in.loc())); + auto sum = mixed_precision_arithmetic(bb, arithmetic::add, + args[1], a, in.loc()); + bb.add(make_yield({sum}, in.loc())); + }); + auto sum = bb.add(make_work_group(work_group_operation::reduce_add, acc[0], in.loc())); + bb.if_condition( + is_from_0, + [&](region_builder &bb) { + blas_update(bb, in.atomic(), &in.alpha(), sum, &in.beta(), &in.B(), {}, in.loc()); + }, + {}, in.loc()); } else if (bt->dim() == 1) { auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.B(), 0, in.loc())); auto c_trip_count = instant_constant_fold_add( diff --git a/test/codegen/coopmatrix_store.ir b/test/codegen/coopmatrix_store.ir index e49bb80e..5dda1e26 100644 --- a/test/codegen/coopmatrix_store.ir +++ b/test/codegen/coopmatrix_store.ir @@ -7,8 +7,8 @@ func @coopmatrix_a_store_n(%A: memref, %x: index, %y: index) subgroup cooperative_matrix_store %c0, %A[%x,%y] : coopmatrix, memref ; CHECK-LABEL: void coopmatrix_a_store_n({{.*}} ; CHECK: global float* x1 = A + x * 1 + y * 64; -; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 0 * 64] = c0[0]; -; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 1 * 64] = c0[1]; +; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 0 * 64) = c0[0]; +; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 1 * 64) = c0[1]; } func @coopmatrix_a_store_n_rows_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { @@ -19,8 +19,8 @@ func @coopmatrix_a_store_n_rows_checked(%A: memref, %x: index, %y: in ; CHECK-NEXT: long x2 = 64 - x; ; CHECK-NEXT: long x3 = 48 - y; ; CHECK-NEXT: if (get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x2) { -; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 0 * 64] = c0[0]; -; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 1 * 64] = c0[1]; +; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 0 * 64) = c0[0]; +; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 1 * 64) = c0[1]; ; CHECK-NEXT: } } @@ -32,10 +32,10 @@ func @coopmatrix_a_store_n_cols_checked(%A: memref, %x: index, %y: in ; CHECK-NEXT: long x2 = 64 - x; ; CHECK-NEXT: long x3 = 48 - y; ; CHECK-NEXT: if (0 >= -y && 0 < x3) { -; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 0 * 64] = c0[0]; +; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 0 * 64) = c0[0]; ; CHECK-NEXT: } ; CHECK-NEXT: if (1 >= -y && 1 < x3) { -; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 1 * 64] = c0[1]; +; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 1 * 64) = c0[1]; ; CHECK-NEXT: } } @@ -48,10 +48,10 @@ func @coopmatrix_a_store_n_checked(%A: memref, %x: index, %y: index) ; CHECK-NEXT: long x3 = 48 - y; ; CHECK-NEXT: if (get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x2) { ; CHECK-NEXT: if (0 >= -y && 0 < x3) { -; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 0 * 64] = c0[0]; +; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 0 * 64) = c0[0]; ; CHECK-NEXT: } ; CHECK-NEXT: if (1 >= -y && 1 < x3) { -; CHECK-NEXT: x1[1 * (get_sub_group_local_id() + 0) + 1 * 64] = c0[1]; +; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 1 * 64) = c0[1]; ; CHECK-NEXT: } ; CHECK-NEXT: } } @@ -61,8 +61,8 @@ func @coopmatrix_a_store_atomic_add(%A: memref, %x: index, %y: index) cooperative_matrix_store.atomic_add %c0, %A[%x,%y] : coopmatrix, memref ; CHECK-LABEL: void coopmatrix_a_store_atomic_add({{.*}} ; CHECK: global float* x1 = A + x * 1 + y * 64; -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) x1 + (1 * (get_sub_group_local_id() + 0) + 0 * 64), c0[0], memory_order_relaxed, memory_scope_work_group); -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) x1 + (1 * (get_sub_group_local_id() + 0) + 1 * 64), c0[1], memory_order_relaxed, memory_scope_work_group); +; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) (x1 + 1 * (get_sub_group_local_id() + 0) + 0 * 64), c0[0], memory_order_relaxed, memory_scope_work_group); +; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) (x1 + 1 * (get_sub_group_local_id() + 0) + 1 * 64), c0[1], memory_order_relaxed, memory_scope_work_group); } func @coopmatrix_a_store_checked_atomic_add(%A: memref, %x: index, %y: index) subgroup_size(16) { @@ -74,10 +74,10 @@ func @coopmatrix_a_store_checked_atomic_add(%A: memref, %x: index, %y ; CHECK-NEXT: long x3 = 48 - y; ; CHECK-NEXT: if (get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x2) { ; CHECK-NEXT: if (0 >= -y && 0 < x3) { -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) x1 + (1 * (get_sub_group_local_id() + 0) + 0 * 64), c0[0], memory_order_relaxed, memory_scope_work_group); +; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) (x1 + 1 * (get_sub_group_local_id() + 0) + 0 * 64), c0[0], memory_order_relaxed, memory_scope_work_group); ; CHECK-NEXT: } ; CHECK-NEXT: if (1 >= -y && 1 < x3) { -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) x1 + (1 * (get_sub_group_local_id() + 0) + 1 * 64), c0[1], memory_order_relaxed, memory_scope_work_group); +; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) (x1 + 1 * (get_sub_group_local_id() + 0) + 1 * 64), c0[1], memory_order_relaxed, memory_scope_work_group); ; CHECK-NEXT: } ; CHECK-NEXT: } } diff --git a/test/codegen/work_group.ir b/test/codegen/work_group.ir new file mode 100644 index 00000000..f182cdc1 --- /dev/null +++ b/test/codegen/work_group.ir @@ -0,0 +1,16 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -O0 < %s | filecheck %s +func @t1() { + %0 = constant 1.0 -> f32 + %1 = work_group.reduce_add %0 : f32 +; CHECK-LABEL: void t1({{.*}} +; CHECK: float x1 = work_group_reduce_add(x); +} +func @t2() { + %0 = constant [1.0, 0.0] -> c32 + %1 = work_group.reduce_add %0 : c32 +; CHECK-LABEL: void t2({{.*}} +; CHECK: float2 x1 = (float2) (work_group_reduce_add(x.x), work_group_reduce_add(x.y)); +} From 144328f4fbc99344cd78913d4a9bab3b2859d7f8 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 25 Oct 2024 12:19:00 +0200 Subject: [PATCH 079/297] Simplify for loop builder Signed-off-by: Carsten Uphoff --- include/tinytc/tinytc.h | 6 ++---- include/tinytc/tinytc.hpp | 25 ++++++++----------------- src/inst.cpp | 9 ++++----- src/node/inst_node.cpp | 26 +++++++++----------------- src/node/inst_node.hpp | 16 ++++++++++------ src/parser/parser_impl.yy | 5 ++++- src/pass/lower_linalg.cpp | 11 +++++------ 7 files changed, 42 insertions(+), 56 deletions(-) diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 36da020d..8add0f9f 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -792,7 +792,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinyt * @brief Create for loop * * @code - * for %loop_var = %from, %to, %step init(initial_value_list) : loop_var_type { } + * for %loop_var = %from, %to, %step + * init(initial_value_list) -> (types(initial_value_list)) : loop_var_type { } * ; loop_var_type == type(%from) * ; loop_var_type == type(%to) * ; loop_var_type == type(%step) @@ -805,8 +806,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinyt * @param init_list_size [in] length of init_value_list and return_type_list * @param init_value_list [in][range(0, init_list_size)] array of initial values; can be * nullptr if init_value_list is 0 - * @param return_type_list [in][range(0, init_list_size)] return type array; can be nullptr - * if return_type_list_size is 0 * @param loop_var_type [in] type of loop variable * @param loc [in][optional] Source code location; can be nullptr * @@ -816,7 +815,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinyt tinytc_value_t to, tinytc_value_t step, uint32_t init_list_size, const tinytc_value_t *init_value_list, - const tinytc_data_type_t *return_type_list, tinytc_data_type_t loop_var_type, const tinytc_location_t *loc); diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 4a2f8ee3..c7c627da 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1481,26 +1481,20 @@ inline inst make_sum(transpose tA, bool atomic, value alpha, value A, value beta * @param to Loop variable bound * @param step Loop variable step; can be {} * @param initial_value_list Array of initial values; can be {} - * @param return_type_list Array of returned types; can be {} * @param loop_var_type Type of loop variable * @param loc Source code location * * @return Instruction */ inline inst make_for(value from, value to, value step, array_view initial_value_list, - array_view return_type_list, data_type loop_var_type, - location const &loc = {}) { + data_type loop_var_type, location const &loc = {}) { tinytc_inst_t instr; - if (initial_value_list.size() != return_type_list.size()) { - throw builder_error(status::ir_init_return_mismatch, loc); - } - auto len = return_type_list.size(); + auto len = initial_value_list.size(); if (len > std::numeric_limits::max()) { - throw std::out_of_range("return type list too long"); + throw std::out_of_range("initial value list too long"); } const tinytc_value_t *il = reinterpret_cast(initial_value_list.data()); - CHECK_STATUS_LOC(tinytc_for_inst_create(&instr, from, to, step, len, il, - return_type_list.data(), loop_var_type, &loc), + CHECK_STATUS_LOC(tinytc_for_inst_create(&instr, from, to, step, len, il, loop_var_type, &loc), loc); return inst(instr); } @@ -1778,7 +1772,7 @@ class region_builder { template void for_loop(value from, value to, value step, data_type loop_var_ty, F &&f, location const &loc = {}) { - auto fi = ::tinytc::make_for(from, to, step, {}, {}, loop_var_ty, loc); + auto fi = ::tinytc::make_for(from, to, step, {}, loop_var_ty, loc); auto reg = region{}; fi.get_regions(reg); auto loop_var = value{}; @@ -1801,23 +1795,20 @@ class region_builder { * @param to Loop variable bound * @param step Loop variable step * @param initial_value_list Array of initial values; can be {} - * @param return_type_list Array of returned types; can be {} * @param loop_var_ty Type of loop variable * @param f Functor * @param loc Source code location */ template auto for_loop(value from, value to, value step, array_view initial_value_list, - array_view return_type_list, data_type loop_var_ty, F &&f, - location const &loc = {}) -> std::vector { - auto fi = ::tinytc::make_for(from, to, step, initial_value_list, return_type_list, - loop_var_ty, loc); + data_type loop_var_ty, F &&f, location const &loc = {}) -> std::vector { + auto fi = ::tinytc::make_for(from, to, step, initial_value_list, loop_var_ty, loc); auto reg = region{}; fi.get_regions(reg); auto num_params = reg.get_parameters({}); auto params = std::vector(num_params); reg.get_parameters(params); - if (!reg || num_params != 1 + return_type_list.size()) { + if (!reg || num_params != 1 + initial_value_list.size()) { throw status::internal_compiler_error; } auto results = add_multivalued(std::move(fi)); diff --git a/src/inst.cpp b/src/inst.cpp index 295f2064..9101168a 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -604,17 +604,16 @@ tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinytc_transpose_t tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, uint32_t init_list_size, const tinytc_value_t *initial_value_list, - const tinytc_data_type_t *return_type_list, tinytc_data_type_t loop_var_type, const tinytc_location_t *loc) { if (instr == nullptr || loop_var_type == nullptr || from == nullptr || to == nullptr || - (init_list_size != 0 && (initial_value_list == nullptr || return_type_list == nullptr))) { + (init_list_size != 0 && initial_value_list == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique( - from, to, step, array_view{initial_value_list, init_list_size}, - array_view{return_type_list, init_list_size}, loop_var_type, get_optional(loc)) + *instr = std::make_unique(from, to, step, + array_view{initial_value_list, init_list_size}, + loop_var_type, get_optional(loc)) .release(); }); } diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index cd1de1c5..5c936006 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -148,14 +148,10 @@ blas_a3_inst::blas_a3_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinyt } loop_inst::loop_inst(IK tid, tinytc_value_t from0, tinytc_value_t to0, tinytc_value_t step0, - array_view init_values, - array_view return_types, tinytc_data_type_t loop_var_type, + array_view init_values, tinytc_data_type_t loop_var_type, location const &lc) : standard_inst{tid, (step0 ? 3 : 2) + static_cast(init_values.size()), - static_cast(return_types.size())} { - if (init_values.size() != return_types.size()) { - throw compilation_error(loc(), status::ir_init_return_mismatch); - } + static_cast(init_values.size())} { op(op_from, from0); op(op_to, to0); @@ -163,20 +159,17 @@ loop_inst::loop_inst(IK tid, tinytc_value_t from0, tinytc_value_t to0, tinytc_va op(op_step, step0); } - body().set_num_params(1 + return_types.size()); + body().set_num_params(1 + init_values.size()); body().set_param(0, loop_var_type, lc); - for (std::size_t i = 0; i < return_types.size(); ++i) { - body().set_param(1 + i, return_types[i], lc); - result(i) = value_node{return_types[i], this, lc}; + for (std::size_t i = 0; i < init_values.size(); ++i) { + body().set_param(1 + i, init_values[i]->ty(), lc); + result(i) = value_node{init_values[i]->ty(), this, lc}; } for (std::size_t i = 0; i < init_values.size(); ++i) { - if (!isa(*return_types[i]) && - !isa(*return_types[i])) { + if (!isa(*init_values[i]->ty()) && + !isa(*init_values[i]->ty())) { throw compilation_error(loc(), status::ir_expected_coopmatrix_or_scalar); } - if (init_values[i]->ty() != return_types[i]) { - throw compilation_error(loc(), status::ir_init_return_mismatch); - } op(op_init() + i, init_values[i]); } loc(lc); @@ -791,8 +784,7 @@ ger_inst::ger_inst(tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t B0, foreach_inst::foreach_inst(tinytc_value_t from, tinytc_value_t to, tinytc_data_type_t loop_var_type, location const &loc) - : loop_inst{ - IK::foreach_loop, std::move(from), std::move(to), nullptr, {}, {}, loop_var_type, loc} { + : loop_inst{IK::foreach_loop, std::move(from), std::move(to), nullptr, {}, loop_var_type, loc} { child_region(0).kind(region_kind::spmd); } diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index b62d0e38..0840c99d 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -334,8 +334,8 @@ class loop_inst : public standard_inst { } enum op_number { op_from = 0, op_to = 1, op_step = 2 }; loop_inst(IK tid, tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, - array_view init_values, array_view return_types, - tinytc_data_type_t loop_var_type, location const &loc = {}); + array_view init_values, tinytc_data_type_t loop_var_type, + location const &loc = {}); inline auto from() const -> tinytc_value const & { return op(op_from); } inline auto to() const -> tinytc_value const & { return op(op_to); } inline auto has_step() const -> bool { return op_init() == 3; } @@ -639,11 +639,15 @@ class for_inst : public loop_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::for_loop; } inline for_inst(tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, - array_view init_values, - array_view return_types, tinytc_data_type_t loop_var_type, + array_view init_values, tinytc_data_type_t loop_var_type, location const &loc = {}) - : loop_inst{IK::for_loop, std::move(from), std::move(to), std::move(step), - std::move(init_values), std::move(return_types), loop_var_type, loc} {} + : loop_inst{IK::for_loop, + std::move(from), + std::move(to), + std::move(step), + std::move(init_values), + loop_var_type, + loc} {} }; class foreach_inst : public loop_inst { diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 70938bf5..e56ce9f9 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -606,10 +606,13 @@ for_inst: if (lcv_init.size() != lcv_type.size()) { throw parser::syntax_error(@lcv, "Length of init value list must match scalar type list"); } + for (std::size_t i = 0; i < lcv_init.size(); ++i) { + check_type(lcv_init[i], lcv_type[i], @lcv, @lcv); + } location loc = @FOR; loc.end = @for_loop_var_type.end; auto inode = std::make_unique($from, $to, $optional_step, lcv_init, - lcv_type, $for_loop_var_type, loc); + $for_loop_var_type, loc); ctx.push_scope(); auto &loop_var = inode->loop_var(); ctx.val($loop_var, loop_var, @loop_var); diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index 131e4194..f006f132 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -54,8 +54,7 @@ void gemm_microkernel(region_builder &bb, transpose tA, transpose tB, bool atomi value c_init) -> value { auto c_step = bb.add(make_constant(k_block_size, index_ty, loc)); auto return_values = bb.for_loop( - K0, K1, c_step, {c_init}, {coopmatrix_c_ty}, index_ty, - [&](region_builder &bb, array_view p) { + K0, K1, c_step, {c_init}, index_ty, [&](region_builder &bb, array_view p) { const auto k = p[0]; value pos_a[2] = {m_block, k}; @@ -355,7 +354,7 @@ auto linalg_generator::operator()(gemv_inst &in) -> inst { auto c_step = bb.add(make_constant(1, index_ty)); auto c_init = bb.add(make_constant_zero(ct->element_data_ty())); auto c_acc = bb.for_loop( - c_zero, K, c_step, {c_init}, {ct->element_data_ty()}, index_ty, + c_zero, K, c_step, {c_init}, index_ty, [&](region_builder &bb, array_view p) { auto a_idx = std::array{mm, p[0]}; if (in.tA() == transpose::T) { @@ -424,8 +423,8 @@ auto linalg_generator::operator()(sum_inst &in) -> inst { core_cfg_.subgroup_size * tiling_.m_tiles() * tiling_.n_tiles(), index_ty, in.loc())); auto c_init = bb.add(make_constant_zero(bt->element_data_ty(), in.loc())); - auto acc = bb.for_loop(from_index, c_trip_count, c_step, {c_init}, {bt->element_data_ty()}, - index_ty, [&](region_builder &bb, array_view args) { + auto acc = bb.for_loop(from_index, c_trip_count, c_step, {c_init}, index_ty, + [&](region_builder &bb, array_view args) { auto a = bb.add(make_load(&in.A(), {args[0]}, in.loc())); auto sum = mixed_precision_arithmetic(bb, arithmetic::add, args[1], a, in.loc()); @@ -448,7 +447,7 @@ auto linalg_generator::operator()(sum_inst &in) -> inst { auto from = bb.add(make_constant(0, index_ty)); auto zero = bb.add(make_constant_zero(bt->element_data_ty())); auto acc = - bb.for_loop(from, c_trip_count, {}, {zero}, {bt->element_data_ty()}, index_ty, + bb.for_loop(from, c_trip_count, {}, {zero}, index_ty, [&](region_builder &bb, array_view args) { auto index_list = std::array{mm, args[0]}; if (in.tA() == transpose::T) { From 1f40226757f6f09b4fa730a228acfdef4af41032 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 25 Oct 2024 12:37:39 +0200 Subject: [PATCH 080/297] Update atomic test Signed-off-by: Carsten Uphoff --- test/codegen/atomic.ir | 62 +++++++----------------------------------- 1 file changed, 10 insertions(+), 52 deletions(-) diff --git a/test/codegen/atomic.ir b/test/codegen/atomic.ir index 37f81fd7..e5ea5c93 100644 --- a/test/codegen/atomic.ir +++ b/test/codegen/atomic.ir @@ -12,58 +12,16 @@ func @atomic_store(%A: memref) { ; CHECK: atomic_fetch_add_explicit((global volatile atomic_double*) (A + i0 * 1), f0, memory_order_relaxed, memory_scope_work_group); } - -func @axpby_atomic_store(%alpha: f64, %A: memref, %B: memref) { - %zero = constant 0.0 -> f64 - axpby.n.atomic %alpha, %A, %zero, %B : f64, memref, f64, memref -; CHECK: global double* b = B + (blck + m) * 1; -; CHECK-NEXT: atomic_store_explicit((global volatile atomic_double*) b, alpha * A[(blck + m) * 1], memory_order_relaxed, memory_scope_work_group); -} - -func @axpby_atomic_add(%alpha: f32, %A: memref, %B: memref) { - %one = constant 1.0 -> f32 - axpby.n.atomic %alpha, %A, %one, %B : f32, memref, f32, memref -; CHECK: global float* b = Bb + (blck1 + m) * 1; -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) b, alpha * Ab[(blck1 + m) * 1], memory_order_relaxed, memory_scope_work_group); -} - -func @gemm_atomic(%A: memref, %B: memref, %C: memref) { - %one = constant 1.0 -> f32 - gemm.n.n.atomic %one, %A, %B, %one, %C - : f32, memref, memref, f32, memref -; CHECK: atomic_fetch_add_explicit((global volatile atomic_float*) (Cb + get_sub_group_local_id()), c[n], memory_order_relaxed, memory_scope_work_group); -} - -func @ger_atomic(%A: memref, %B: memref, %C: memref) { - %one = constant 1.0 -> f32 - ger.atomic %one, %A, %B, %one, %C - : f32, memref, memref, f32, memref -; CHECK: global float* c = Cb + (blck1 + m) * 1; -; CHECK-NEXT: float ab = A[(blck1 + m) * 1] * b; -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) c, 0x1p+0f * ab, memory_order_relaxed, memory_scope_work_group); -} - -func @hadamard_atomic(%A: memref, %B: memref, %C: memref) { - %one = constant 1.0 -> f32 - hadamard.atomic %one, %A, %B, %one, %C - : f32, memref, memref, f32, memref -; CHECK: global float* c = C + (blck + m) * 1; -; CHECK-NEXT: float ab = A[(blck + m) * 1] * B[(blck + m) * 1]; -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) c, 0x1p+0f * ab, memory_order_relaxed, memory_scope_work_group); +func @atomic_store_c64(%A: memref) { + %f0 = constant [0.0, 0.0] -> c64 + %i0 = constant 0 -> index + store.atomic %f0, %A[%i0] : memref + store.atomic_add %f0, %A[%i0] : memref +; CHECK-LABEL: void atomic_store_c64({{.*}} +; CHECK: atomic_store_explicit((atomic_double*global volatile) &(*(A + i0 * 1)).x, f0, memory_order_relaxed, memory_scope_work_group); +; CHECK-NEXT: atomic_store_explicit((atomic_double*global volatile) &(*(A + i0 * 1)).y, f0, memory_order_relaxed, memory_scope_work_group); +; CHECK: atomic_fetch_add_explicit((atomic_double*global volatile) &(*(A + i0 * 1)).x, f0, memory_order_relaxed, memory_scope_work_group); +; CHECK-NEXT: atomic_fetch_add_explicit((atomic_double*global volatile) &(*(A + i0 * 1)).y, f0, memory_order_relaxed, memory_scope_work_group); } -func @sum_atomic(%A: memref, %B: memref) { - %one = constant 1.0 -> f32 - sum.n.atomic %one, %A, %one, %B : f32, memref, f32, memref -; CHECK: float sum = work_group_reduce_add(acc); -; CHECK-NEXT: if (get_sub_group_id() == 0 && get_sub_group_local_id() == 0) { -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) B, 0x1p+0f * sum, memory_order_relaxed, memory_scope_work_group); -; CHECK-NEXT: } -} -func @sum_atomic_matrix(%A: memref, %B: memref) { - %one = constant 1.0 -> f32 - sum.n.atomic %one, %A, %one, %B : f32, memref, f32, memref -; CHECK: global float* b = B + (blck + m) * 1; -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) b, 0x1p+0f * acc, memory_order_relaxed, memory_scope_work_group); -} From 9d5db5813e279fcffd85e8248773ec3a70011c49 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 25 Oct 2024 17:06:05 +0200 Subject: [PATCH 081/297] Add core features to tools' argparser Signed-off-by: Carsten Uphoff --- tools/argparser/argparser_common.cpp | 38 ++++++++++++++++++++++++++++ tools/argparser/argparser_common.hpp | 3 +++ tools/offline_compiler/main.cpp | 6 +++++ tools/opt/main.cpp | 6 +++++ 4 files changed, 53 insertions(+) diff --git a/tools/argparser/argparser_common.cpp b/tools/argparser/argparser_common.cpp index f18e64f4..4be44ac9 100644 --- a/tools/argparser/argparser_common.cpp +++ b/tools/argparser/argparser_common.cpp @@ -51,4 +51,42 @@ void list_optimization_flags(std::ostream &os) { os << "unsafe-fp-math" << std::endl; } +void add_core_feature_flags(arg_parser &parser, tinytc_core_feature_flags_t &flags) { + auto const converter = [](char const *str, tinytc_core_feature_flags_t &val) { + bool clear = false; + constexpr char const disable_prefix[] = "no-"; + constexpr std::size_t disable_prefix_len = sizeof(disable_prefix) - 1; + if (std::strncmp(str, disable_prefix, disable_prefix_len) == 0) { + clear = true; + str = str + disable_prefix_len; + } + tinytc_core_feature_flags_t flag = 0; + switch (fnv1a(str, std::strlen(str))) { + case "large-register-file"_fnv1a: + flag = tinytc_core_feature_flag_large_register_file; + break; + default: + return parser_status::invalid_argument; + }; + if (clear) { + val &= ~flag; + } else { + val |= flag; + } + return parser_status::success; + }; + parser + .set_short_opt('F', &flags, + "Enable core feature flag; use \"no-\" prefix to clear feature flag") + .converter(converter); +} + +void list_core_feature_flags(std::ostream &os) { + os << "Core feature flags:" << std::endl; + for (int i = 0; i < arg_parser::optindent; ++i) { + os << ' '; + } + os << "large-register-file" << std::endl; +} + } // namespace tinytc::cmd diff --git a/tools/argparser/argparser_common.hpp b/tools/argparser/argparser_common.hpp index 5ebe42a4..7919bebe 100644 --- a/tools/argparser/argparser_common.hpp +++ b/tools/argparser/argparser_common.hpp @@ -24,6 +24,9 @@ void add_optflag_states(arg_parser &parser, optflag_states &flags); void set_optflags(compiler_context &ctx, optflag_states const &flags); void list_optimization_flags(std::ostream &os); +void add_core_feature_flags(arg_parser &parser, tinytc_core_feature_flags_t &flags); +void list_core_feature_flags(std::ostream &os); + } // namespace tinytc::cmd #endif // ARGPARSER_COMMON_20241010_HPP diff --git a/tools/offline_compiler/main.cpp b/tools/offline_compiler/main.cpp index 2a8c5cd0..75b5af72 100644 --- a/tools/offline_compiler/main.cpp +++ b/tools/offline_compiler/main.cpp @@ -17,6 +17,7 @@ using namespace tinytc; int main(int argc, char **argv) { char const *filename = nullptr; auto info = core_info{}; + tinytc_core_feature_flags_t core_features = 0; std::int32_t opt_level = 2; auto flags = cmd::optflag_states{}; bool help = false; @@ -42,6 +43,7 @@ int main(int argc, char **argv) { parser.add_positional_arg("file-name", &filename, "Path to source code; leave empty to read from stdin"); cmd::add_optflag_states(parser, flags); + cmd::add_core_feature_flags(parser, core_features); parser.parse(argc, argv); } catch (status const &st) { @@ -57,6 +59,9 @@ int main(int argc, char **argv) { std::cout << std::endl; cmd::list_optimization_flags(std::cout); + std::cout << std::endl; + cmd::list_core_feature_flags(std::cout); + return 0; } @@ -65,6 +70,7 @@ int main(int argc, char **argv) { ctx = make_compiler_context(); ctx.set_optimization_level(opt_level); cmd::set_optflags(ctx, flags); + info.set_core_features(core_features); auto p = prog{}; if (!filename) { p = parse_stdin(ctx); diff --git a/tools/opt/main.cpp b/tools/opt/main.cpp index aa9d35bc..e858f7ba 100644 --- a/tools/opt/main.cpp +++ b/tools/opt/main.cpp @@ -18,6 +18,7 @@ int main(int argc, char **argv) { auto pass_names = std::vector{}; char const *filename = nullptr; auto info = core_info{}; + tinytc_core_feature_flags_t core_features = 0; std::int32_t opt_level = 2; auto flags = cmd::optflag_states{}; bool help = false; @@ -44,6 +45,7 @@ int main(int argc, char **argv) { parser.add_positional_arg("file-name", &filename, "Path to source code; leave empty to read from stdin"); cmd::add_optflag_states(parser, flags); + cmd::add_core_feature_flags(parser, core_features); parser.parse(argc, argv); } catch (status const &st) { @@ -71,6 +73,9 @@ int main(int argc, char **argv) { std::cout << std::endl; cmd::list_optimization_flags(std::cout); + std::cout << std::endl; + cmd::list_core_feature_flags(std::cout); + return 0; } @@ -83,6 +88,7 @@ int main(int argc, char **argv) { ctx = make_compiler_context(); ctx.set_optimization_level(opt_level); cmd::set_optflags(ctx, flags); + info.set_core_features(core_features); auto p = prog{}; if (!filename) { p = parse_stdin(ctx); From c7b1eedf9adfdcbf41a64eab85aeb754292e7d48 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 25 Oct 2024 08:18:30 -0700 Subject: [PATCH 082/297] Update flop count for complex examples Signed-off-by: Carsten Uphoff --- examples/benchmark/main.cpp | 19 +++++++++---------- examples/tall_and_skinny/main.cpp | 12 +++++++++++- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/examples/benchmark/main.cpp b/examples/benchmark/main.cpp index 18ab4aba..a9dbef03 100644 --- a/examples/benchmark/main.cpp +++ b/examples/benchmark/main.cpp @@ -235,16 +235,15 @@ template void test(queue q, args &a) { }).wait(); }); - auto ops_per_mnk = 0; - switch (element_ty) { - case scalar_type::c32: - case scalar_type::c64: - ops_per_mnk = 8; - break; - default: - ops_per_mnk = 2; - break; - } + const auto ops_per_mnk = [&] { + switch (a.ty) { + case scalar_type::c32: + case scalar_type::c64: + return 8; + default: + return 2; + } + }(); auto gflops = a.internal_repetitions * ops_per_mnk * c.m * c.n * c.k * howmany / min_exec_time_ns; diff --git a/examples/tall_and_skinny/main.cpp b/examples/tall_and_skinny/main.cpp index 30be5031..96d79943 100644 --- a/examples/tall_and_skinny/main.cpp +++ b/examples/tall_and_skinny/main.cpp @@ -135,10 +135,20 @@ template void test(queue q, args &a) { } double min_exec_time_ns = bench([&]() { tas.submit(q).wait(); }); + const auto ops_per_mnk = [&] { + switch (a.ty) { + case scalar_type::c32: + case scalar_type::c64: + return 8; + default: + return 2; + } + }(); + auto bw_C_factor = a.update ? 2 : 1; auto bw = sizeof(T) * (c.m * c.n * bw_C_factor + c.m * c.k + c.k * c.n) / min_exec_time_ns; - auto gflops = 2 * c.m * c.n * c.k / min_exec_time_ns; + auto gflops = ops_per_mnk * c.m * c.n * c.k / min_exec_time_ns; std::cout << to_string(a.ty) << "," << c.m << "," << c.n << "," << c.k << "," << a.update << "," << min_exec_time_ns / 1e9 << "," << bw << "," << gflops << std::endl; From 5b602cee27bc3351a2fb77077d525d79d88b47f6 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 25 Oct 2024 09:05:10 -0700 Subject: [PATCH 083/297] Update tall and skinny recipe Signed-off-by: Carsten Uphoff --- src/recipe/tall_and_skinny.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index dc796f45..77d307a8 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -139,8 +139,8 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( if (!is_dynamic_value(M) && M % M_block_size == 0) { static_gemm(bb); } else { - auto M_val = is_dynamic_value(M) ? bb.add(make_size(C, 0, my_loc())) - : bb.add(make_constant(M, index_ty)); + + auto M_val = bb.add(make_size(C, 0, my_loc())); auto M_val_sub_m = bb.add(make_arith(arithmetic::sub, M_val, m, my_loc())); auto cond = bb.add(make_cmp(cmp_condition::lt, M_val_sub_m, c_M_block_size, my_loc())); From ec25f09c6193d90fd7652af0b9b4850ea23efefd Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 4 Nov 2024 18:20:54 +0100 Subject: [PATCH 084/297] Start adding SPIR-V backend Signed-off-by: Carsten Uphoff --- docs/api/core_capi.yaml | 1 + docs/api/core_cxxapi.yaml | 1 + docs/manual/tensor-ir.rst | 56 +- include/tinytc/tinytc.h | 56 +- include/tinytc/tinytc.hpp | 14 + include/tinytc/types.h | 3 + include/tinytc/types.hpp | 1 + src/CMakeLists.txt | 4 + src/compiler.cpp | 84 +- src/compiler_context_cache.cpp | 1 + src/compiler_context_cache.hpp | 1 + src/error.cpp | 2 + src/node/data_type_node.cpp | 4 + src/node/data_type_node.hpp | 2 + src/node/inst_node.cpp | 10 +- src/pass/convert_to_spirv.cpp | 289 ++ src/pass/convert_to_spirv.hpp | 30 + src/passes.def | 2 +- src/spv/enums.hpp | 1425 ++++++++ src/spv/instructions.hpp | 5590 +++++++++++++++++++++++++++++++ src/spv/module.cpp | 16 + src/spv/module.hpp | 59 + src/spv/names.cpp | 2893 ++++++++++++++++ src/spv/names.hpp | 69 + src/spv/pass/dump_asm.cpp | 110 + src/spv/pass/dump_asm.hpp | 57 + src/spv/visit.hpp | 2915 ++++++++++++++++ tools/offline_compiler/main.cpp | 31 +- tools/spirvgen/filter.json | 11 + tools/spirvgen/spirvgen.py | 385 +++ 30 files changed, 14043 insertions(+), 79 deletions(-) create mode 100644 src/pass/convert_to_spirv.cpp create mode 100644 src/pass/convert_to_spirv.hpp create mode 100644 src/spv/enums.hpp create mode 100644 src/spv/instructions.hpp create mode 100644 src/spv/module.cpp create mode 100644 src/spv/module.hpp create mode 100644 src/spv/names.cpp create mode 100644 src/spv/names.hpp create mode 100644 src/spv/pass/dump_asm.cpp create mode 100644 src/spv/pass/dump_asm.hpp create mode 100644 src/spv/visit.hpp create mode 100644 tools/spirvgen/filter.json create mode 100755 tools/spirvgen/spirvgen.py diff --git a/docs/api/core_capi.yaml b/docs/api/core_capi.yaml index c57673bc..f2790094 100644 --- a/docs/api/core_capi.yaml +++ b/docs/api/core_capi.yaml @@ -46,6 +46,7 @@ Core C-API: - tinytc_run_function_pass - tinytc_list_function_passes - tinytc_prog_compile_to_opencl + - tinytc_prog_compile_to_spirv Compiler Context: function: - tinytc_compiler_context_create diff --git a/docs/api/core_cxxapi.yaml b/docs/api/core_cxxapi.yaml index 64461f1b..b329c907 100644 --- a/docs/api/core_cxxapi.yaml +++ b/docs/api/core_cxxapi.yaml @@ -30,6 +30,7 @@ Core C++-API: - tinytc::run_function_pass - tinytc::list_function_passes - tinytc::compile_to_opencl + - tinytc::compile_to_spirv Compiler Context: function: - tinytc::make_compiler_context diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 39965758..fc107514 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -285,6 +285,8 @@ The supported matrix shapes may depend on data type, matrix use, and target hard An argument to any instruction that has coopmatrix type **must** be dynamically uniform. +Having i1 as component type of a coopmatrix is forbidden. + Instructions ============ @@ -644,21 +646,22 @@ Arithmetic on cooperative matrices is done component-wise. The following table shows the operations' description and the types that are allowed for the operation. The backslash "\\" is used to exclude types from the list of allowed types. - -==== ============================= ================================================================ -Op Allowed type Description -==== ============================= ================================================================ -.add scalar-type / coopmatrix-type Sum of operands -.sub scalar-type / coopmatrix-type Difference of operands -.mul scalar-type / coopmatrix-type Product of operands -.div scalar-type / coopmatrix-type Quotient of operands -.rem scalar-type \\ complex-type Remainder from the division of operands -.shl integer-type \\ i1 Left shift first operand by second operand -.shr integer-type \\ i1 Arithmetic right shift first operand by second operand -.and integer-type Bitwise and -.or integer-type Bitwise or -.xor integer-type Bitwise xor -==== ============================= ================================================================ +Boolean arithmetic is only allowed for .and, .or, and .xor. + +==== ============================= ========== ====================================================== +Op Allowed type i1 allowed Description +==== ============================= ========== ====================================================== +.add scalar-type / coopmatrix-type No Sum of operands +.sub scalar-type / coopmatrix-type No Difference of operands +.mul scalar-type / coopmatrix-type No Product of operands +.div scalar-type / coopmatrix-type No Quotient of operands +.rem scalar-type \\ complex-type No Remainder from the division of operands +.shl integer-type No Left shift first operand by second operand +.shr integer-type No Arithmetic right shift first operand by second operand +.and integer-type Yes Bitwise and +.or integer-type Yes Bitwise or +.xor integer-type Yes Bitwise xor +==== ============================= ========== ====================================================== Arithmetic (unary) .................. @@ -679,17 +682,18 @@ for ".abs", ".im", and ".re", and the returned value has the same type as the op for ".neg" and ".conj". The following table shows the operations' description and the types that are allowed for the operation. - -===== ============================= ============================= -Op Allowed type Description -===== ============================= ============================= -.abs scalar-type Compute absolute value -.neg scalar-type / coopmatrix-type Negation -.not integer-type Bitwise not -.conj complex-type Complex conjugate -.im complex-type Extract imaginary part -.re complex-type Extract real part -===== ============================= ============================= +Boolean arithmetic is only allowed for .neg. + +===== ============================= ========== ============================= +Op Allowed type i1 allowed Description +===== ============================= ========== ============================= +.abs scalar-type No Compute absolute value +.neg scalar-type / coopmatrix-type No Negation +.not integer-type Yes Bitwise not +.conj complex-type No Complex conjugate +.im complex-type No Extract imaginary part +.re complex-type No Extract real part +===== ============================= ========== ============================= Barrier ....... diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 8add0f9f..3714318d 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -1459,13 +1459,25 @@ TINYTC_EXPORT tinytc_status_t tinytc_list_function_passes(uint32_t *names_size, TINYTC_EXPORT tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_t prg, const_tinytc_core_info_t info); +/** + * @brief Compiler tensor language to SPIR-V + * + * @param bin [out] pointer to the binary object created + * @param prg [inout] tensor program; modified as compiler passes are run + * @param info [in] core info object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prog_compile_to_spirv(tinytc_binary_t *bin, tinytc_prog_t prg, + const_tinytc_core_info_t info); + /** * @brief Get source text * * @param src [in] source object * @param length [out] pointer to code length - * @param code [out] code contains a pointer to the source text; the pointer is only valid as long - * as the source object is alive + * @param code [out] code contains a pointer to the source text; the pointer is only valid as + * long as the source object is alive * * @return tinytc_status_success on success and error otherwise */ @@ -1511,8 +1523,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_source_get_core_features( * * @param src [in] source object * @param extensions_size [out] pointer to number of extensions - * @param extensions [out][range(0,extensions_size)] pointer to array of C-strings; array owned by - * source object + * @param extensions [out][range(0,extensions_size)] pointer to array of C-strings; array owned + * by source object * * @return tinytc_status_success on success and error otherwise */ @@ -1528,8 +1540,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_source_get_extensions(const_tinytc_source_t * @param format [in] Bundle format (SPIR-V or Native) * @param data_size [in] Size of data in bytes * @param data [in][range(0, data_size)] Binary data; data is copied - * @param core_features [in][optional] requested core features; must be 0 (default) or a combination - * of tinytc_core_feature_flag_t + * @param core_features [in][optional] requested core features; must be 0 (default) or a + * combination of tinytc_core_feature_flag_t * * @return tinytc_status_success on success and error otherwise */ @@ -1622,8 +1634,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_binary_retain(tinytc_binary_t bin); /** * @brief Returns a small batched GEMM recipe * - * The program contains a kernel for @f$\beta=0@f$ called "gemm_beta0" and a kernel for @f$\beta\neq - * 0@f$ called "gemm". All matrix shapes and strides are known at compile-time. + * The program contains a kernel for @f$\beta=0@f$ called "gemm_beta0" and a kernel for + * @f$\beta\neq 0@f$ called "gemm". All matrix shapes and strides are known at compile-time. * * The signature of the generated kernels gemm and gemm_beta0 is (if A and B are not transposed) * @@ -1691,8 +1703,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_small_gemm_batched_set_args( /** * @brief Returns a tall and skinny recipe * - * The program contains a kernel for beta = 0 called "gemm_beta0" and a kernel for beta != 0 called - * "gemm". M (= number of rows of A, C) and strides are dynamic. + * The program contains a kernel for beta = 0 called "gemm_beta0" and a kernel for beta != 0 + * called "gemm". M (= number of rows of A, C) and strides are dynamic. * * The signature of the generated kernels gemm and gemm_beta0 is * @@ -1717,8 +1729,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_small_gemm_batched_set_args( * @param ty [in] Scalar type of alpha, A, B, beta, C * @param N [in] Number of columns of B, C * @param K [in] Number columns of A, number of rows of B - * @param M_block_size [in][optional] Size of M block that each work group gets; pass 0 to have the - * parameter auto-selected + * @param M_block_size [in][optional] Size of M block that each work group gets; pass 0 to have + * the parameter auto-selected * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr * * @return tinytc_status_success on success and error otherwise @@ -1733,8 +1745,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_create( * Similar to tinytc_recipe_tall_and_skinny_create but with the additional specialization * constants M, ldA, ldB, and ldC. * The specializtion constants may be either set to a fixed value or to TINYTC_DYNAMIC. - * Note that if a specialization constant is set to a fixed value then the parameter with the same - * name in tinytc_recipe_tall_and_skinny_set_args is ignored. + * Note that if a specialization constant is set to a fixed value then the parameter with the + * same name in tinytc_recipe_tall_and_skinny_set_args is ignored. * * The generated kernels have the following signature: * @@ -1755,8 +1767,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_create( * @param ldA [in] Leading dimension of A; can be TINYTC_DYNAMIC * @param ldB [in] Leading dimension of B; can be TINYTC_DYNAMIC * @param ldC [in] Leading dimension of C; can be TINYTC_DYNAMIC - * @param M_block_size [in][optional] Size of M block that each work group gets; pass 0 to have the - * parameter auto-selected + * @param M_block_size [in][optional] Size of M block that each work group gets; pass 0 to have + * the parameter auto-selected * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr * * @return @@ -1808,8 +1820,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_set_args( * @brief Get prog object * * @param recipe [in] recipe object - * @param prg [out] pointer to prog object; reference count is increased so the user needs to call - * tinytc_prog_release to clean up + * @param prg [out] pointer to prog object; reference count is increased so the user needs to + * call tinytc_prog_release to clean up * * @return tinytc_status_success on success and error otherwise */ @@ -1820,8 +1832,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_get_prog(const_tinytc_recipe_t recip * @brief Get source object * * @param recipe [in] recipe object - * @param src [out] pointer to source object; reference count is increased so the user needs to call - * tinytc_source_release to clean up + * @param src [out] pointer to source object; reference count is increased so the user needs to + * call tinytc_source_release to clean up * * @return tinytc_status_success on success and error otherwise */ @@ -1852,8 +1864,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_retain(tinytc_recipe_t obj); * @brief Get recipe object * * @param handler [in] recipe handler object - * @param recipe [out] pointer to recipe object; reference count is increased so the user needs to - * call tinytc_recipe_release to clean up + * @param recipe [out] pointer to recipe object; reference count is increased so the user needs + * to call tinytc_recipe_release to clean up * * @return tinytc_status_success on success and error otherwise */ diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index c7c627da..8d4f331c 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -2252,6 +2252,20 @@ inline auto compile_to_opencl(prog prg, core_info const &info) -> source { return source{src}; } +/** + * @brief Compile program to SPIR-V + * + * @param prg Program + * @param info Core info + * + * @return Binary + */ +inline auto compile_to_spirv(prog prg, core_info const &info) -> binary { + tinytc_binary_t bin; + CHECK_STATUS(tinytc_prog_compile_to_spirv(&bin, prg.get(), info.get())); + return binary{bin}; +} + //////////////////////////// ////////// Recipe ////////// //////////////////////////// diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 410621ca..2305db1f 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -81,6 +81,9 @@ typedef enum { tinytc_status_ir_invalid_matrix_use = 0x122, ///< Invalid matrix use tinytc_status_ir_unsupported_coopmatrix_shape = 0x123, ///< Unsupported coopmatrix shape tinytc_status_ir_incompatible_scalar_types = 0x124, ///< Incompatible scalar types + // SPIR-V errors + tinytc_status_spirv_forbidden_forward_declaration = + 0x1000, ///< Forward declaration of id is forbidden // Level zero errors tinytc_status_ze_result_not_ready = 0x10000, ///< ZE_RESULT_NOT_READY tinytc_status_ze_result_error_device_lost = 0x10001, ///< ZE_RESULT_ERROR_DEVICE_LOST diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 9c02e5fb..766a6fb8 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -89,6 +89,7 @@ enum class status { ir_invalid_matrix_use = tinytc_status_ir_invalid_matrix_use, ir_unsupported_coopmatrix_shape = tinytc_status_ir_unsupported_coopmatrix_shape, ir_incompatible_scalar_types = tinytc_status_ir_incompatible_scalar_types, + spirv_forbidden_forward_declaration = tinytc_status_spirv_forbidden_forward_declaration, ze_result_not_ready = tinytc_status_ze_result_not_ready, ze_result_error_device_lost = tinytc_status_ze_result_error_device_lost, ze_result_error_out_of_host_memory = tinytc_status_ze_result_error_out_of_host_memory, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c61333ad..89f863f3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -45,6 +45,7 @@ set(SOURCES pass/constant_folding.cpp pass/constant_propagation.cpp pass/convert_to_opencl.cpp + pass/convert_to_spirv.cpp pass/dead_code_elimination.cpp pass/dump_cfg.cpp pass/dump_def_use.cpp @@ -62,6 +63,9 @@ set(SOURCES region.cpp required_extensions.cpp scalar_type.cpp + spv/module.cpp + spv/names.cpp + spv/pass/dump_asm.cpp source.cpp tiling.cpp value.cpp diff --git a/src/compiler.cpp b/src/compiler.cpp index 1c8ab4ef..5b1a359b 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -7,6 +7,7 @@ #include "pass/check_ir.hpp" #include "pass/constant_propagation.hpp" #include "pass/convert_to_opencl.hpp" +#include "pass/convert_to_spirv.hpp" #include "pass/dead_code_elimination.hpp" #include "pass/dump_cfg.hpp" #include "pass/dump_def_use.hpp" @@ -20,6 +21,7 @@ #include "reference_counted.hpp" #include "required_extensions.hpp" #include "source.hpp" +#include "spv/pass/dump_asm.hpp" #include "tinytc/tinytc.h" #include "tinytc/types.h" #include "tinytc/types.hpp" @@ -36,6 +38,8 @@ using namespace tinytc; +namespace tinytc { + template struct optflag_setter { PassT &pass; tinytc_compiler_context_t ctx; @@ -45,6 +49,37 @@ template struct optflag_setter { } }; +void apply_default_optimization_pipeline(tinytc_prog_t prg, const_tinytc_core_info_t info) { + auto ctx = prg->context(); + const auto opt_level = ctx->opt_level(); + + // passes + auto cpp = constant_propagation_pass{}; + optflag_setter{cpp, ctx}(tinytc::optflag::unsafe_fp_math); + + run_function_pass(check_ir_pass{}, *prg); + + if (opt_level >= 1) { + // We run constant propagation + dead code elimination early to capture dead allocas + // (later on they are maybe "in use" due to the lifetime_stop instruction) + run_function_pass(cpp, *prg); + run_function_pass(dead_code_elimination_pass{}, *prg); + } + + run_function_pass(insert_lifetime_stop_pass{}, *prg); + run_function_pass(set_stack_ptr_pass{}, *prg); + run_function_pass(insert_barrier_pass{}, *prg); + run_function_pass(work_group_size_pass{info}, *prg); + + run_function_pass(lower_linalg_pass{info}, *prg); + if (opt_level >= 1) { + run_function_pass(cpp, *prg); + run_function_pass(dead_code_elimination_pass{}, *prg); + } +} + +} // namespace tinytc + extern "C" { tinytc_status_t tinytc_run_function_pass(char const *pass_name, tinytc_prog_t prg, @@ -96,32 +131,7 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ } return exception_to_status_code( [&] { - auto ctx = prg->share_context(); - const auto opt_level = ctx->opt_level(); - - // passes - auto cpp = constant_propagation_pass{}; - optflag_setter{cpp, ctx.get()}(tinytc::optflag::unsafe_fp_math); - - run_function_pass(check_ir_pass{}, *prg); - - if (opt_level >= 1) { - // We run constant propagation + dead code elimination early to capture dead allocas - // (later on they are maybe "in use" due to the lifetime_stop instruction) - run_function_pass(cpp, *prg); - run_function_pass(dead_code_elimination_pass{}, *prg); - } - - run_function_pass(insert_lifetime_stop_pass{}, *prg); - run_function_pass(set_stack_ptr_pass{}, *prg); - run_function_pass(insert_barrier_pass{}, *prg); - run_function_pass(work_group_size_pass{info}, *prg); - - run_function_pass(lower_linalg_pass{info}, *prg); - if (opt_level >= 1) { - run_function_pass(cpp, *prg); - run_function_pass(dead_code_elimination_pass{}, *prg); - } + apply_default_optimization_pipeline(prg, info); // opencl auto ast = convert_to_opencl_pass{info}.run_on_program(*prg); @@ -135,10 +145,30 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ clir::generate_opencl(oss, std::move(ast)); - *src = std::make_unique<::tinytc_source>(std::move(ctx), oss.str(), prg->loc(), + *src = std::make_unique<::tinytc_source>(prg->share_context(), oss.str(), prg->loc(), std::move(ext), info->core_features()) .release(); }, prg->context()); } + +tinytc_status_t tinytc_prog_compile_to_spirv(tinytc_binary_t *bin, tinytc_prog_t prg, + const_tinytc_core_info_t info) { + if (bin == nullptr || prg == nullptr || info == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { + apply_default_optimization_pipeline(prg, info); + + // opencl + auto mod = convert_to_spirv_pass{info}.run_on_program(*prg); + spv::dump_asm_pass{std::cout}.run_on_module(*mod); + + //*bin = std::make_unique<::tinytc_binary>(prg->share_context(), mod.to_binary(), + // bundle_format::spirv, info->core_features()) + //.release(); + }, + prg->context()); +} } diff --git a/src/compiler_context_cache.cpp b/src/compiler_context_cache.cpp index c94bf416..0f0447ab 100644 --- a/src/compiler_context_cache.cpp +++ b/src/compiler_context_cache.cpp @@ -9,6 +9,7 @@ namespace tinytc { compiler_context_cache::compiler_context_cache(tinytc_compiler_context_t ctx) { + void_ty = std::unique_ptr(new void_data_type(ctx)); for (int i = 0; i < TINYTC_NUMBER_OF_SCALAR_TYPES; ++i) { scalar_tys[i] = std::unique_ptr(new scalar_data_type(ctx, enum_cast(i))); diff --git a/src/compiler_context_cache.hpp b/src/compiler_context_cache.hpp index a6855a0c..4d61c22c 100644 --- a/src/compiler_context_cache.hpp +++ b/src/compiler_context_cache.hpp @@ -35,6 +35,7 @@ class compiler_context_cache { compiler_context_cache(compiler_context_cache const &) = delete; compiler_context_cache &operator=(compiler_context_cache const &) = delete; + std::unique_ptr void_ty; std::array, TINYTC_NUMBER_OF_SCALAR_TYPES> scalar_tys; std::unordered_multimap memref_tys; std::unordered_multimap coopmatrix_tys; diff --git a/src/error.cpp b/src/error.cpp index 33d6d976..55deb3be 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -194,6 +194,8 @@ char const *tinytc_error_string(tinytc_status_t status) { "target architecture"; case tinytc_status_ir_incompatible_scalar_types: return "Scalar types violate compatibility rules"; + case tinytc_status_spirv_forbidden_forward_declaration: + return "Forward declaration of id is forbidden"; // Level Zero case tinytc_status_ze_result_not_ready: return "ZE_RESULT_NOT_READY"; diff --git a/src/node/data_type_node.cpp b/src/node/data_type_node.cpp index 07bd33c0..f6e1bec3 100644 --- a/src/node/data_type_node.cpp +++ b/src/node/data_type_node.cpp @@ -170,4 +170,8 @@ auto scalar_data_type::get(tinytc_compiler_context_t ctx, scalar_type ty) -> tin return ctx->cache()->scalar_tys[static_cast(ty)].get(); } +auto void_data_type::get(tinytc_compiler_context_t ctx) -> tinytc_data_type_t { + return ctx->cache()->void_ty.get(); +} + } // namespace tinytc diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp index 288ee2eb..ae0095cb 100644 --- a/src/node/data_type_node.hpp +++ b/src/node/data_type_node.hpp @@ -167,9 +167,11 @@ class scalar_data_type : public data_type_node { class void_data_type : public data_type_node { public: inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::void_; } + static auto get(tinytc_compiler_context_t ctx) -> tinytc_data_type_t; protected: inline void_data_type(tinytc_compiler_context_t ctx) : data_type_node(DTK::void_, ctx) {} + friend class compiler_context_cache; }; } // namespace tinytc diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 5c936006..7cf3a4a7 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -259,16 +259,18 @@ arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b if (a_ty != b_ty) { throw compilation_error(loc(), status::ir_scalar_mismatch); } + bool inst_supports_i1 = true; bool inst_supports_fp = true; bool inst_supports_complex = true; - bool inst_supports_i1 = true; switch (operation) { case arithmetic::add: case arithmetic::sub: case arithmetic::mul: case arithmetic::div: + inst_supports_i1 = false; break; case arithmetic::rem: + inst_supports_i1 = false; inst_supports_complex = false; break; case arithmetic::and_: @@ -315,12 +317,14 @@ arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0 auto a_ty = get_scalar_type(loc(), a()); to_ty = a_ty; + bool inst_supports_i1 = true; bool inst_supports_int = true; bool inst_supports_fp = true; bool inst_supports_complex = true; switch (operation_) { case arithmetic_unary::abs: case arithmetic_unary::neg: + inst_supports_i1 = false; break; case arithmetic_unary::not_: inst_supports_fp = false; @@ -329,10 +333,14 @@ arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0 case arithmetic_unary::conj: case arithmetic_unary::im: case arithmetic_unary::re: + inst_supports_i1 = false; inst_supports_int = false; inst_supports_fp = false; break; } + if (!inst_supports_i1 && a_ty->ty() == scalar_type::i1) { + throw compilation_error(loc(), status::ir_i1_unsupported); + } if (!inst_supports_int && is_integer_type(a_ty->ty())) { throw compilation_error(loc(), status::ir_int_unsupported); } diff --git a/src/pass/convert_to_spirv.cpp b/src/pass/convert_to_spirv.cpp new file mode 100644 index 00000000..2e5e5179 --- /dev/null +++ b/src/pass/convert_to_spirv.cpp @@ -0,0 +1,289 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/convert_to_spirv.hpp" +#include "node/data_type_node.hpp" +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "spv/instructions.hpp" +#include "support/visit.hpp" + +#include + +namespace tinytc { + +class spirv_converter { + public: + inline spirv_converter(spv::mod &mod, tinytc_compiler_context_t ctx) : mod_(&mod), ctx_(ctx) {} + + auto operator()(data_type_node const &ty) -> spv::spv_inst *; + + // Instruction nodes + void operator()(inst_node const &in); + void operator()(arith_inst const &in); + + void run_on_program(program_node const &p); + + private: + auto declare(value_node const &v, spv::spv_inst *in); + auto val(value_node const &v) -> spv::spv_inst *; + void run_on_region(region_node const &fn); + void run_on_function(function_node const &fn); + template auto add_to(Args &&...args) -> T * { + auto ptr = std::make_unique(std::forward(args)...).release(); + mod_->insts(S).push_back(ptr); + return ptr; + } + + template auto add(Args &&...args) -> T * { + return add_to(std::forward(args)...); + } + + spv::mod *mod_; + tinytc_compiler_context_t ctx_; + std::unordered_map spv_tys_; + std::unordered_map vals_; +}; + +auto spirv_converter::declare(value_node const &v, spv::spv_inst *in) { vals_[&v] = in; } +auto spirv_converter::val(value_node const &v) -> spv::spv_inst * { + if (auto it = vals_.find(&v); it != vals_.end()) { + return it->second; + } + throw status::internal_compiler_error; +} + +auto spirv_converter::operator()(data_type_node const &ty) -> spv::spv_inst * { + auto it = spv_tys_.find(&ty); + if (it == spv_tys_.end()) { + auto spv_ty = visit( + overloaded{ + [&](void_data_type const &) -> spv::spv_inst * { + return add_to(); + }, + [&](scalar_data_type const &ty) -> spv::spv_inst * { + switch (ty.ty()) { + case scalar_type::i1: + return add_to(); + case scalar_type::i8: + add_to(spv::Capability::Int8); + return add_to(8, 1); + case scalar_type::i16: + add_to(spv::Capability::Int16); + return add_to(16, 1); + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return add_to(size(ty.ty()) * 8, 1); + case scalar_type::f32: + case scalar_type::f64: + return add_to(size(ty.ty()) * 8, + std::nullopt); + case scalar_type::c32: + case scalar_type::c64: { + auto float_ty = add_to( + size(ty.ty()) * 8 / 2, std::nullopt); + return add_to(float_ty, 2); + } + } + throw status::internal_compiler_error; + }, + [&](coopmatrix_data_type const &ty) -> spv::spv_inst * { + // @todo + throw status::internal_compiler_error; + return nullptr; + }, + [](auto const &) -> spv::spv_inst * { + // @todo + throw status::internal_compiler_error; + return nullptr; + }}, + ty); + spv_tys_[&ty] = spv_ty; + return spv_ty; + } + return it->second; +} + +void spirv_converter::operator()(inst_node const &) { + // @todo + throw status::internal_compiler_error; +} + +void spirv_converter::operator()(arith_inst const &in) { + auto const make_boolean = [&](arithmetic op, spv::spv_inst *ty, spv::spv_inst *a, + spv::spv_inst *b) -> spv::spv_inst * { + switch (op) { + case arithmetic::and_: + return add(ty, a, b); + case arithmetic::or_: + return add(ty, a, b); + case arithmetic::xor_: + return add(ty, a, b); + default: + break; + } + throw status::ir_i1_unsupported; + }; + auto const make_int = [&](arithmetic op, spv::spv_inst *ty, spv::spv_inst *a, + spv::spv_inst *b) -> spv::spv_inst * { + switch (op) { + case arithmetic::add: + return add(ty, a, b); + case arithmetic::sub: + return add(ty, a, b); + case arithmetic::mul: + return add(ty, a, b); + case arithmetic::div: + return add(ty, a, b); + case arithmetic::rem: + return add(ty, a, b); + case arithmetic::shl: + return add(ty, a, b); + case arithmetic::shr: + return add(ty, a, b); + case arithmetic::and_: + return add(ty, a, b); + case arithmetic::or_: + return add(ty, a, b); + case arithmetic::xor_: + return add(ty, a, b); + } + throw status::internal_compiler_error; + }; + auto const make_float_complex = [&](arithmetic op, spv::spv_inst *ty, spv::spv_inst *a, + spv::spv_inst *b) -> spv::spv_inst * { + switch (op) { + case arithmetic::add: + return add(ty, a, b); + case arithmetic::sub: + return add(ty, a, b); + case arithmetic::mul: + return add(ty, a, b); + case arithmetic::div: + return add(ty, a, b); + case arithmetic::rem: + return add(ty, a, b); + default: + break; + } + throw status::ir_fp_unsupported; + }; + auto const make = [&](scalar_type sty, arithmetic op, spv::spv_inst *ty, spv::spv_inst *a, + spv::spv_inst *b) -> spv::spv_inst * { + switch (sty) { + case scalar_type::i1: + return make_boolean(op, ty, a, b); + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return make_int(op, ty, a, b); + case scalar_type::f32: + case scalar_type::f64: + case scalar_type::c32: + case scalar_type::c64: + return make_float_complex(op, ty, a, b); + } + throw status::internal_compiler_error; + }; + + auto ty = visit(*this, *in.result(0).ty()); + auto av = val(in.a()); + auto bv = val(in.b()); + + if (auto st = dyn_cast(in.result(0).ty()); st) { + make(st->ty(), in.operation(), ty, av, bv); + } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { + // auto clinst = std::vector{}; + // auto const len = ct->length(core_cfg_.subgroup_size); + // clinst.reserve(len + 1); + // clinst.emplace_back(declaration(std::move(lhs_ty), lhs)); + // const auto sty = ct->component_ty(); + // for (std::int64_t i = 0; i < len; ++i) { + // auto op = make(a.operation(), av[i], bv[i], sty); + // clinst.emplace_back(expression_statement(assignment(lhs[i], std::move(op)))); + //} + // return clinst; + } else { + throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); + } +} + +void spirv_converter::run_on_region(region_node const ®) { + add(); + for (auto const &i : reg) { + visit(*this, i); + } +} + +void spirv_converter::run_on_function(function_node const &fn) { + // Function type + auto void_ty = visit(*this, *void_data_type::get(ctx_)); + auto params = std::vector{}; + params.reserve(fn.num_params()); + for (auto const &p : fn.params()) { + params.push_back(visit(*this, *p.ty())); + } + auto fun_ty = add(void_ty, std::move(params)); + + // Function + auto fun = add(void_ty, spv::FunctionControl::None, fun_ty); + for (auto const &p : fn.params()) { + declare(p, add(visit(*this, *p.ty()))); + } + run_on_region(fn.body()); + add(); + + // Entry point + add_to( + spv::ExecutionModel::Kernel, fun, std::string{fn.name()}, std::vector{}); + + // Execution mode + auto const work_group_size = fn.work_group_size(); + add_to( + fun, spv::ExecutionMode::LocalSize, + spv::ExecutionModeAttr{ + std::array{work_group_size[0], work_group_size[1], 1}}); + add_to( + fun, spv::ExecutionMode::SubgroupSize, spv::ExecutionModeAttr{fn.subgroup_size()}); + + // Function decoration + auto linkage_decoration = + spv::DecorationAttr{std::make_pair(std::string{fn.name()}, spv::LinkageType::Export)}; + add_to(fun, spv::Decoration::LinkageAttributes, + std::move(linkage_decoration)); +} + +void spirv_converter::run_on_program(program_node const &p) { + add_to(spv::Capability::Addresses); + add_to(spv::Capability::Kernel); + add_to(spv::Capability::Linkage); + add_to(spv::Capability::SubgroupDispatch); + add_to(spv::AddressingModel::Physical64, + spv::MemoryModel::OpenCL); + + for (auto const &fn : p) { + run_on_function(fn); + } +} + +convert_to_spirv_pass::convert_to_spirv_pass(::tinytc_core_info const *info) + : info_(std::move(info)) { + if (info_ == nullptr) { + throw std::invalid_argument("info must not be nullptr"); + } +} + +auto convert_to_spirv_pass::run_on_program(program_node const &p) -> std::unique_ptr { + auto m = std::make_unique(); + + spirv_converter(*m, p.context()).run_on_program(p); + + return m; +} + +} // namespace tinytc diff --git a/src/pass/convert_to_spirv.hpp b/src/pass/convert_to_spirv.hpp new file mode 100644 index 00000000..a520a9ca --- /dev/null +++ b/src/pass/convert_to_spirv.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CONVERT_TO_SPIRV_20241029_HPP +#define CONVERT_TO_SPIRV_20241029_HPP + +#include "device_info.hpp" +#include "node/program_node.hpp" +#include "spv/module.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include +#include + +namespace tinytc { + +class convert_to_spirv_pass { + public: + convert_to_spirv_pass(::tinytc_core_info const *info); + + auto run_on_program(program_node const &p) -> std::unique_ptr; + + private: + ::tinytc_core_info const *info_; +}; + +} // namespace tinytc + +#endif // CONVERT_TO_SPIRV_20241029_HPP diff --git a/src/passes.def b/src/passes.def index eda7d20a..2e3da426 100644 --- a/src/passes.def +++ b/src/passes.def @@ -11,4 +11,4 @@ FUNCTION_PASS("insert-barrier", insert_barrier_pass{}) FUNCTION_PASS("insert-lifetime-stop", insert_lifetime_stop_pass{}) FUNCTION_PASS("set-stack-ptr", set_stack_ptr_pass{}) FUNCTION_PASS_WITH_INFO("lower-linalg", [](tinytc_core_info const* info) { return lower_linalg_pass{info}; }) -FUNCTION_PASS_WITH_INFO("work-group-size", [](tinytc_core_info const* info) { return work_group_size_pass(info); }) +FUNCTION_PASS_WITH_INFO("work-group-size", [](tinytc_core_info const* info) { return work_group_size_pass{info}; }) diff --git a/src/spv/enums.hpp b/src/spv/enums.hpp new file mode 100644 index 00000000..457a9ee4 --- /dev/null +++ b/src/spv/enums.hpp @@ -0,0 +1,1425 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_ENUMS_2024114_HPP +#define GENERATED_ENUMS_2024114_HPP + +namespace tinytc::spv { + +enum class Op { + Nop = 0, + Undef = 1, + SourceContinued = 2, + Source = 3, + SourceExtension = 4, + Name = 5, + MemberName = 6, + String = 7, + Line = 8, + Extension = 10, + ExtInstImport = 11, + ExtInst = 12, + MemoryModel = 14, + EntryPoint = 15, + ExecutionMode = 16, + Capability = 17, + TypeVoid = 19, + TypeBool = 20, + TypeInt = 21, + TypeFloat = 22, + TypeVector = 23, + TypeMatrix = 24, + TypeImage = 25, + TypeSampler = 26, + TypeSampledImage = 27, + TypeArray = 28, + TypeRuntimeArray = 29, + TypeStruct = 30, + TypeOpaque = 31, + TypePointer = 32, + TypeFunction = 33, + TypeEvent = 34, + TypeDeviceEvent = 35, + TypeReserveId = 36, + TypeQueue = 37, + TypePipe = 38, + TypeForwardPointer = 39, + ConstantTrue = 41, + ConstantFalse = 42, + Constant = 43, + ConstantComposite = 44, + ConstantSampler = 45, + ConstantNull = 46, + Function = 54, + FunctionParameter = 55, + FunctionEnd = 56, + FunctionCall = 57, + Variable = 59, + ImageTexelPointer = 60, + Load = 61, + Store = 62, + CopyMemory = 63, + CopyMemorySized = 64, + AccessChain = 65, + InBoundsAccessChain = 66, + PtrAccessChain = 67, + ArrayLength = 68, + GenericPtrMemSemantics = 69, + InBoundsPtrAccessChain = 70, + Decorate = 71, + MemberDecorate = 72, + DecorationGroup = 73, + GroupDecorate = 74, + GroupMemberDecorate = 75, + VectorExtractDynamic = 77, + VectorInsertDynamic = 78, + VectorShuffle = 79, + CompositeConstruct = 80, + CompositeExtract = 81, + CompositeInsert = 82, + CopyObject = 83, + Transpose = 84, + SampledImage = 86, + ImageSampleImplicitLod = 87, + ImageSampleExplicitLod = 88, + ImageSampleDrefImplicitLod = 89, + ImageSampleDrefExplicitLod = 90, + ImageSampleProjImplicitLod = 91, + ImageSampleProjExplicitLod = 92, + ImageSampleProjDrefImplicitLod = 93, + ImageSampleProjDrefExplicitLod = 94, + ImageFetch = 95, + ImageGather = 96, + ImageDrefGather = 97, + ImageRead = 98, + ImageWrite = 99, + Image = 100, + ImageQueryFormat = 101, + ImageQueryOrder = 102, + ImageQuerySizeLod = 103, + ImageQuerySize = 104, + ImageQueryLod = 105, + ImageQueryLevels = 106, + ImageQuerySamples = 107, + ConvertFToU = 109, + ConvertFToS = 110, + ConvertSToF = 111, + ConvertUToF = 112, + UConvert = 113, + SConvert = 114, + FConvert = 115, + QuantizeToF16 = 116, + ConvertPtrToU = 117, + SatConvertSToU = 118, + SatConvertUToS = 119, + ConvertUToPtr = 120, + PtrCastToGeneric = 121, + GenericCastToPtr = 122, + GenericCastToPtrExplicit = 123, + Bitcast = 124, + SNegate = 126, + FNegate = 127, + IAdd = 128, + FAdd = 129, + ISub = 130, + FSub = 131, + IMul = 132, + FMul = 133, + UDiv = 134, + SDiv = 135, + FDiv = 136, + UMod = 137, + SRem = 138, + SMod = 139, + FRem = 140, + FMod = 141, + VectorTimesScalar = 142, + MatrixTimesScalar = 143, + VectorTimesMatrix = 144, + MatrixTimesVector = 145, + MatrixTimesMatrix = 146, + OuterProduct = 147, + Dot = 148, + IAddCarry = 149, + ISubBorrow = 150, + UMulExtended = 151, + SMulExtended = 152, + Any = 154, + All = 155, + IsNan = 156, + IsInf = 157, + IsFinite = 158, + IsNormal = 159, + SignBitSet = 160, + LessOrGreater = 161, + Ordered = 162, + Unordered = 163, + LogicalEqual = 164, + LogicalNotEqual = 165, + LogicalOr = 166, + LogicalAnd = 167, + LogicalNot = 168, + Select = 169, + IEqual = 170, + INotEqual = 171, + UGreaterThan = 172, + SGreaterThan = 173, + UGreaterThanEqual = 174, + SGreaterThanEqual = 175, + ULessThan = 176, + SLessThan = 177, + ULessThanEqual = 178, + SLessThanEqual = 179, + FOrdEqual = 180, + FUnordEqual = 181, + FOrdNotEqual = 182, + FUnordNotEqual = 183, + FOrdLessThan = 184, + FUnordLessThan = 185, + FOrdGreaterThan = 186, + FUnordGreaterThan = 187, + FOrdLessThanEqual = 188, + FUnordLessThanEqual = 189, + FOrdGreaterThanEqual = 190, + FUnordGreaterThanEqual = 191, + ShiftRightLogical = 194, + ShiftRightArithmetic = 195, + ShiftLeftLogical = 196, + BitwiseOr = 197, + BitwiseXor = 198, + BitwiseAnd = 199, + Not = 200, + BitFieldInsert = 201, + BitFieldSExtract = 202, + BitFieldUExtract = 203, + BitReverse = 204, + BitCount = 205, + DPdx = 207, + DPdy = 208, + Fwidth = 209, + DPdxFine = 210, + DPdyFine = 211, + FwidthFine = 212, + DPdxCoarse = 213, + DPdyCoarse = 214, + FwidthCoarse = 215, + EmitVertex = 218, + EndPrimitive = 219, + EmitStreamVertex = 220, + EndStreamPrimitive = 221, + ControlBarrier = 224, + MemoryBarrier = 225, + AtomicLoad = 227, + AtomicStore = 228, + AtomicExchange = 229, + AtomicCompareExchange = 230, + AtomicCompareExchangeWeak = 231, + AtomicIIncrement = 232, + AtomicIDecrement = 233, + AtomicIAdd = 234, + AtomicISub = 235, + AtomicSMin = 236, + AtomicUMin = 237, + AtomicSMax = 238, + AtomicUMax = 239, + AtomicAnd = 240, + AtomicOr = 241, + AtomicXor = 242, + Phi = 245, + LoopMerge = 246, + SelectionMerge = 247, + Label = 248, + Branch = 249, + BranchConditional = 250, + Switch = 251, + Kill = 252, + Return = 253, + ReturnValue = 254, + Unreachable = 255, + LifetimeStart = 256, + LifetimeStop = 257, + GroupAsyncCopy = 259, + GroupWaitEvents = 260, + GroupAll = 261, + GroupAny = 262, + GroupBroadcast = 263, + GroupIAdd = 264, + GroupFAdd = 265, + GroupFMin = 266, + GroupUMin = 267, + GroupSMin = 268, + GroupFMax = 269, + GroupUMax = 270, + GroupSMax = 271, + ReadPipe = 274, + WritePipe = 275, + ReservedReadPipe = 276, + ReservedWritePipe = 277, + ReserveReadPipePackets = 278, + ReserveWritePipePackets = 279, + CommitReadPipe = 280, + CommitWritePipe = 281, + IsValidReserveId = 282, + GetNumPipePackets = 283, + GetMaxPipePackets = 284, + GroupReserveReadPipePackets = 285, + GroupReserveWritePipePackets = 286, + GroupCommitReadPipe = 287, + GroupCommitWritePipe = 288, + EnqueueMarker = 291, + EnqueueKernel = 292, + GetKernelNDrangeSubGroupCount = 293, + GetKernelNDrangeMaxSubGroupSize = 294, + GetKernelWorkGroupSize = 295, + GetKernelPreferredWorkGroupSizeMultiple = 296, + RetainEvent = 297, + ReleaseEvent = 298, + CreateUserEvent = 299, + IsValidEvent = 300, + SetUserEventStatus = 301, + CaptureEventProfilingInfo = 302, + GetDefaultQueue = 303, + BuildNDRange = 304, + ImageSparseSampleImplicitLod = 305, + ImageSparseSampleExplicitLod = 306, + ImageSparseSampleDrefImplicitLod = 307, + ImageSparseSampleDrefExplicitLod = 308, + ImageSparseSampleProjImplicitLod = 309, + ImageSparseSampleProjExplicitLod = 310, + ImageSparseSampleProjDrefImplicitLod = 311, + ImageSparseSampleProjDrefExplicitLod = 312, + ImageSparseFetch = 313, + ImageSparseGather = 314, + ImageSparseDrefGather = 315, + ImageSparseTexelsResident = 316, + NoLine = 317, + AtomicFlagTestAndSet = 318, + AtomicFlagClear = 319, + ImageSparseRead = 320, + SizeOf = 321, + TypePipeStorage = 322, + ConstantPipeStorage = 323, + CreatePipeFromPipeStorage = 324, + GetKernelLocalSizeForSubgroupCount = 325, + GetKernelMaxNumSubgroups = 326, + TypeNamedBarrier = 327, + NamedBarrierInitialize = 328, + MemoryNamedBarrier = 329, + ModuleProcessed = 330, + ExecutionModeId = 331, + DecorateId = 332, + GroupNonUniformElect = 333, + GroupNonUniformAll = 334, + GroupNonUniformAny = 335, + GroupNonUniformAllEqual = 336, + GroupNonUniformBroadcast = 337, + GroupNonUniformBroadcastFirst = 338, + GroupNonUniformBallot = 339, + GroupNonUniformInverseBallot = 340, + GroupNonUniformBallotBitExtract = 341, + GroupNonUniformBallotBitCount = 342, + GroupNonUniformBallotFindLSB = 343, + GroupNonUniformBallotFindMSB = 344, + GroupNonUniformShuffle = 345, + GroupNonUniformShuffleXor = 346, + GroupNonUniformShuffleUp = 347, + GroupNonUniformShuffleDown = 348, + GroupNonUniformIAdd = 349, + GroupNonUniformFAdd = 350, + GroupNonUniformIMul = 351, + GroupNonUniformFMul = 352, + GroupNonUniformSMin = 353, + GroupNonUniformUMin = 354, + GroupNonUniformFMin = 355, + GroupNonUniformSMax = 356, + GroupNonUniformUMax = 357, + GroupNonUniformFMax = 358, + GroupNonUniformBitwiseAnd = 359, + GroupNonUniformBitwiseOr = 360, + GroupNonUniformBitwiseXor = 361, + GroupNonUniformLogicalAnd = 362, + GroupNonUniformLogicalOr = 363, + GroupNonUniformLogicalXor = 364, + GroupNonUniformQuadBroadcast = 365, + GroupNonUniformQuadSwap = 366, + CopyLogical = 400, + PtrEqual = 401, + PtrNotEqual = 402, + PtrDiff = 403, + TypeCooperativeMatrixKHR = 4456, + CooperativeMatrixLoadKHR = 4457, + CooperativeMatrixStoreKHR = 4458, + CooperativeMatrixMulAddKHR = 4459, + CooperativeMatrixLengthKHR = 4460, +}; +enum class ImageOperands { + None = 0x0000, + Bias = 0x0001, + Lod = 0x0002, + Grad = 0x0004, + ConstOffset = 0x0008, + Offset = 0x0010, + ConstOffsets = 0x0020, + Sample = 0x0040, + MinLod = 0x0080, + MakeTexelAvailable = 0x0100, + MakeTexelVisible = 0x0200, + NonPrivateTexel = 0x0400, + VolatileTexel = 0x0800, + SignExtend = 0x1000, + ZeroExtend = 0x2000, + Nontemporal = 0x4000, + Offsets = 0x10000, +}; +enum class FPFastMathMode { + None = 0x0000, + NotNaN = 0x0001, + NotInf = 0x0002, + NSZ = 0x0004, + AllowRecip = 0x0008, + Fast = 0x0010, + AllowContract = 0x10000, + AllowReassoc = 0x20000, + AllowTransform = 0x40000, +}; +enum class SelectionControl { + None = 0x0000, + Flatten = 0x0001, + DontFlatten = 0x0002, +}; +enum class LoopControl { + None = 0x0000, + Unroll = 0x0001, + DontUnroll = 0x0002, + DependencyInfinite = 0x0004, + DependencyLength = 0x0008, + MinIterations = 0x0010, + MaxIterations = 0x0020, + IterationMultiple = 0x0040, + PeelCount = 0x0080, + PartialCount = 0x0100, + InitiationIntervalINTEL = 0x10000, + MaxConcurrencyINTEL = 0x20000, + DependencyArrayINTEL = 0x40000, + PipelineEnableINTEL = 0x80000, + LoopCoalesceINTEL = 0x100000, + MaxInterleavingINTEL = 0x200000, + SpeculatedIterationsINTEL = 0x400000, + NoFusionINTEL = 0x800000, + LoopCountINTEL = 0x1000000, + MaxReinvocationDelayINTEL = 0x2000000, +}; +enum class FunctionControl { + None = 0x0000, + Inline = 0x0001, + DontInline = 0x0002, + Pure = 0x0004, + Const = 0x0008, + OptNoneEXT = 0x10000, +}; +enum class MemorySemantics { + Relaxed = 0x0000, + Acquire = 0x0002, + Release = 0x0004, + AcquireRelease = 0x0008, + SequentiallyConsistent = 0x0010, + UniformMemory = 0x0040, + SubgroupMemory = 0x0080, + WorkgroupMemory = 0x0100, + CrossWorkgroupMemory = 0x0200, + AtomicCounterMemory = 0x0400, + ImageMemory = 0x0800, + OutputMemory = 0x1000, + MakeAvailable = 0x2000, + MakeVisible = 0x4000, + Volatile = 0x8000, +}; +enum class MemoryAccess { + None = 0x0000, + Volatile = 0x0001, + Aligned = 0x0002, + Nontemporal = 0x0004, + MakePointerAvailable = 0x0008, + MakePointerVisible = 0x0010, + NonPrivatePointer = 0x0020, + AliasScopeINTELMask = 0x10000, + NoAliasINTELMask = 0x20000, +}; +enum class KernelProfilingInfo { + None = 0x0000, + CmdExecTime = 0x0001, +}; +enum class RayFlags { + NoneKHR = 0x0000, + OpaqueKHR = 0x0001, + NoOpaqueKHR = 0x0002, + TerminateOnFirstHitKHR = 0x0004, + SkipClosestHitShaderKHR = 0x0008, + CullBackFacingTrianglesKHR = 0x0010, + CullFrontFacingTrianglesKHR = 0x0020, + CullOpaqueKHR = 0x0040, + CullNoOpaqueKHR = 0x0080, + SkipTrianglesKHR = 0x0100, + SkipAABBsKHR = 0x0200, + ForceOpacityMicromap2StateEXT = 0x0400, +}; +enum class FragmentShadingRate { + Vertical2Pixels = 0x0001, + Vertical4Pixels = 0x0002, + Horizontal2Pixels = 0x0004, + Horizontal4Pixels = 0x0008, +}; +enum class RawAccessChainOperands { + None = 0x0000, + RobustnessPerComponentNV = 0x0001, + RobustnessPerElementNV = 0x0002, +}; +enum class SourceLanguage { + Unknown = 0, + ESSL = 1, + GLSL = 2, + OpenCL_C = 3, + OpenCL_CPP = 4, + HLSL = 5, + CPP_for_OpenCL = 6, + SYCL = 7, + HERO_C = 8, + NZSL = 9, + WGSL = 10, + Slang = 11, + Zig = 12, +}; +enum class ExecutionModel { + Vertex = 0, + TessellationControl = 1, + TessellationEvaluation = 2, + Geometry = 3, + Fragment = 4, + GLCompute = 5, + Kernel = 6, + TaskNV = 5267, + MeshNV = 5268, + RayGenerationKHR = 5313, + IntersectionKHR = 5314, + AnyHitKHR = 5315, + ClosestHitKHR = 5316, + MissKHR = 5317, + CallableKHR = 5318, + TaskEXT = 5364, + MeshEXT = 5365, +}; +enum class AddressingModel { + Logical = 0, + Physical32 = 1, + Physical64 = 2, + PhysicalStorageBuffer64 = 5348, +}; +enum class MemoryModel { + Simple = 0, + GLSL450 = 1, + OpenCL = 2, + Vulkan = 3, +}; +enum class ExecutionMode { + Invocations = 0, + SpacingEqual = 1, + SpacingFractionalEven = 2, + SpacingFractionalOdd = 3, + VertexOrderCw = 4, + VertexOrderCcw = 5, + PixelCenterInteger = 6, + OriginUpperLeft = 7, + OriginLowerLeft = 8, + EarlyFragmentTests = 9, + PointMode = 10, + Xfb = 11, + DepthReplacing = 12, + DepthGreater = 14, + DepthLess = 15, + DepthUnchanged = 16, + LocalSize = 17, + LocalSizeHint = 18, + InputPoints = 19, + InputLines = 20, + InputLinesAdjacency = 21, + Triangles = 22, + InputTrianglesAdjacency = 23, + Quads = 24, + Isolines = 25, + OutputVertices = 26, + OutputPoints = 27, + OutputLineStrip = 28, + OutputTriangleStrip = 29, + VecTypeHint = 30, + ContractionOff = 31, + Initializer = 33, + Finalizer = 34, + SubgroupSize = 35, + SubgroupsPerWorkgroup = 36, + SubgroupsPerWorkgroupId = 37, + LocalSizeId = 38, + LocalSizeHintId = 39, + NonCoherentColorAttachmentReadEXT = 4169, + NonCoherentDepthAttachmentReadEXT = 4170, + NonCoherentStencilAttachmentReadEXT = 4171, + SubgroupUniformControlFlowKHR = 4421, + PostDepthCoverage = 4446, + DenormPreserve = 4459, + DenormFlushToZero = 4460, + SignedZeroInfNanPreserve = 4461, + RoundingModeRTE = 4462, + RoundingModeRTZ = 4463, + EarlyAndLateFragmentTestsAMD = 5017, + StencilRefReplacingEXT = 5027, + CoalescingAMDX = 5069, + IsApiEntryAMDX = 5070, + MaxNodeRecursionAMDX = 5071, + StaticNumWorkgroupsAMDX = 5072, + ShaderIndexAMDX = 5073, + MaxNumWorkgroupsAMDX = 5077, + StencilRefUnchangedFrontAMD = 5079, + StencilRefGreaterFrontAMD = 5080, + StencilRefLessFrontAMD = 5081, + StencilRefUnchangedBackAMD = 5082, + StencilRefGreaterBackAMD = 5083, + StencilRefLessBackAMD = 5084, + QuadDerivativesKHR = 5088, + RequireFullQuadsKHR = 5089, + SharesInputWithAMDX = 5102, + OutputLinesEXT = 5269, + OutputPrimitivesEXT = 5270, + DerivativeGroupQuadsKHR = 5289, + DerivativeGroupLinearKHR = 5290, + OutputTrianglesEXT = 5298, + PixelInterlockOrderedEXT = 5366, + PixelInterlockUnorderedEXT = 5367, + SampleInterlockOrderedEXT = 5368, + SampleInterlockUnorderedEXT = 5369, + ShadingRateInterlockOrderedEXT = 5370, + ShadingRateInterlockUnorderedEXT = 5371, + SharedLocalMemorySizeINTEL = 5618, + RoundingModeRTPINTEL = 5620, + RoundingModeRTNINTEL = 5621, + FloatingPointModeALTINTEL = 5622, + FloatingPointModeIEEEINTEL = 5623, + MaxWorkgroupSizeINTEL = 5893, + MaxWorkDimINTEL = 5894, + NoGlobalOffsetINTEL = 5895, + NumSIMDWorkitemsINTEL = 5896, + SchedulerTargetFmaxMhzINTEL = 5903, + MaximallyReconvergesKHR = 6023, + FPFastMathDefault = 6028, + StreamingInterfaceINTEL = 6154, + RegisterMapInterfaceINTEL = 6160, + NamedBarrierCountINTEL = 6417, + MaximumRegistersINTEL = 6461, + MaximumRegistersIdINTEL = 6462, + NamedMaximumRegistersINTEL = 6463, +}; +enum class StorageClass { + UniformConstant = 0, + Input = 1, + Uniform = 2, + Output = 3, + Workgroup = 4, + CrossWorkgroup = 5, + Private = 6, + Function = 7, + Generic = 8, + PushConstant = 9, + AtomicCounter = 10, + Image = 11, + StorageBuffer = 12, + TileImageEXT = 4172, + NodePayloadAMDX = 5068, + CallableDataKHR = 5328, + IncomingCallableDataKHR = 5329, + RayPayloadKHR = 5338, + HitAttributeKHR = 5339, + IncomingRayPayloadKHR = 5342, + ShaderRecordBufferKHR = 5343, + PhysicalStorageBuffer = 5349, + HitObjectAttributeNV = 5385, + TaskPayloadWorkgroupEXT = 5402, + CodeSectionINTEL = 5605, + DeviceOnlyINTEL = 5936, + HostOnlyINTEL = 5937, +}; +enum class Dim { + Dim1D = 0, + Dim2D = 1, + Dim3D = 2, + Cube = 3, + Rect = 4, + Buffer = 5, + SubpassData = 6, + TileImageDataEXT = 4173, +}; +enum class SamplerAddressingMode { + None = 0, + ClampToEdge = 1, + Clamp = 2, + Repeat = 3, + RepeatMirrored = 4, +}; +enum class SamplerFilterMode { + Nearest = 0, + Linear = 1, +}; +enum class ImageFormat { + Unknown = 0, + Rgba32f = 1, + Rgba16f = 2, + R32f = 3, + Rgba8 = 4, + Rgba8Snorm = 5, + Rg32f = 6, + Rg16f = 7, + R11fG11fB10f = 8, + R16f = 9, + Rgba16 = 10, + Rgb10A2 = 11, + Rg16 = 12, + Rg8 = 13, + R16 = 14, + R8 = 15, + Rgba16Snorm = 16, + Rg16Snorm = 17, + Rg8Snorm = 18, + R16Snorm = 19, + R8Snorm = 20, + Rgba32i = 21, + Rgba16i = 22, + Rgba8i = 23, + R32i = 24, + Rg32i = 25, + Rg16i = 26, + Rg8i = 27, + R16i = 28, + R8i = 29, + Rgba32ui = 30, + Rgba16ui = 31, + Rgba8ui = 32, + R32ui = 33, + Rgb10a2ui = 34, + Rg32ui = 35, + Rg16ui = 36, + Rg8ui = 37, + R16ui = 38, + R8ui = 39, + R64ui = 40, + R64i = 41, +}; +enum class ImageChannelOrder { + R = 0, + A = 1, + RG = 2, + RA = 3, + RGB = 4, + RGBA = 5, + BGRA = 6, + ARGB = 7, + Intensity = 8, + Luminance = 9, + Rx = 10, + RGx = 11, + RGBx = 12, + Depth = 13, + DepthStencil = 14, + sRGB = 15, + sRGBx = 16, + sRGBA = 17, + sBGRA = 18, + ABGR = 19, +}; +enum class ImageChannelDataType { + SnormInt8 = 0, + SnormInt16 = 1, + UnormInt8 = 2, + UnormInt16 = 3, + UnormShort565 = 4, + UnormShort555 = 5, + UnormInt101010 = 6, + SignedInt8 = 7, + SignedInt16 = 8, + SignedInt32 = 9, + UnsignedInt8 = 10, + UnsignedInt16 = 11, + UnsignedInt32 = 12, + HalfFloat = 13, + Float = 14, + UnormInt24 = 15, + UnormInt101010_2 = 16, + UnsignedIntRaw10EXT = 19, + UnsignedIntRaw12EXT = 20, + UnormInt2_101010EXT = 21, +}; +enum class FPRoundingMode { + RTE = 0, + RTZ = 1, + RTP = 2, + RTN = 3, +}; +enum class FPDenormMode { + Preserve = 0, + FlushToZero = 1, +}; +enum class QuantizationModes { + TRN = 0, + TRN_ZERO = 1, + RND = 2, + RND_ZERO = 3, + RND_INF = 4, + RND_MIN_INF = 5, + RND_CONV = 6, + RND_CONV_ODD = 7, +}; +enum class FPOperationMode { + IEEE = 0, + ALT = 1, +}; +enum class OverflowModes { + WRAP = 0, + SAT = 1, + SAT_ZERO = 2, + SAT_SYM = 3, +}; +enum class LinkageType { + Export = 0, + Import = 1, + LinkOnceODR = 2, +}; +enum class AccessQualifier { + ReadOnly = 0, + WriteOnly = 1, + ReadWrite = 2, +}; +enum class HostAccessQualifier { + NoneINTEL = 0, + ReadINTEL = 1, + WriteINTEL = 2, + ReadWriteINTEL = 3, +}; +enum class FunctionParameterAttribute { + Zext = 0, + Sext = 1, + ByVal = 2, + Sret = 3, + NoAlias = 4, + NoCapture = 5, + NoWrite = 6, + NoReadWrite = 7, + RuntimeAlignedINTEL = 5940, +}; +enum class Decoration { + RelaxedPrecision = 0, + SpecId = 1, + Block = 2, + BufferBlock = 3, + RowMajor = 4, + ColMajor = 5, + ArrayStride = 6, + MatrixStride = 7, + GLSLShared = 8, + GLSLPacked = 9, + CPacked = 10, + BuiltIn = 11, + NoPerspective = 13, + Flat = 14, + Patch = 15, + Centroid = 16, + Sample = 17, + Invariant = 18, + Restrict = 19, + Aliased = 20, + Volatile = 21, + Constant = 22, + Coherent = 23, + NonWritable = 24, + NonReadable = 25, + Uniform = 26, + UniformId = 27, + SaturatedConversion = 28, + Stream = 29, + Location = 30, + Component = 31, + Index = 32, + Binding = 33, + DescriptorSet = 34, + Offset = 35, + XfbBuffer = 36, + XfbStride = 37, + FuncParamAttr = 38, + FPRoundingMode = 39, + FPFastMathMode = 40, + LinkageAttributes = 41, + NoContraction = 42, + InputAttachmentIndex = 43, + Alignment = 44, + MaxByteOffset = 45, + AlignmentId = 46, + MaxByteOffsetId = 47, + NoSignedWrap = 4469, + NoUnsignedWrap = 4470, + WeightTextureQCOM = 4487, + BlockMatchTextureQCOM = 4488, + BlockMatchSamplerQCOM = 4499, + ExplicitInterpAMD = 4999, + NodeSharesPayloadLimitsWithAMDX = 5019, + NodeMaxPayloadsAMDX = 5020, + TrackFinishWritingAMDX = 5078, + PayloadNodeNameAMDX = 5091, + PayloadNodeBaseIndexAMDX = 5098, + PayloadNodeSparseArrayAMDX = 5099, + PayloadNodeArraySizeAMDX = 5100, + PayloadDispatchIndirectAMDX = 5105, + OverrideCoverageNV = 5248, + PassthroughNV = 5250, + ViewportRelativeNV = 5252, + SecondaryViewportRelativeNV = 5256, + PerPrimitiveEXT = 5271, + PerViewNV = 5272, + PerTaskNV = 5273, + PerVertexKHR = 5285, + NonUniform = 5300, + RestrictPointer = 5355, + AliasedPointer = 5356, + HitObjectShaderRecordBufferNV = 5386, + BindlessSamplerNV = 5398, + BindlessImageNV = 5399, + BoundSamplerNV = 5400, + BoundImageNV = 5401, + SIMTCallINTEL = 5599, + ReferencedIndirectlyINTEL = 5602, + ClobberINTEL = 5607, + SideEffectsINTEL = 5608, + VectorComputeVariableINTEL = 5624, + FuncParamIOKindINTEL = 5625, + VectorComputeFunctionINTEL = 5626, + StackCallINTEL = 5627, + GlobalVariableOffsetINTEL = 5628, + CounterBuffer = 5634, + UserSemantic = 5635, + UserTypeGOOGLE = 5636, + FunctionRoundingModeINTEL = 5822, + FunctionDenormModeINTEL = 5823, + RegisterINTEL = 5825, + MemoryINTEL = 5826, + NumbanksINTEL = 5827, + BankwidthINTEL = 5828, + MaxPrivateCopiesINTEL = 5829, + SinglepumpINTEL = 5830, + DoublepumpINTEL = 5831, + MaxReplicatesINTEL = 5832, + SimpleDualPortINTEL = 5833, + MergeINTEL = 5834, + BankBitsINTEL = 5835, + ForcePow2DepthINTEL = 5836, + StridesizeINTEL = 5883, + WordsizeINTEL = 5884, + TrueDualPortINTEL = 5885, + BurstCoalesceINTEL = 5899, + CacheSizeINTEL = 5900, + DontStaticallyCoalesceINTEL = 5901, + PrefetchINTEL = 5902, + StallEnableINTEL = 5905, + FuseLoopsInFunctionINTEL = 5907, + MathOpDSPModeINTEL = 5909, + AliasScopeINTEL = 5914, + NoAliasINTEL = 5915, + InitiationIntervalINTEL = 5917, + MaxConcurrencyINTEL = 5918, + PipelineEnableINTEL = 5919, + BufferLocationINTEL = 5921, + IOPipeStorageINTEL = 5944, + FunctionFloatingPointModeINTEL = 6080, + SingleElementVectorINTEL = 6085, + VectorComputeCallableFunctionINTEL = 6087, + MediaBlockIOINTEL = 6140, + StallFreeINTEL = 6151, + FPMaxErrorDecorationINTEL = 6170, + LatencyControlLabelINTEL = 6172, + LatencyControlConstraintINTEL = 6173, + ConduitKernelArgumentINTEL = 6175, + RegisterMapKernelArgumentINTEL = 6176, + MMHostInterfaceAddressWidthINTEL = 6177, + MMHostInterfaceDataWidthINTEL = 6178, + MMHostInterfaceLatencyINTEL = 6179, + MMHostInterfaceReadWriteModeINTEL = 6180, + MMHostInterfaceMaxBurstINTEL = 6181, + MMHostInterfaceWaitRequestINTEL = 6182, + StableKernelArgumentINTEL = 6183, + HostAccessINTEL = 6188, + InitModeINTEL = 6190, + ImplementInRegisterMapINTEL = 6191, + CacheControlLoadINTEL = 6442, + CacheControlStoreINTEL = 6443, +}; +enum class BuiltIn { + Position = 0, + PointSize = 1, + ClipDistance = 3, + CullDistance = 4, + VertexId = 5, + InstanceId = 6, + PrimitiveId = 7, + InvocationId = 8, + Layer = 9, + ViewportIndex = 10, + TessLevelOuter = 11, + TessLevelInner = 12, + TessCoord = 13, + PatchVertices = 14, + FragCoord = 15, + PointCoord = 16, + FrontFacing = 17, + SampleId = 18, + SamplePosition = 19, + SampleMask = 20, + FragDepth = 22, + HelperInvocation = 23, + NumWorkgroups = 24, + WorkgroupSize = 25, + WorkgroupId = 26, + LocalInvocationId = 27, + GlobalInvocationId = 28, + LocalInvocationIndex = 29, + WorkDim = 30, + GlobalSize = 31, + EnqueuedWorkgroupSize = 32, + GlobalOffset = 33, + GlobalLinearId = 34, + SubgroupSize = 36, + SubgroupMaxSize = 37, + NumSubgroups = 38, + NumEnqueuedSubgroups = 39, + SubgroupId = 40, + SubgroupLocalInvocationId = 41, + VertexIndex = 42, + InstanceIndex = 43, + CoreIDARM = 4160, + CoreCountARM = 4161, + CoreMaxIDARM = 4162, + WarpIDARM = 4163, + WarpMaxIDARM = 4164, + SubgroupEqMask = 4416, + SubgroupGeMask = 4417, + SubgroupGtMask = 4418, + SubgroupLeMask = 4419, + SubgroupLtMask = 4420, + BaseVertex = 4424, + BaseInstance = 4425, + DrawIndex = 4426, + PrimitiveShadingRateKHR = 4432, + DeviceIndex = 4438, + ViewIndex = 4440, + ShadingRateKHR = 4444, + BaryCoordNoPerspAMD = 4992, + BaryCoordNoPerspCentroidAMD = 4993, + BaryCoordNoPerspSampleAMD = 4994, + BaryCoordSmoothAMD = 4995, + BaryCoordSmoothCentroidAMD = 4996, + BaryCoordSmoothSampleAMD = 4997, + BaryCoordPullModelAMD = 4998, + FragStencilRefEXT = 5014, + RemainingRecursionLevelsAMDX = 5021, + ShaderIndexAMDX = 5073, + ViewportMaskNV = 5253, + SecondaryPositionNV = 5257, + SecondaryViewportMaskNV = 5258, + PositionPerViewNV = 5261, + ViewportMaskPerViewNV = 5262, + FullyCoveredEXT = 5264, + TaskCountNV = 5274, + PrimitiveCountNV = 5275, + PrimitiveIndicesNV = 5276, + ClipDistancePerViewNV = 5277, + CullDistancePerViewNV = 5278, + LayerPerViewNV = 5279, + MeshViewCountNV = 5280, + MeshViewIndicesNV = 5281, + BaryCoordKHR = 5286, + BaryCoordNoPerspKHR = 5287, + FragSizeEXT = 5292, + FragInvocationCountEXT = 5293, + PrimitivePointIndicesEXT = 5294, + PrimitiveLineIndicesEXT = 5295, + PrimitiveTriangleIndicesEXT = 5296, + CullPrimitiveEXT = 5299, + LaunchIdKHR = 5319, + LaunchSizeKHR = 5320, + WorldRayOriginKHR = 5321, + WorldRayDirectionKHR = 5322, + ObjectRayOriginKHR = 5323, + ObjectRayDirectionKHR = 5324, + RayTminKHR = 5325, + RayTmaxKHR = 5326, + InstanceCustomIndexKHR = 5327, + ObjectToWorldKHR = 5330, + WorldToObjectKHR = 5331, + HitTNV = 5332, + HitKindKHR = 5333, + CurrentRayTimeNV = 5334, + HitTriangleVertexPositionsKHR = 5335, + HitMicroTriangleVertexPositionsNV = 5337, + HitMicroTriangleVertexBarycentricsNV = 5344, + IncomingRayFlagsKHR = 5351, + RayGeometryIndexKHR = 5352, + WarpsPerSMNV = 5374, + SMCountNV = 5375, + WarpIDNV = 5376, + SMIDNV = 5377, + HitKindFrontFacingMicroTriangleNV = 5405, + HitKindBackFacingMicroTriangleNV = 5406, + CullMaskKHR = 6021, +}; +enum class Scope { + CrossDevice = 0, + Device = 1, + Workgroup = 2, + Subgroup = 3, + Invocation = 4, + QueueFamily = 5, + ShaderCallKHR = 6, +}; +enum class GroupOperation { + Reduce = 0, + InclusiveScan = 1, + ExclusiveScan = 2, + ClusteredReduce = 3, + PartitionedReduceNV = 6, + PartitionedInclusiveScanNV = 7, + PartitionedExclusiveScanNV = 8, +}; +enum class KernelEnqueueFlags { + NoWait = 0, + WaitKernel = 1, + WaitWorkGroup = 2, +}; +enum class Capability { + Matrix = 0, + Shader = 1, + Geometry = 2, + Tessellation = 3, + Addresses = 4, + Linkage = 5, + Kernel = 6, + Vector16 = 7, + Float16Buffer = 8, + Float16 = 9, + Float64 = 10, + Int64 = 11, + Int64Atomics = 12, + ImageBasic = 13, + ImageReadWrite = 14, + ImageMipmap = 15, + Pipes = 17, + Groups = 18, + DeviceEnqueue = 19, + LiteralSampler = 20, + AtomicStorage = 21, + Int16 = 22, + TessellationPointSize = 23, + GeometryPointSize = 24, + ImageGatherExtended = 25, + StorageImageMultisample = 27, + UniformBufferArrayDynamicIndexing = 28, + SampledImageArrayDynamicIndexing = 29, + StorageBufferArrayDynamicIndexing = 30, + StorageImageArrayDynamicIndexing = 31, + ClipDistance = 32, + CullDistance = 33, + ImageCubeArray = 34, + SampleRateShading = 35, + ImageRect = 36, + SampledRect = 37, + GenericPointer = 38, + Int8 = 39, + InputAttachment = 40, + SparseResidency = 41, + MinLod = 42, + Sampled1D = 43, + Image1D = 44, + SampledCubeArray = 45, + SampledBuffer = 46, + ImageBuffer = 47, + ImageMSArray = 48, + StorageImageExtendedFormats = 49, + ImageQuery = 50, + DerivativeControl = 51, + InterpolationFunction = 52, + TransformFeedback = 53, + GeometryStreams = 54, + StorageImageReadWithoutFormat = 55, + StorageImageWriteWithoutFormat = 56, + MultiViewport = 57, + SubgroupDispatch = 58, + NamedBarrier = 59, + PipeStorage = 60, + GroupNonUniform = 61, + GroupNonUniformVote = 62, + GroupNonUniformArithmetic = 63, + GroupNonUniformBallot = 64, + GroupNonUniformShuffle = 65, + GroupNonUniformShuffleRelative = 66, + GroupNonUniformClustered = 67, + GroupNonUniformQuad = 68, + ShaderLayer = 69, + ShaderViewportIndex = 70, + UniformDecoration = 71, + CoreBuiltinsARM = 4165, + TileImageColorReadAccessEXT = 4166, + TileImageDepthReadAccessEXT = 4167, + TileImageStencilReadAccessEXT = 4168, + CooperativeMatrixLayoutsARM = 4201, + FragmentShadingRateKHR = 4422, + SubgroupBallotKHR = 4423, + DrawParameters = 4427, + WorkgroupMemoryExplicitLayoutKHR = 4428, + WorkgroupMemoryExplicitLayout8BitAccessKHR = 4429, + WorkgroupMemoryExplicitLayout16BitAccessKHR = 4430, + SubgroupVoteKHR = 4431, + StorageBuffer16BitAccess = 4433, + UniformAndStorageBuffer16BitAccess = 4434, + StoragePushConstant16 = 4435, + StorageInputOutput16 = 4436, + DeviceGroup = 4437, + MultiView = 4439, + VariablePointersStorageBuffer = 4441, + VariablePointers = 4442, + AtomicStorageOps = 4445, + SampleMaskPostDepthCoverage = 4447, + StorageBuffer8BitAccess = 4448, + UniformAndStorageBuffer8BitAccess = 4449, + StoragePushConstant8 = 4450, + DenormPreserve = 4464, + DenormFlushToZero = 4465, + SignedZeroInfNanPreserve = 4466, + RoundingModeRTE = 4467, + RoundingModeRTZ = 4468, + RayQueryProvisionalKHR = 4471, + RayQueryKHR = 4472, + UntypedPointersKHR = 4473, + RayTraversalPrimitiveCullingKHR = 4478, + RayTracingKHR = 4479, + TextureSampleWeightedQCOM = 4484, + TextureBoxFilterQCOM = 4485, + TextureBlockMatchQCOM = 4486, + TextureBlockMatch2QCOM = 4498, + Float16ImageAMD = 5008, + ImageGatherBiasLodAMD = 5009, + FragmentMaskAMD = 5010, + StencilExportEXT = 5013, + ImageReadWriteLodAMD = 5015, + Int64ImageEXT = 5016, + ShaderClockKHR = 5055, + ShaderEnqueueAMDX = 5067, + QuadControlKHR = 5087, + SampleMaskOverrideCoverageNV = 5249, + GeometryShaderPassthroughNV = 5251, + ShaderViewportIndexLayerEXT = 5254, + ShaderViewportMaskNV = 5255, + ShaderStereoViewNV = 5259, + PerViewAttributesNV = 5260, + FragmentFullyCoveredEXT = 5265, + MeshShadingNV = 5266, + ImageFootprintNV = 5282, + MeshShadingEXT = 5283, + FragmentBarycentricKHR = 5284, + ComputeDerivativeGroupQuadsKHR = 5288, + FragmentDensityEXT = 5291, + GroupNonUniformPartitionedNV = 5297, + ShaderNonUniform = 5301, + RuntimeDescriptorArray = 5302, + InputAttachmentArrayDynamicIndexing = 5303, + UniformTexelBufferArrayDynamicIndexing = 5304, + StorageTexelBufferArrayDynamicIndexing = 5305, + UniformBufferArrayNonUniformIndexing = 5306, + SampledImageArrayNonUniformIndexing = 5307, + StorageBufferArrayNonUniformIndexing = 5308, + StorageImageArrayNonUniformIndexing = 5309, + InputAttachmentArrayNonUniformIndexing = 5310, + UniformTexelBufferArrayNonUniformIndexing = 5311, + StorageTexelBufferArrayNonUniformIndexing = 5312, + RayTracingPositionFetchKHR = 5336, + RayTracingNV = 5340, + RayTracingMotionBlurNV = 5341, + VulkanMemoryModel = 5345, + VulkanMemoryModelDeviceScope = 5346, + PhysicalStorageBufferAddresses = 5347, + ComputeDerivativeGroupLinearKHR = 5350, + RayTracingProvisionalKHR = 5353, + CooperativeMatrixNV = 5357, + FragmentShaderSampleInterlockEXT = 5363, + FragmentShaderShadingRateInterlockEXT = 5372, + ShaderSMBuiltinsNV = 5373, + FragmentShaderPixelInterlockEXT = 5378, + DemoteToHelperInvocation = 5379, + DisplacementMicromapNV = 5380, + RayTracingOpacityMicromapEXT = 5381, + ShaderInvocationReorderNV = 5383, + BindlessTextureNV = 5390, + RayQueryPositionFetchKHR = 5391, + AtomicFloat16VectorNV = 5404, + RayTracingDisplacementMicromapNV = 5409, + RawAccessChainsNV = 5414, + CooperativeMatrixReductionsNV = 5430, + CooperativeMatrixConversionsNV = 5431, + CooperativeMatrixPerElementOperationsNV = 5432, + CooperativeMatrixTensorAddressingNV = 5433, + CooperativeMatrixBlockLoadsNV = 5434, + TensorAddressingNV = 5439, + SubgroupShuffleINTEL = 5568, + SubgroupBufferBlockIOINTEL = 5569, + SubgroupImageBlockIOINTEL = 5570, + SubgroupImageMediaBlockIOINTEL = 5579, + RoundToInfinityINTEL = 5582, + FloatingPointModeINTEL = 5583, + IntegerFunctions2INTEL = 5584, + FunctionPointersINTEL = 5603, + IndirectReferencesINTEL = 5604, + AsmINTEL = 5606, + AtomicFloat32MinMaxEXT = 5612, + AtomicFloat64MinMaxEXT = 5613, + AtomicFloat16MinMaxEXT = 5616, + VectorComputeINTEL = 5617, + VectorAnyINTEL = 5619, + ExpectAssumeKHR = 5629, + SubgroupAvcMotionEstimationINTEL = 5696, + SubgroupAvcMotionEstimationIntraINTEL = 5697, + SubgroupAvcMotionEstimationChromaINTEL = 5698, + VariableLengthArrayINTEL = 5817, + FunctionFloatControlINTEL = 5821, + FPGAMemoryAttributesINTEL = 5824, + FPFastMathModeINTEL = 5837, + ArbitraryPrecisionIntegersINTEL = 5844, + ArbitraryPrecisionFloatingPointINTEL = 5845, + UnstructuredLoopControlsINTEL = 5886, + FPGALoopControlsINTEL = 5888, + KernelAttributesINTEL = 5892, + FPGAKernelAttributesINTEL = 5897, + FPGAMemoryAccessesINTEL = 5898, + FPGAClusterAttributesINTEL = 5904, + LoopFuseINTEL = 5906, + FPGADSPControlINTEL = 5908, + MemoryAccessAliasingINTEL = 5910, + FPGAInvocationPipeliningAttributesINTEL = 5916, + FPGABufferLocationINTEL = 5920, + ArbitraryPrecisionFixedPointINTEL = 5922, + USMStorageClassesINTEL = 5935, + RuntimeAlignedAttributeINTEL = 5939, + IOPipesINTEL = 5943, + BlockingPipesINTEL = 5945, + FPGARegINTEL = 5948, + DotProductInputAll = 6016, + DotProductInput4x8Bit = 6017, + DotProductInput4x8BitPacked = 6018, + DotProduct = 6019, + RayCullMaskKHR = 6020, + CooperativeMatrixKHR = 6022, + ReplicatedCompositesEXT = 6024, + BitInstructions = 6025, + GroupNonUniformRotateKHR = 6026, + FloatControls2 = 6029, + AtomicFloat32AddEXT = 6033, + AtomicFloat64AddEXT = 6034, + LongCompositesINTEL = 6089, + OptNoneEXT = 6094, + AtomicFloat16AddEXT = 6095, + DebugInfoModuleINTEL = 6114, + BFloat16ConversionINTEL = 6115, + SplitBarrierINTEL = 6141, + ArithmeticFenceEXT = 6144, + FPGAClusterAttributesV2INTEL = 6150, + FPGAKernelAttributesv2INTEL = 6161, + FPMaxErrorINTEL = 6169, + FPGALatencyControlINTEL = 6171, + FPGAArgumentInterfacesINTEL = 6174, + GlobalVariableHostAccessINTEL = 6187, + GlobalVariableFPGADecorationsINTEL = 6189, + SubgroupBufferPrefetchINTEL = 6220, + GroupUniformArithmeticKHR = 6400, + MaskedGatherScatterINTEL = 6427, + CacheControlsINTEL = 6441, + RegisterLimitsINTEL = 6460, +}; +enum class RayQueryIntersection { + RayQueryCandidateIntersectionKHR = 0, + RayQueryCommittedIntersectionKHR = 1, +}; +enum class RayQueryCommittedIntersectionType { + RayQueryCommittedIntersectionNoneKHR = 0, + RayQueryCommittedIntersectionTriangleKHR = 1, + RayQueryCommittedIntersectionGeneratedKHR = 2, +}; +enum class RayQueryCandidateIntersectionType { + RayQueryCandidateIntersectionTriangleKHR = 0, + RayQueryCandidateIntersectionAABBKHR = 1, +}; +enum class PackedVectorFormat { + PackedVectorFormat4x8Bit = 0, +}; +enum class CooperativeMatrixOperands { + NoneKHR = 0x0000, + MatrixASignedComponentsKHR = 0x0001, + MatrixBSignedComponentsKHR = 0x0002, + MatrixCSignedComponentsKHR = 0x0004, + MatrixResultSignedComponentsKHR = 0x0008, + SaturatingAccumulationKHR = 0x0010, +}; +enum class CooperativeMatrixLayout { + RowMajorKHR = 0, + ColumnMajorKHR = 1, + RowBlockedInterleavedARM = 4202, + ColumnBlockedInterleavedARM = 4203, +}; +enum class CooperativeMatrixUse { + MatrixAKHR = 0, + MatrixBKHR = 1, + MatrixAccumulatorKHR = 2, +}; +enum class CooperativeMatrixReduce { + Row = 0x0001, + Column = 0x0002, + CooperativeMatrixReduce2x2 = 0x0004, +}; +enum class TensorClampMode { + Undefined = 0, + Constant = 1, + ClampToEdge = 2, + Repeat = 3, + RepeatMirrored = 4, +}; +enum class TensorAddressingOperands { + None = 0x0000, + TensorView = 0x0001, + DecodeFunc = 0x0002, +}; +enum class InitializationModeQualifier { + InitOnDeviceReprogramINTEL = 0, + InitOnDeviceResetINTEL = 1, +}; +enum class LoadCacheControl { + UncachedINTEL = 0, + CachedINTEL = 1, + StreamingINTEL = 2, + InvalidateAfterReadINTEL = 3, + ConstCachedINTEL = 4, +}; +enum class StoreCacheControl { + UncachedINTEL = 0, + WriteThroughINTEL = 1, + WriteBackINTEL = 2, + StreamingINTEL = 3, +}; +enum class NamedMaximumNumberOfRegisters { + AutoINTEL = 0, +}; +enum class FPEncoding {}; + +} // namespace tinytc::spv + +#endif // GENERATED_ENUMS_2024114_HPP diff --git a/src/spv/instructions.hpp b/src/spv/instructions.hpp new file mode 100644 index 00000000..8890fd69 --- /dev/null +++ b/src/spv/instructions.hpp @@ -0,0 +1,5590 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_INSTRUCTIONS_2024114_HPP +#define GENERATED_INSTRUCTIONS_2024114_HPP + +#include "enums.hpp" +#include "error.hpp" +#include "support/ilist_base.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +class spv_inst : public ilist_node { + public: + inline spv_inst(Op opcode, bool has_result_id) + : opcode_{opcode}, has_result_id_{has_result_id} {} + virtual ~spv_inst() = default; + + spv_inst(spv_inst const &other) = delete; + spv_inst(spv_inst &&other) = delete; + spv_inst &operator=(spv_inst const &other) = delete; + spv_inst &operator=(spv_inst &&other) = delete; + + inline auto opcode() const -> Op { return opcode_; } + inline auto has_result_id() const -> bool { return has_result_id_; } + + private: + Op opcode_; + bool has_result_id_; +}; + +using DecorationAttr = std::variant>; +using ExecutionModeAttr = std::variant>; +using LiteralContextDependentNumber = + std::variant; +using LiteralString = std::string; +using LiteralInteger = std::int32_t; +using LiteralExtInstInteger = std::int32_t; +using IdResultType = spv_inst *; +using IdRef = spv_inst *; +using IdScope = spv_inst *; +using IdMemorySemantics = spv_inst *; +using PairIdRefIdRef = std::pair; +using PairLiteralIntegerIdRef = + std::pair, spv_inst *>; +using PairIdRefLiteralInteger = std::pair; + +class OpNop : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Nop; } + OpNop() : spv_inst{Op::Nop, false} {} + + private: +}; +class OpUndef : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Undef; } + OpUndef(IdResultType type) : spv_inst{Op::Undef, true}, type_(std::move(type)) {} + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpSourceContinued : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SourceContinued; } + OpSourceContinued(LiteralString op0) + : spv_inst{Op::SourceContinued, false}, op0_(std::move(op0)) {} + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpSource : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Source; } + OpSource(SourceLanguage op0, LiteralInteger op1, std::optional op2, + std::optional op3) + : spv_inst{Op::Source, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() const -> SourceLanguage const & { return op0_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() const -> std::optional const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + SourceLanguage op0_; + LiteralInteger op1_; + std::optional op2_; + std::optional op3_; +}; +class OpSourceExtension : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SourceExtension; } + OpSourceExtension(LiteralString op0) + : spv_inst{Op::SourceExtension, false}, op0_(std::move(op0)) {} + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpName : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Name; } + OpName(IdRef op0, LiteralString op1) + : spv_inst{Op::Name, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> LiteralString const & { return op1_; } + + private: + IdRef op0_; + LiteralString op1_; +}; +class OpMemberName : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemberName; } + OpMemberName(IdRef op0, LiteralInteger op1, LiteralString op2) + : spv_inst{Op::MemberName, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() const -> LiteralString const & { return op2_; } + + private: + IdRef op0_; + LiteralInteger op1_; + LiteralString op2_; +}; +class OpString : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::String; } + OpString(LiteralString op0) : spv_inst{Op::String, true}, op0_(std::move(op0)) {} + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpLine : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Line; } + OpLine(IdRef op0, LiteralInteger op1, LiteralInteger op2) + : spv_inst{Op::Line, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() const -> LiteralInteger const & { return op2_; } + + private: + IdRef op0_; + LiteralInteger op1_; + LiteralInteger op2_; +}; +class OpExtension : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Extension; } + OpExtension(LiteralString op0) : spv_inst{Op::Extension, false}, op0_(std::move(op0)) {} + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpExtInstImport : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ExtInstImport; } + OpExtInstImport(LiteralString op0) : spv_inst{Op::ExtInstImport, true}, op0_(std::move(op0)) {} + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpExtInst : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ExtInst; } + OpExtInst(IdResultType type, IdRef op0, LiteralExtInstInteger op1, std::vector op2) + : spv_inst{Op::ExtInst, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> LiteralExtInstInteger const & { return op1_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + LiteralExtInstInteger op1_; + std::vector op2_; +}; +class OpMemoryModel : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemoryModel; } + OpMemoryModel(AddressingModel op0, MemoryModel op1) + : spv_inst{Op::MemoryModel, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> AddressingModel const & { return op0_; } + inline auto op1() const -> MemoryModel const & { return op1_; } + + private: + AddressingModel op0_; + MemoryModel op1_; +}; +class OpEntryPoint : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EntryPoint; } + OpEntryPoint(ExecutionModel op0, IdRef op1, LiteralString op2, std::vector op3) + : spv_inst{Op::EntryPoint, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() const -> ExecutionModel const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> LiteralString const & { return op2_; } + inline auto op3() const -> std::vector const & { return op3_; } + + private: + ExecutionModel op0_; + IdRef op1_; + LiteralString op2_; + std::vector op3_; +}; +class OpExecutionMode : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ExecutionMode; } + OpExecutionMode(IdRef op0, ExecutionMode op1, ExecutionModeAttr op2) + : spv_inst{Op::ExecutionMode, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> ExecutionMode const & { return op1_; } + inline auto op2() const -> ExecutionModeAttr const & { return op2_; } + + private: + IdRef op0_; + ExecutionMode op1_; + ExecutionModeAttr op2_; +}; +class OpCapability : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Capability; } + OpCapability(Capability op0) : spv_inst{Op::Capability, false}, op0_(std::move(op0)) {} + inline auto op0() const -> Capability const & { return op0_; } + + private: + Capability op0_; +}; +class OpTypeVoid : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeVoid; } + OpTypeVoid() : spv_inst{Op::TypeVoid, true} {} + + private: +}; +class OpTypeBool : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeBool; } + OpTypeBool() : spv_inst{Op::TypeBool, true} {} + + private: +}; +class OpTypeInt : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeInt; } + OpTypeInt(LiteralInteger op0, LiteralInteger op1) + : spv_inst{Op::TypeInt, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> LiteralInteger const & { return op0_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + LiteralInteger op0_; + LiteralInteger op1_; +}; +class OpTypeFloat : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeFloat; } + OpTypeFloat(LiteralInteger op0, std::optional op1) + : spv_inst{Op::TypeFloat, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> LiteralInteger const & { return op0_; } + inline auto op1() const -> std::optional const & { return op1_; } + + private: + LiteralInteger op0_; + std::optional op1_; +}; +class OpTypeVector : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeVector; } + OpTypeVector(IdRef op0, LiteralInteger op1) + : spv_inst{Op::TypeVector, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + IdRef op0_; + LiteralInteger op1_; +}; +class OpTypeMatrix : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeMatrix; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpTypeMatrix(IdRef op0, LiteralInteger op1) + : spv_inst{Op::TypeMatrix, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + IdRef op0_; + LiteralInteger op1_; +}; +class OpTypeImage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeImage; } + OpTypeImage(IdRef op0, Dim op1, LiteralInteger op2, LiteralInteger op3, LiteralInteger op4, + LiteralInteger op5, ImageFormat op6, std::optional op7) + : spv_inst{Op::TypeImage, true}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), op5_(std::move(op5)), + op6_(std::move(op6)), op7_(std::move(op7)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> Dim const & { return op1_; } + inline auto op2() const -> LiteralInteger const & { return op2_; } + inline auto op3() const -> LiteralInteger const & { return op3_; } + inline auto op4() const -> LiteralInteger const & { return op4_; } + inline auto op5() const -> LiteralInteger const & { return op5_; } + inline auto op6() const -> ImageFormat const & { return op6_; } + inline auto op7() const -> std::optional const & { return op7_; } + + private: + IdRef op0_; + Dim op1_; + LiteralInteger op2_; + LiteralInteger op3_; + LiteralInteger op4_; + LiteralInteger op5_; + ImageFormat op6_; + std::optional op7_; +}; +class OpTypeSampler : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeSampler; } + OpTypeSampler() : spv_inst{Op::TypeSampler, true} {} + + private: +}; +class OpTypeSampledImage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeSampledImage; } + OpTypeSampledImage(IdRef op0) : spv_inst{Op::TypeSampledImage, true}, op0_(std::move(op0)) {} + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpTypeArray : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeArray; } + OpTypeArray(IdRef op0, IdRef op1) + : spv_inst{Op::TypeArray, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdRef op0_; + IdRef op1_; +}; +class OpTypeRuntimeArray : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeRuntimeArray; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpTypeRuntimeArray(IdRef op0) : spv_inst{Op::TypeRuntimeArray, true}, op0_(std::move(op0)) {} + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpTypeStruct : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeStruct; } + OpTypeStruct(std::vector op0) : spv_inst{Op::TypeStruct, true}, op0_(std::move(op0)) {} + inline auto op0() const -> std::vector const & { return op0_; } + + private: + std::vector op0_; +}; +class OpTypeOpaque : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeOpaque; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpTypeOpaque(LiteralString op0) : spv_inst{Op::TypeOpaque, true}, op0_(std::move(op0)) {} + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpTypePointer : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypePointer; } + OpTypePointer(StorageClass op0, IdRef op1) + : spv_inst{Op::TypePointer, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> StorageClass const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + StorageClass op0_; + IdRef op1_; +}; +class OpTypeFunction : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeFunction; } + OpTypeFunction(IdRef op0, std::vector op1) + : spv_inst{Op::TypeFunction, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdRef op0_; + std::vector op1_; +}; +class OpTypeEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeEvent; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpTypeEvent() : spv_inst{Op::TypeEvent, true} {} + + private: +}; +class OpTypeDeviceEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeDeviceEvent; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpTypeDeviceEvent() : spv_inst{Op::TypeDeviceEvent, true} {} + + private: +}; +class OpTypeReserveId : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeReserveId; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpTypeReserveId() : spv_inst{Op::TypeReserveId, true} {} + + private: +}; +class OpTypeQueue : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeQueue; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpTypeQueue() : spv_inst{Op::TypeQueue, true} {} + + private: +}; +class OpTypePipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypePipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpTypePipe(AccessQualifier op0) : spv_inst{Op::TypePipe, true}, op0_(std::move(op0)) {} + inline auto op0() const -> AccessQualifier const & { return op0_; } + + private: + AccessQualifier op0_; +}; +class OpTypeForwardPointer : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeForwardPointer; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::PhysicalStorageBufferAddresses}; + OpTypeForwardPointer(IdRef op0, StorageClass op1) + : spv_inst{Op::TypeForwardPointer, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> StorageClass const & { return op1_; } + + private: + IdRef op0_; + StorageClass op1_; +}; +class OpConstantTrue : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantTrue; } + OpConstantTrue(IdResultType type) : spv_inst{Op::ConstantTrue, true}, type_(std::move(type)) {} + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpConstantFalse : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantFalse; } + OpConstantFalse(IdResultType type) + : spv_inst{Op::ConstantFalse, true}, type_(std::move(type)) {} + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpConstant : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Constant; } + OpConstant(IdResultType type, LiteralContextDependentNumber op0) + : spv_inst{Op::Constant, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> LiteralContextDependentNumber const & { return op0_; } + + private: + IdResultType type_; + LiteralContextDependentNumber op0_; +}; +class OpConstantComposite : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantComposite; } + OpConstantComposite(IdResultType type, std::vector op0) + : spv_inst{Op::ConstantComposite, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> std::vector const & { return op0_; } + + private: + IdResultType type_; + std::vector op0_; +}; +class OpConstantSampler : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantSampler; } + constexpr static std::array required_capabilities = {Capability::LiteralSampler}; + OpConstantSampler(IdResultType type, SamplerAddressingMode op0, LiteralInteger op1, + SamplerFilterMode op2) + : spv_inst{Op::ConstantSampler, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> SamplerAddressingMode const & { return op0_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() const -> SamplerFilterMode const & { return op2_; } + + private: + IdResultType type_; + SamplerAddressingMode op0_; + LiteralInteger op1_; + SamplerFilterMode op2_; +}; +class OpConstantNull : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantNull; } + OpConstantNull(IdResultType type) : spv_inst{Op::ConstantNull, true}, type_(std::move(type)) {} + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpFunction : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Function; } + OpFunction(IdResultType type, FunctionControl op0, IdRef op1) + : spv_inst{Op::Function, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> FunctionControl const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + FunctionControl op0_; + IdRef op1_; +}; +class OpFunctionParameter : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FunctionParameter; } + OpFunctionParameter(IdResultType type) + : spv_inst{Op::FunctionParameter, true}, type_(std::move(type)) {} + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpFunctionEnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FunctionEnd; } + OpFunctionEnd() : spv_inst{Op::FunctionEnd, false} {} + + private: +}; +class OpFunctionCall : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FunctionCall; } + OpFunctionCall(IdResultType type, IdRef op0, std::vector op1) + : spv_inst{Op::FunctionCall, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + std::vector op1_; +}; +class OpVariable : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Variable; } + OpVariable(IdResultType type, StorageClass op0, std::optional op1) + : spv_inst{Op::Variable, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> StorageClass const & { return op0_; } + inline auto op1() const -> std::optional const & { return op1_; } + + private: + IdResultType type_; + StorageClass op0_; + std::optional op1_; +}; +class OpImageTexelPointer : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageTexelPointer; } + OpImageTexelPointer(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::ImageTexelPointer, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpLoad : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Load; } + OpLoad(IdResultType type, IdRef op0, std::optional op1) + : spv_inst{Op::Load, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> std::optional const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + std::optional op1_; +}; +class OpStore : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Store; } + OpStore(IdRef op0, IdRef op1, std::optional op2) + : spv_inst{Op::Store, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpCopyMemory : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CopyMemory; } + OpCopyMemory(IdRef op0, IdRef op1, std::optional op2, + std::optional op3) + : spv_inst{Op::CopyMemory, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> std::optional const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + std::optional op2_; + std::optional op3_; +}; +class OpCopyMemorySized : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CopyMemorySized; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::UntypedPointersKHR}; + OpCopyMemorySized(IdRef op0, IdRef op1, IdRef op2, std::optional op3, + std::optional op4) + : spv_inst{Op::CopyMemorySized, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + inline auto op4() const -> std::optional const & { return op4_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; + std::optional op4_; +}; +class OpAccessChain : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AccessChain; } + OpAccessChain(IdResultType type, IdRef op0, std::vector op1) + : spv_inst{Op::AccessChain, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + std::vector op1_; +}; +class OpInBoundsAccessChain : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::InBoundsAccessChain; } + OpInBoundsAccessChain(IdResultType type, IdRef op0, std::vector op1) + : spv_inst{Op::InBoundsAccessChain, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + std::vector op1_; +}; +class OpPtrAccessChain : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::PtrAccessChain; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::VariablePointers, + Capability::VariablePointersStorageBuffer, Capability::PhysicalStorageBufferAddresses}; + OpPtrAccessChain(IdResultType type, IdRef op0, IdRef op1, std::vector op2) + : spv_inst{Op::PtrAccessChain, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::vector op2_; +}; +class OpArrayLength : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ArrayLength; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpArrayLength(IdResultType type, IdRef op0, LiteralInteger op1) + : spv_inst{Op::ArrayLength, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + LiteralInteger op1_; +}; +class OpGenericPtrMemSemantics : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GenericPtrMemSemantics; + } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpGenericPtrMemSemantics(IdResultType type, IdRef op0) + : spv_inst{Op::GenericPtrMemSemantics, true}, type_(std::move(type)), op0_(std::move(op0)) { + } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpInBoundsPtrAccessChain : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::InBoundsPtrAccessChain; + } + constexpr static std::array required_capabilities = {Capability::Addresses}; + OpInBoundsPtrAccessChain(IdResultType type, IdRef op0, IdRef op1, std::vector op2) + : spv_inst{Op::InBoundsPtrAccessChain, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::vector op2_; +}; +class OpDecorate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Decorate; } + OpDecorate(IdRef op0, Decoration op1, DecorationAttr op2) + : spv_inst{Op::Decorate, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> Decoration const & { return op1_; } + inline auto op2() const -> DecorationAttr const & { return op2_; } + + private: + IdRef op0_; + Decoration op1_; + DecorationAttr op2_; +}; +class OpMemberDecorate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemberDecorate; } + OpMemberDecorate(IdRef op0, LiteralInteger op1, Decoration op2) + : spv_inst{Op::MemberDecorate, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() const -> Decoration const & { return op2_; } + + private: + IdRef op0_; + LiteralInteger op1_; + Decoration op2_; +}; +class OpDecorationGroup : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DecorationGroup; } + OpDecorationGroup() : spv_inst{Op::DecorationGroup, true} {} + + private: +}; +class OpGroupDecorate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupDecorate; } + OpGroupDecorate(IdRef op0, std::vector op1) + : spv_inst{Op::GroupDecorate, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdRef op0_; + std::vector op1_; +}; +class OpGroupMemberDecorate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupMemberDecorate; } + OpGroupMemberDecorate(IdRef op0, std::vector op1) + : spv_inst{Op::GroupMemberDecorate, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdRef op0_; + std::vector op1_; +}; +class OpVectorExtractDynamic : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::VectorExtractDynamic; } + OpVectorExtractDynamic(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::VectorExtractDynamic, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpVectorInsertDynamic : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::VectorInsertDynamic; } + OpVectorInsertDynamic(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::VectorInsertDynamic, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpVectorShuffle : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::VectorShuffle; } + OpVectorShuffle(IdResultType type, IdRef op0, IdRef op1, std::vector op2) + : spv_inst{Op::VectorShuffle, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::vector op2_; +}; +class OpCompositeConstruct : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CompositeConstruct; } + OpCompositeConstruct(IdResultType type, std::vector op0) + : spv_inst{Op::CompositeConstruct, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> std::vector const & { return op0_; } + + private: + IdResultType type_; + std::vector op0_; +}; +class OpCompositeExtract : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CompositeExtract; } + OpCompositeExtract(IdResultType type, IdRef op0, std::vector op1) + : spv_inst{Op::CompositeExtract, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + std::vector op1_; +}; +class OpCompositeInsert : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CompositeInsert; } + OpCompositeInsert(IdResultType type, IdRef op0, IdRef op1, std::vector op2) + : spv_inst{Op::CompositeInsert, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::vector op2_; +}; +class OpCopyObject : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CopyObject; } + OpCopyObject(IdResultType type, IdRef op0) + : spv_inst{Op::CopyObject, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpTranspose : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Transpose; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpTranspose(IdResultType type, IdRef op0) + : spv_inst{Op::Transpose, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSampledImage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SampledImage; } + OpSampledImage(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SampledImage, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpImageSampleImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleImplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleImplicitLod(IdResultType type, IdRef op0, IdRef op1, + std::optional op2) + : spv_inst{Op::ImageSampleImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageSampleExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleExplicitLod; + } + OpImageSampleExplicitLod(IdResultType type, IdRef op0, IdRef op1, ImageOperands op2) + : spv_inst{Op::ImageSampleExplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> ImageOperands const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + ImageOperands op2_; +}; +class OpImageSampleDrefImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleDrefImplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleDrefImplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3) + : spv_inst{Op::ImageSampleDrefImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSampleDrefExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleDrefExplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleDrefExplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + ImageOperands op3) + : spv_inst{Op::ImageSampleDrefExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> ImageOperands const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + ImageOperands op3_; +}; +class OpImageSampleProjImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleProjImplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleProjImplicitLod(IdResultType type, IdRef op0, IdRef op1, + std::optional op2) + : spv_inst{Op::ImageSampleProjImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageSampleProjExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleProjExplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleProjExplicitLod(IdResultType type, IdRef op0, IdRef op1, ImageOperands op2) + : spv_inst{Op::ImageSampleProjExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> ImageOperands const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + ImageOperands op2_; +}; +class OpImageSampleProjDrefImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleProjDrefImplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleProjDrefImplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3) + : spv_inst{Op::ImageSampleProjDrefImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSampleProjDrefExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleProjDrefExplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleProjDrefExplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + ImageOperands op3) + : spv_inst{Op::ImageSampleProjDrefExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> ImageOperands const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + ImageOperands op3_; +}; +class OpImageFetch : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageFetch; } + OpImageFetch(IdResultType type, IdRef op0, IdRef op1, std::optional op2) + : spv_inst{Op::ImageFetch, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageGather : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageGather; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageGather(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3) + : spv_inst{Op::ImageGather, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageDrefGather : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageDrefGather; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageDrefGather(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3) + : spv_inst{Op::ImageDrefGather, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageRead : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageRead; } + OpImageRead(IdResultType type, IdRef op0, IdRef op1, std::optional op2) + : spv_inst{Op::ImageRead, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageWrite : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageWrite; } + OpImageWrite(IdRef op0, IdRef op1, IdRef op2, std::optional op3) + : spv_inst{Op::ImageWrite, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Image; } + OpImage(IdResultType type, IdRef op0) + : spv_inst{Op::Image, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpImageQueryFormat : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQueryFormat; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpImageQueryFormat(IdResultType type, IdRef op0) + : spv_inst{Op::ImageQueryFormat, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpImageQueryOrder : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQueryOrder; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpImageQueryOrder(IdResultType type, IdRef op0) + : spv_inst{Op::ImageQueryOrder, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpImageQuerySizeLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQuerySizeLod; } + constexpr static std::array required_capabilities = {Capability::Kernel, + Capability::ImageQuery}; + OpImageQuerySizeLod(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ImageQuerySizeLod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpImageQuerySize : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQuerySize; } + constexpr static std::array required_capabilities = {Capability::Kernel, + Capability::ImageQuery}; + OpImageQuerySize(IdResultType type, IdRef op0) + : spv_inst{Op::ImageQuerySize, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpImageQueryLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQueryLod; } + constexpr static std::array required_capabilities = {Capability::ImageQuery}; + OpImageQueryLod(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ImageQueryLod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpImageQueryLevels : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQueryLevels; } + constexpr static std::array required_capabilities = {Capability::Kernel, + Capability::ImageQuery}; + OpImageQueryLevels(IdResultType type, IdRef op0) + : spv_inst{Op::ImageQueryLevels, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpImageQuerySamples : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQuerySamples; } + constexpr static std::array required_capabilities = {Capability::Kernel, + Capability::ImageQuery}; + OpImageQuerySamples(IdResultType type, IdRef op0) + : spv_inst{Op::ImageQuerySamples, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertFToU : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertFToU; } + OpConvertFToU(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertFToU, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertFToS : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertFToS; } + OpConvertFToS(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertFToS, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertSToF : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertSToF; } + OpConvertSToF(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertSToF, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertUToF : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertUToF; } + OpConvertUToF(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertUToF, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpUConvert : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UConvert; } + OpUConvert(IdResultType type, IdRef op0) + : spv_inst{Op::UConvert, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSConvert : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SConvert; } + OpSConvert(IdResultType type, IdRef op0) + : spv_inst{Op::SConvert, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpFConvert : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FConvert; } + OpFConvert(IdResultType type, IdRef op0) + : spv_inst{Op::FConvert, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpQuantizeToF16 : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::QuantizeToF16; } + OpQuantizeToF16(IdResultType type, IdRef op0) + : spv_inst{Op::QuantizeToF16, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertPtrToU : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertPtrToU; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::PhysicalStorageBufferAddresses}; + OpConvertPtrToU(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertPtrToU, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSatConvertSToU : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SatConvertSToU; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpSatConvertSToU(IdResultType type, IdRef op0) + : spv_inst{Op::SatConvertSToU, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSatConvertUToS : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SatConvertUToS; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpSatConvertUToS(IdResultType type, IdRef op0) + : spv_inst{Op::SatConvertUToS, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertUToPtr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertUToPtr; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::PhysicalStorageBufferAddresses}; + OpConvertUToPtr(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertUToPtr, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpPtrCastToGeneric : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::PtrCastToGeneric; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpPtrCastToGeneric(IdResultType type, IdRef op0) + : spv_inst{Op::PtrCastToGeneric, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpGenericCastToPtr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GenericCastToPtr; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpGenericCastToPtr(IdResultType type, IdRef op0) + : spv_inst{Op::GenericCastToPtr, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpGenericCastToPtrExplicit : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GenericCastToPtrExplicit; + } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpGenericCastToPtrExplicit(IdResultType type, IdRef op0, StorageClass op1) + : spv_inst{Op::GenericCastToPtrExplicit, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> StorageClass const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + StorageClass op1_; +}; +class OpBitcast : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Bitcast; } + OpBitcast(IdResultType type, IdRef op0) + : spv_inst{Op::Bitcast, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSNegate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SNegate; } + OpSNegate(IdResultType type, IdRef op0) + : spv_inst{Op::SNegate, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpFNegate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FNegate; } + OpFNegate(IdResultType type, IdRef op0) + : spv_inst{Op::FNegate, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpIAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IAdd; } + OpIAdd(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::IAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FAdd; } + OpFAdd(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpISub : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ISub; } + OpISub(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ISub, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFSub : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FSub; } + OpFSub(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FSub, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpIMul : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IMul; } + OpIMul(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::IMul, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFMul : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FMul; } + OpFMul(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FMul, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUDiv : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UDiv; } + OpUDiv(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::UDiv, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSDiv : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SDiv; } + OpSDiv(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SDiv, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFDiv : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FDiv; } + OpFDiv(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FDiv, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUMod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UMod; } + OpUMod(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::UMod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSRem : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SRem; } + OpSRem(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SRem, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSMod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SMod; } + OpSMod(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SMod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFRem : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FRem; } + OpFRem(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FRem, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFMod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FMod; } + OpFMod(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FMod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpVectorTimesScalar : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::VectorTimesScalar; } + OpVectorTimesScalar(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::VectorTimesScalar, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpMatrixTimesScalar : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MatrixTimesScalar; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpMatrixTimesScalar(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::MatrixTimesScalar, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpVectorTimesMatrix : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::VectorTimesMatrix; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpVectorTimesMatrix(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::VectorTimesMatrix, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpMatrixTimesVector : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MatrixTimesVector; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpMatrixTimesVector(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::MatrixTimesVector, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpMatrixTimesMatrix : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MatrixTimesMatrix; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpMatrixTimesMatrix(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::MatrixTimesMatrix, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpOuterProduct : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::OuterProduct; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpOuterProduct(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::OuterProduct, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpDot : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Dot; } + OpDot(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::Dot, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpIAddCarry : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IAddCarry; } + OpIAddCarry(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::IAddCarry, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpISubBorrow : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ISubBorrow; } + OpISubBorrow(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ISubBorrow, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUMulExtended : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UMulExtended; } + OpUMulExtended(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::UMulExtended, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSMulExtended : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SMulExtended; } + OpSMulExtended(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SMulExtended, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpAny : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Any; } + OpAny(IdResultType type, IdRef op0) + : spv_inst{Op::Any, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpAll : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::All; } + OpAll(IdResultType type, IdRef op0) + : spv_inst{Op::All, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpIsNan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsNan; } + OpIsNan(IdResultType type, IdRef op0) + : spv_inst{Op::IsNan, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpIsInf : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsInf; } + OpIsInf(IdResultType type, IdRef op0) + : spv_inst{Op::IsInf, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpIsFinite : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsFinite; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpIsFinite(IdResultType type, IdRef op0) + : spv_inst{Op::IsFinite, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpIsNormal : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsNormal; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpIsNormal(IdResultType type, IdRef op0) + : spv_inst{Op::IsNormal, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSignBitSet : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SignBitSet; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpSignBitSet(IdResultType type, IdRef op0) + : spv_inst{Op::SignBitSet, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpLessOrGreater : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LessOrGreater; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpLessOrGreater(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::LessOrGreater, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpOrdered : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Ordered; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpOrdered(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::Ordered, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUnordered : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Unordered; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpUnordered(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::Unordered, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpLogicalEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LogicalEqual; } + OpLogicalEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::LogicalEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpLogicalNotEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LogicalNotEqual; } + OpLogicalNotEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::LogicalNotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpLogicalOr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LogicalOr; } + OpLogicalOr(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::LogicalOr, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpLogicalAnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LogicalAnd; } + OpLogicalAnd(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::LogicalAnd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpLogicalNot : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LogicalNot; } + OpLogicalNot(IdResultType type, IdRef op0) + : spv_inst{Op::LogicalNot, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSelect : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Select; } + OpSelect(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::Select, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpIEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IEqual; } + OpIEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::IEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpINotEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::INotEqual; } + OpINotEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::INotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUGreaterThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UGreaterThan; } + OpUGreaterThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::UGreaterThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSGreaterThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SGreaterThan; } + OpSGreaterThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SGreaterThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUGreaterThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UGreaterThanEqual; } + OpUGreaterThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::UGreaterThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSGreaterThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SGreaterThanEqual; } + OpSGreaterThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SGreaterThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpULessThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ULessThan; } + OpULessThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ULessThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSLessThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SLessThan; } + OpSLessThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SLessThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpULessThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ULessThanEqual; } + OpULessThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ULessThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSLessThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SLessThanEqual; } + OpSLessThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SLessThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdEqual; } + OpFOrdEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FUnordEqual; } + OpFUnordEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdNotEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdNotEqual; } + OpFOrdNotEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdNotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordNotEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FUnordNotEqual; } + OpFUnordNotEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordNotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdLessThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdLessThan; } + OpFOrdLessThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdLessThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordLessThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FUnordLessThan; } + OpFUnordLessThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordLessThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdGreaterThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdGreaterThan; } + OpFOrdGreaterThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdGreaterThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordGreaterThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FUnordGreaterThan; } + OpFUnordGreaterThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordGreaterThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdLessThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdLessThanEqual; } + OpFOrdLessThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdLessThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordLessThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FUnordLessThanEqual; } + OpFUnordLessThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordLessThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdGreaterThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdGreaterThanEqual; } + OpFOrdGreaterThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdGreaterThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordGreaterThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::FUnordGreaterThanEqual; + } + OpFUnordGreaterThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordGreaterThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpShiftRightLogical : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ShiftRightLogical; } + OpShiftRightLogical(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ShiftRightLogical, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpShiftRightArithmetic : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ShiftRightArithmetic; } + OpShiftRightArithmetic(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ShiftRightArithmetic, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpShiftLeftLogical : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ShiftLeftLogical; } + OpShiftLeftLogical(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ShiftLeftLogical, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpBitwiseOr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitwiseOr; } + OpBitwiseOr(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::BitwiseOr, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpBitwiseXor : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitwiseXor; } + OpBitwiseXor(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::BitwiseXor, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpBitwiseAnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitwiseAnd; } + OpBitwiseAnd(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::BitwiseAnd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpNot : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Not; } + OpNot(IdResultType type, IdRef op0) + : spv_inst{Op::Not, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpBitFieldInsert : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitFieldInsert; } + constexpr static std::array required_capabilities = { + Capability::Shader, Capability::BitInstructions}; + OpBitFieldInsert(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::BitFieldInsert, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpBitFieldSExtract : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitFieldSExtract; } + constexpr static std::array required_capabilities = { + Capability::Shader, Capability::BitInstructions}; + OpBitFieldSExtract(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::BitFieldSExtract, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpBitFieldUExtract : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitFieldUExtract; } + constexpr static std::array required_capabilities = { + Capability::Shader, Capability::BitInstructions}; + OpBitFieldUExtract(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::BitFieldUExtract, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpBitReverse : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitReverse; } + constexpr static std::array required_capabilities = { + Capability::Shader, Capability::BitInstructions}; + OpBitReverse(IdResultType type, IdRef op0) + : spv_inst{Op::BitReverse, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpBitCount : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitCount; } + OpBitCount(IdResultType type, IdRef op0) + : spv_inst{Op::BitCount, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdx : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdx; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpDPdx(IdResultType type, IdRef op0) + : spv_inst{Op::DPdx, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdy : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdy; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpDPdy(IdResultType type, IdRef op0) + : spv_inst{Op::DPdy, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpFwidth : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Fwidth; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpFwidth(IdResultType type, IdRef op0) + : spv_inst{Op::Fwidth, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdxFine : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdxFine; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpDPdxFine(IdResultType type, IdRef op0) + : spv_inst{Op::DPdxFine, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdyFine : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdyFine; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpDPdyFine(IdResultType type, IdRef op0) + : spv_inst{Op::DPdyFine, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpFwidthFine : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FwidthFine; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpFwidthFine(IdResultType type, IdRef op0) + : spv_inst{Op::FwidthFine, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdxCoarse : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdxCoarse; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpDPdxCoarse(IdResultType type, IdRef op0) + : spv_inst{Op::DPdxCoarse, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdyCoarse : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdyCoarse; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpDPdyCoarse(IdResultType type, IdRef op0) + : spv_inst{Op::DPdyCoarse, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpFwidthCoarse : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FwidthCoarse; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpFwidthCoarse(IdResultType type, IdRef op0) + : spv_inst{Op::FwidthCoarse, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpEmitVertex : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EmitVertex; } + constexpr static std::array required_capabilities = {Capability::Geometry}; + OpEmitVertex() : spv_inst{Op::EmitVertex, false} {} + + private: +}; +class OpEndPrimitive : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EndPrimitive; } + constexpr static std::array required_capabilities = {Capability::Geometry}; + OpEndPrimitive() : spv_inst{Op::EndPrimitive, false} {} + + private: +}; +class OpEmitStreamVertex : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EmitStreamVertex; } + constexpr static std::array required_capabilities = { + Capability::GeometryStreams}; + OpEmitStreamVertex(IdRef op0) : spv_inst{Op::EmitStreamVertex, false}, op0_(std::move(op0)) {} + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpEndStreamPrimitive : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EndStreamPrimitive; } + constexpr static std::array required_capabilities = { + Capability::GeometryStreams}; + OpEndStreamPrimitive(IdRef op0) + : spv_inst{Op::EndStreamPrimitive, false}, op0_(std::move(op0)) {} + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpControlBarrier : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ControlBarrier; } + OpControlBarrier(IdScope op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::ControlBarrier, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdScope op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpMemoryBarrier : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemoryBarrier; } + OpMemoryBarrier(IdScope op0, IdMemorySemantics op1) + : spv_inst{Op::MemoryBarrier, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdMemorySemantics const & { return op1_; } + + private: + IdScope op0_; + IdMemorySemantics op1_; +}; +class OpAtomicLoad : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicLoad; } + OpAtomicLoad(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::AtomicLoad, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpAtomicStore : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicStore; } + OpAtomicStore(IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicStore, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicExchange : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicExchange; } + OpAtomicExchange(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicExchange, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicCompareExchange : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::AtomicCompareExchange; + } + OpAtomicCompareExchange(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, + IdMemorySemantics op3, IdRef op4, IdRef op5) + : spv_inst{Op::AtomicCompareExchange, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() const -> IdMemorySemantics const & { return op3_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() const -> IdRef const & { return op5_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdMemorySemantics op3_; + IdRef op4_; + IdRef op5_; +}; +class OpAtomicCompareExchangeWeak : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::AtomicCompareExchangeWeak; + } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpAtomicCompareExchangeWeak(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, + IdMemorySemantics op3, IdRef op4, IdRef op5) + : spv_inst{Op::AtomicCompareExchangeWeak, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)), op5_(std::move(op5)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() const -> IdMemorySemantics const & { return op3_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() const -> IdRef const & { return op5_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdMemorySemantics op3_; + IdRef op4_; + IdRef op5_; +}; +class OpAtomicIIncrement : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicIIncrement; } + OpAtomicIIncrement(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::AtomicIIncrement, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpAtomicIDecrement : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicIDecrement; } + OpAtomicIDecrement(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::AtomicIDecrement, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpAtomicIAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicIAdd; } + OpAtomicIAdd(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicIAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicISub : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicISub; } + OpAtomicISub(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicISub, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicSMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicSMin; } + OpAtomicSMin(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicSMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicUMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicUMin; } + OpAtomicUMin(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicUMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicSMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicSMax; } + OpAtomicSMax(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicSMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicUMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicUMax; } + OpAtomicUMax(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicUMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicAnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicAnd; } + OpAtomicAnd(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicAnd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicOr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicOr; } + OpAtomicOr(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicOr, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicXor : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicXor; } + OpAtomicXor(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicXor, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpPhi : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Phi; } + OpPhi(IdResultType type, std::vector op0) + : spv_inst{Op::Phi, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> std::vector const & { return op0_; } + + private: + IdResultType type_; + std::vector op0_; +}; +class OpLoopMerge : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LoopMerge; } + OpLoopMerge(IdRef op0, IdRef op1, LoopControl op2) + : spv_inst{Op::LoopMerge, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> LoopControl const & { return op2_; } + + private: + IdRef op0_; + IdRef op1_; + LoopControl op2_; +}; +class OpSelectionMerge : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SelectionMerge; } + OpSelectionMerge(IdRef op0, SelectionControl op1) + : spv_inst{Op::SelectionMerge, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> SelectionControl const & { return op1_; } + + private: + IdRef op0_; + SelectionControl op1_; +}; +class OpLabel : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Label; } + OpLabel() : spv_inst{Op::Label, true} {} + + private: +}; +class OpBranch : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Branch; } + OpBranch(IdRef op0) : spv_inst{Op::Branch, false}, op0_(std::move(op0)) {} + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpBranchConditional : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BranchConditional; } + OpBranchConditional(IdRef op0, IdRef op1, IdRef op2, std::vector op3) + : spv_inst{Op::BranchConditional, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::vector const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::vector op3_; +}; +class OpSwitch : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Switch; } + OpSwitch(IdRef op0, IdRef op1, std::vector op2) + : spv_inst{Op::Switch, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdRef op0_; + IdRef op1_; + std::vector op2_; +}; +class OpKill : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Kill; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpKill() : spv_inst{Op::Kill, false} {} + + private: +}; +class OpReturn : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Return; } + OpReturn() : spv_inst{Op::Return, false} {} + + private: +}; +class OpReturnValue : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReturnValue; } + OpReturnValue(IdRef op0) : spv_inst{Op::ReturnValue, false}, op0_(std::move(op0)) {} + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpUnreachable : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Unreachable; } + OpUnreachable() : spv_inst{Op::Unreachable, false} {} + + private: +}; +class OpLifetimeStart : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LifetimeStart; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpLifetimeStart(IdRef op0, LiteralInteger op1) + : spv_inst{Op::LifetimeStart, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + IdRef op0_; + LiteralInteger op1_; +}; +class OpLifetimeStop : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LifetimeStop; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpLifetimeStop(IdRef op0, LiteralInteger op1) + : spv_inst{Op::LifetimeStop, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + IdRef op0_; + LiteralInteger op1_; +}; +class OpGroupAsyncCopy : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupAsyncCopy; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpGroupAsyncCopy(IdResultType type, IdScope op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4, + IdRef op5) + : spv_inst{Op::GroupAsyncCopy, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() const -> IdRef const & { return op5_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; + IdRef op5_; +}; +class OpGroupWaitEvents : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupWaitEvents; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpGroupWaitEvents(IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupWaitEvents, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupAll : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupAll; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupAll(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupAll, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupAny : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupAny; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupAny(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupAny, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupBroadcast : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupBroadcast; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupBroadcast(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupBroadcast, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupIAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupIAdd; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupIAdd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupIAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupFAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupFAdd; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupFAdd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupFAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupFMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupFMin; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupFMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupFMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupUMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupUMin; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupUMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupUMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupSMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupSMin; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupSMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupSMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupFMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupFMax; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupFMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupFMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupUMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupUMax; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupUMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupUMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupSMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupSMax; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupSMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupSMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpReadPipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReadPipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpReadPipe(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::ReadPipe, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpWritePipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::WritePipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpWritePipe(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::WritePipe, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpReservedReadPipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReservedReadPipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpReservedReadPipe(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4, + IdRef op5) + : spv_inst{Op::ReservedReadPipe, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() const -> IdRef const & { return op5_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; + IdRef op5_; +}; +class OpReservedWritePipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReservedWritePipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpReservedWritePipe(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4, + IdRef op5) + : spv_inst{Op::ReservedWritePipe, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() const -> IdRef const & { return op5_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; + IdRef op5_; +}; +class OpReserveReadPipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ReserveReadPipePackets; + } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpReserveReadPipePackets(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::ReserveReadPipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpReserveWritePipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ReserveWritePipePackets; + } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpReserveWritePipePackets(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::ReserveWritePipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpCommitReadPipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CommitReadPipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpCommitReadPipe(IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::CommitReadPipe, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpCommitWritePipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CommitWritePipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpCommitWritePipe(IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::CommitWritePipe, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpIsValidReserveId : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsValidReserveId; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpIsValidReserveId(IdResultType type, IdRef op0) + : spv_inst{Op::IsValidReserveId, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpGetNumPipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GetNumPipePackets; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGetNumPipePackets(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::GetNumPipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGetMaxPipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GetMaxPipePackets; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGetMaxPipePackets(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::GetMaxPipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupReserveReadPipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupReserveReadPipePackets; + } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGroupReserveReadPipePackets(IdResultType type, IdScope op0, IdRef op1, IdRef op2, IdRef op3, + IdRef op4) + : spv_inst{Op::GroupReserveReadPipePackets, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGroupReserveWritePipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupReserveWritePipePackets; + } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGroupReserveWritePipePackets(IdResultType type, IdScope op0, IdRef op1, IdRef op2, IdRef op3, + IdRef op4) + : spv_inst{Op::GroupReserveWritePipePackets, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGroupCommitReadPipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupCommitReadPipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGroupCommitReadPipe(IdScope op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4) + : spv_inst{Op::GroupCommitReadPipe, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdScope op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGroupCommitWritePipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupCommitWritePipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGroupCommitWritePipe(IdScope op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4) + : spv_inst{Op::GroupCommitWritePipe, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdScope op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpEnqueueMarker : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EnqueueMarker; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpEnqueueMarker(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::EnqueueMarker, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpEnqueueKernel : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EnqueueKernel; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpEnqueueKernel(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4, + IdRef op5, IdRef op6, IdRef op7, IdRef op8, IdRef op9, std::vector op10) + : spv_inst{Op::EnqueueKernel, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)), op6_(std::move(op6)), op7_(std::move(op7)), op8_(std::move(op8)), + op9_(std::move(op9)), op10_(std::move(op10)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() const -> IdRef const & { return op5_; } + inline auto op6() const -> IdRef const & { return op6_; } + inline auto op7() const -> IdRef const & { return op7_; } + inline auto op8() const -> IdRef const & { return op8_; } + inline auto op9() const -> IdRef const & { return op9_; } + inline auto op10() const -> std::vector const & { return op10_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; + IdRef op5_; + IdRef op6_; + IdRef op7_; + IdRef op8_; + IdRef op9_; + std::vector op10_; +}; +class OpGetKernelNDrangeSubGroupCount : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelNDrangeSubGroupCount; + } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpGetKernelNDrangeSubGroupCount(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3, + IdRef op4) + : spv_inst{Op::GetKernelNDrangeSubGroupCount, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGetKernelNDrangeMaxSubGroupSize : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelNDrangeMaxSubGroupSize; + } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpGetKernelNDrangeMaxSubGroupSize(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3, + IdRef op4) + : spv_inst{Op::GetKernelNDrangeMaxSubGroupSize, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGetKernelWorkGroupSize : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelWorkGroupSize; + } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpGetKernelWorkGroupSize(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::GetKernelWorkGroupSize, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpGetKernelPreferredWorkGroupSizeMultiple : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelPreferredWorkGroupSizeMultiple; + } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpGetKernelPreferredWorkGroupSizeMultiple(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + IdRef op3) + : spv_inst{Op::GetKernelPreferredWorkGroupSizeMultiple, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpRetainEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::RetainEvent; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpRetainEvent(IdRef op0) : spv_inst{Op::RetainEvent, false}, op0_(std::move(op0)) {} + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpReleaseEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReleaseEvent; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpReleaseEvent(IdRef op0) : spv_inst{Op::ReleaseEvent, false}, op0_(std::move(op0)) {} + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpCreateUserEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CreateUserEvent; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpCreateUserEvent(IdResultType type) + : spv_inst{Op::CreateUserEvent, true}, type_(std::move(type)) {} + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpIsValidEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsValidEvent; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpIsValidEvent(IdResultType type, IdRef op0) + : spv_inst{Op::IsValidEvent, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSetUserEventStatus : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SetUserEventStatus; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpSetUserEventStatus(IdRef op0, IdRef op1) + : spv_inst{Op::SetUserEventStatus, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdRef op0_; + IdRef op1_; +}; +class OpCaptureEventProfilingInfo : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CaptureEventProfilingInfo; + } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpCaptureEventProfilingInfo(IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::CaptureEventProfilingInfo, false}, op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGetDefaultQueue : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GetDefaultQueue; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpGetDefaultQueue(IdResultType type) + : spv_inst{Op::GetDefaultQueue, true}, type_(std::move(type)) {} + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpBuildNDRange : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BuildNDRange; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpBuildNDRange(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::BuildNDRange, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpImageSparseSampleImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleImplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleImplicitLod(IdResultType type, IdRef op0, IdRef op1, + std::optional op2) + : spv_inst{Op::ImageSparseSampleImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageSparseSampleExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleExplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleExplicitLod(IdResultType type, IdRef op0, IdRef op1, ImageOperands op2) + : spv_inst{Op::ImageSparseSampleExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> ImageOperands const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + ImageOperands op2_; +}; +class OpImageSparseSampleDrefImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleDrefImplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleDrefImplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3) + : spv_inst{Op::ImageSparseSampleDrefImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSparseSampleDrefExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleDrefExplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleDrefExplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + ImageOperands op3) + : spv_inst{Op::ImageSparseSampleDrefExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> ImageOperands const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + ImageOperands op3_; +}; +class OpImageSparseSampleProjImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleProjImplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleProjImplicitLod(IdResultType type, IdRef op0, IdRef op1, + std::optional op2) + : spv_inst{Op::ImageSparseSampleProjImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageSparseSampleProjExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleProjExplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleProjExplicitLod(IdResultType type, IdRef op0, IdRef op1, ImageOperands op2) + : spv_inst{Op::ImageSparseSampleProjExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> ImageOperands const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + ImageOperands op2_; +}; +class OpImageSparseSampleProjDrefImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleProjDrefImplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleProjDrefImplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3) + : spv_inst{Op::ImageSparseSampleProjDrefImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSparseSampleProjDrefExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleProjDrefExplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleProjDrefExplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + ImageOperands op3) + : spv_inst{Op::ImageSparseSampleProjDrefExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> ImageOperands const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + ImageOperands op3_; +}; +class OpImageSparseFetch : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageSparseFetch; } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseFetch(IdResultType type, IdRef op0, IdRef op1, std::optional op2) + : spv_inst{Op::ImageSparseFetch, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageSparseGather : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageSparseGather; } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseGather(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3) + : spv_inst{Op::ImageSparseGather, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSparseDrefGather : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseDrefGather; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseDrefGather(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3) + : spv_inst{Op::ImageSparseDrefGather, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSparseTexelsResident : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseTexelsResident; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseTexelsResident(IdResultType type, IdRef op0) + : spv_inst{Op::ImageSparseTexelsResident, true}, type_(std::move(type)), + op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpNoLine : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::NoLine; } + OpNoLine() : spv_inst{Op::NoLine, false} {} + + private: +}; +class OpAtomicFlagTestAndSet : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicFlagTestAndSet; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpAtomicFlagTestAndSet(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::AtomicFlagTestAndSet, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpAtomicFlagClear : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicFlagClear; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpAtomicFlagClear(IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::AtomicFlagClear, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpImageSparseRead : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageSparseRead; } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseRead(IdResultType type, IdRef op0, IdRef op1, std::optional op2) + : spv_inst{Op::ImageSparseRead, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpSizeOf : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SizeOf; } + constexpr static std::array required_capabilities = {Capability::Addresses}; + OpSizeOf(IdResultType type, IdRef op0) + : spv_inst{Op::SizeOf, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpTypePipeStorage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypePipeStorage; } + constexpr static std::array required_capabilities = {Capability::PipeStorage}; + OpTypePipeStorage() : spv_inst{Op::TypePipeStorage, true} {} + + private: +}; +class OpConstantPipeStorage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantPipeStorage; } + constexpr static std::array required_capabilities = {Capability::PipeStorage}; + OpConstantPipeStorage(IdResultType type, LiteralInteger op0, LiteralInteger op1, + LiteralInteger op2) + : spv_inst{Op::ConstantPipeStorage, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> LiteralInteger const & { return op0_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() const -> LiteralInteger const & { return op2_; } + + private: + IdResultType type_; + LiteralInteger op0_; + LiteralInteger op1_; + LiteralInteger op2_; +}; +class OpCreatePipeFromPipeStorage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CreatePipeFromPipeStorage; + } + constexpr static std::array required_capabilities = {Capability::PipeStorage}; + OpCreatePipeFromPipeStorage(IdResultType type, IdRef op0) + : spv_inst{Op::CreatePipeFromPipeStorage, true}, type_(std::move(type)), + op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpGetKernelLocalSizeForSubgroupCount : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelLocalSizeForSubgroupCount; + } + constexpr static std::array required_capabilities = { + Capability::SubgroupDispatch}; + OpGetKernelLocalSizeForSubgroupCount(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + IdRef op3, IdRef op4) + : spv_inst{Op::GetKernelLocalSizeForSubgroupCount, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGetKernelMaxNumSubgroups : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelMaxNumSubgroups; + } + constexpr static std::array required_capabilities = { + Capability::SubgroupDispatch}; + OpGetKernelMaxNumSubgroups(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::GetKernelMaxNumSubgroups, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpTypeNamedBarrier : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeNamedBarrier; } + constexpr static std::array required_capabilities = {Capability::NamedBarrier}; + OpTypeNamedBarrier() : spv_inst{Op::TypeNamedBarrier, true} {} + + private: +}; +class OpNamedBarrierInitialize : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::NamedBarrierInitialize; + } + constexpr static std::array required_capabilities = {Capability::NamedBarrier}; + OpNamedBarrierInitialize(IdResultType type, IdRef op0) + : spv_inst{Op::NamedBarrierInitialize, true}, type_(std::move(type)), op0_(std::move(op0)) { + } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpMemoryNamedBarrier : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemoryNamedBarrier; } + constexpr static std::array required_capabilities = {Capability::NamedBarrier}; + OpMemoryNamedBarrier(IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::MemoryNamedBarrier, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpModuleProcessed : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ModuleProcessed; } + OpModuleProcessed(LiteralString op0) + : spv_inst{Op::ModuleProcessed, false}, op0_(std::move(op0)) {} + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpExecutionModeId : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ExecutionModeId; } + OpExecutionModeId(IdRef op0, ExecutionMode op1) + : spv_inst{Op::ExecutionModeId, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> ExecutionMode const & { return op1_; } + + private: + IdRef op0_; + ExecutionMode op1_; +}; +class OpDecorateId : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DecorateId; } + OpDecorateId(IdRef op0, Decoration op1) + : spv_inst{Op::DecorateId, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> Decoration const & { return op1_; } + + private: + IdRef op0_; + Decoration op1_; +}; +class OpGroupNonUniformElect : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformElect; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniform}; + OpGroupNonUniformElect(IdResultType type, IdScope op0) + : spv_inst{Op::GroupNonUniformElect, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + + private: + IdResultType type_; + IdScope op0_; +}; +class OpGroupNonUniformAll : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformAll; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformVote}; + OpGroupNonUniformAll(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformAll, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformAny : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformAny; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformVote}; + OpGroupNonUniformAny(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformAny, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformAllEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformAllEqual; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformVote}; + OpGroupNonUniformAllEqual(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformAllEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformBroadcast : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBroadcast; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBroadcast(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformBroadcast, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformBroadcastFirst : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBroadcastFirst; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBroadcastFirst(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformBroadcastFirst, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformBallot : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBallot; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBallot(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformBallot, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformInverseBallot : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformInverseBallot; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformInverseBallot(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformInverseBallot, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformBallotBitExtract : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBallotBitExtract; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBallotBitExtract(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformBallotBitExtract, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformBallotBitCount : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBallotBitCount; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBallotBitCount(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupNonUniformBallotBitCount, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupNonUniformBallotFindLSB : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBallotFindLSB; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBallotFindLSB(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformBallotFindLSB, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformBallotFindMSB : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBallotFindMSB; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBallotFindMSB(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformBallotFindMSB, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformShuffle : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformShuffle; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformShuffle}; + OpGroupNonUniformShuffle(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformShuffle, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformShuffleXor : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformShuffleXor; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformShuffle}; + OpGroupNonUniformShuffleXor(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformShuffleXor, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformShuffleUp : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformShuffleUp; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformShuffleRelative}; + OpGroupNonUniformShuffleUp(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformShuffleUp, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformShuffleDown : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformShuffleDown; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformShuffleRelative}; + OpGroupNonUniformShuffleDown(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformShuffleDown, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformIAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformIAdd; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformIAdd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3) + : spv_inst{Op::GroupNonUniformIAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformFAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformFAdd; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformFAdd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3) + : spv_inst{Op::GroupNonUniformFAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformIMul : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformIMul; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformIMul(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3) + : spv_inst{Op::GroupNonUniformIMul, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformFMul : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformFMul; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformFMul(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3) + : spv_inst{Op::GroupNonUniformFMul, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformSMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformSMin; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformSMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3) + : spv_inst{Op::GroupNonUniformSMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformUMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformUMin; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformUMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3) + : spv_inst{Op::GroupNonUniformUMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformFMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformFMin; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformFMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3) + : spv_inst{Op::GroupNonUniformFMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformSMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformSMax; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformSMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3) + : spv_inst{Op::GroupNonUniformSMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformUMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformUMax; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformUMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3) + : spv_inst{Op::GroupNonUniformUMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformFMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformFMax; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformFMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3) + : spv_inst{Op::GroupNonUniformFMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformBitwiseAnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBitwiseAnd; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformBitwiseAnd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3) + : spv_inst{Op::GroupNonUniformBitwiseAnd, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformBitwiseOr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBitwiseOr; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformBitwiseOr(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3) + : spv_inst{Op::GroupNonUniformBitwiseOr, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformBitwiseXor : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBitwiseXor; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformBitwiseXor(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3) + : spv_inst{Op::GroupNonUniformBitwiseXor, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformLogicalAnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformLogicalAnd; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformLogicalAnd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3) + : spv_inst{Op::GroupNonUniformLogicalAnd, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformLogicalOr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformLogicalOr; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformLogicalOr(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3) + : spv_inst{Op::GroupNonUniformLogicalOr, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformLogicalXor : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformLogicalXor; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformLogicalXor(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3) + : spv_inst{Op::GroupNonUniformLogicalXor, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformQuadBroadcast : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformQuadBroadcast; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformQuad}; + OpGroupNonUniformQuadBroadcast(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformQuadBroadcast, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformQuadSwap : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformQuadSwap; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformQuad}; + OpGroupNonUniformQuadSwap(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformQuadSwap, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpCopyLogical : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CopyLogical; } + OpCopyLogical(IdResultType type, IdRef op0) + : spv_inst{Op::CopyLogical, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpPtrEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::PtrEqual; } + OpPtrEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::PtrEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpPtrNotEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::PtrNotEqual; } + OpPtrNotEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::PtrNotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpPtrDiff : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::PtrDiff; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::VariablePointers, + Capability::VariablePointersStorageBuffer}; + OpPtrDiff(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::PtrDiff, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpTypeCooperativeMatrixKHR : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::TypeCooperativeMatrixKHR; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixKHR}; + OpTypeCooperativeMatrixKHR(IdRef op0, IdScope op1, IdRef op2, IdRef op3, IdRef op4) + : spv_inst{Op::TypeCooperativeMatrixKHR, true}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdRef op0_; + IdScope op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpCooperativeMatrixLoadKHR : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CooperativeMatrixLoadKHR; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixKHR}; + OpCooperativeMatrixLoadKHR(IdResultType type, IdRef op0, IdRef op1, std::optional op2, + std::optional op3) + : spv_inst{Op::CooperativeMatrixLoadKHR, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> std::optional const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; + std::optional op3_; +}; +class OpCooperativeMatrixStoreKHR : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CooperativeMatrixStoreKHR; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixKHR}; + OpCooperativeMatrixStoreKHR(IdRef op0, IdRef op1, IdRef op2, std::optional op3, + std::optional op4) + : spv_inst{Op::CooperativeMatrixStoreKHR, false}, op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + inline auto op4() const -> std::optional const & { return op4_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; + std::optional op4_; +}; +class OpCooperativeMatrixMulAddKHR : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CooperativeMatrixMulAddKHR; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixKHR}; + OpCooperativeMatrixMulAddKHR(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3) + : spv_inst{Op::CooperativeMatrixMulAddKHR, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpCooperativeMatrixLengthKHR : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CooperativeMatrixLengthKHR; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixKHR}; + OpCooperativeMatrixLengthKHR(IdResultType type, IdRef op0) + : spv_inst{Op::CooperativeMatrixLengthKHR, true}, type_(std::move(type)), + op0_(std::move(op0)) {} + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; + +} // namespace tinytc::spv + +#endif // GENERATED_INSTRUCTIONS_2024114_HPP diff --git a/src/spv/module.cpp b/src/spv/module.cpp new file mode 100644 index 00000000..63abaf83 --- /dev/null +++ b/src/spv/module.cpp @@ -0,0 +1,16 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/module.hpp" +#include "spv/instructions.hpp" + +namespace tinytc { +void ilist_callbacks::node_added(spv::spv_inst *) {} +void ilist_callbacks::node_removed(spv::spv_inst *node) { delete node; } +} // namespace tinytc + +namespace tinytc::spv { +mod::mod() {} +mod::~mod() {} +} // namespace tinytc::spv + diff --git a/src/spv/module.hpp b/src/spv/module.hpp new file mode 100644 index 00000000..c0d14d60 --- /dev/null +++ b/src/spv/module.hpp @@ -0,0 +1,59 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef MODULE_20241029_HPP +#define MODULE_20241029_HPP + +#include "reference_counted.hpp" +#include "support/ilist.hpp" +#include "support/ilist_base.hpp" + +#include +#include + +namespace tinytc { + +namespace spv { +class spv_inst; +} + +template <> struct ilist_callbacks { + void node_added(spv::spv_inst *node); + void node_removed(spv::spv_inst *node); +}; + +namespace spv { + +enum class section { + capability = 0, + memory_model = 1, + entry_point = 2, + execution_mode = 3, + decoration = 4, + type = 5, + function = 6 +}; +inline constexpr std::size_t num_module_sections = 7; + +class mod final { + public: + using iterator = ilist::iterator; + using const_iterator = ilist::const_iterator; + + mod(); + ~mod(); + + inline auto insts(section s) -> ilist & { return insts_[static_cast(s)]; } + inline auto insts(section s) const -> ilist const & { + return insts_[static_cast(s)]; + } + inline auto empty(section s) const -> bool { return insts_[static_cast(s)].empty(); } + + private: + std::array, num_module_sections> insts_; +}; + +} // namespace spv +} // namespace tinytc + +#endif // MODULE_20241029_HPP diff --git a/src/spv/names.cpp b/src/spv/names.cpp new file mode 100644 index 00000000..259e9096 --- /dev/null +++ b/src/spv/names.cpp @@ -0,0 +1,2893 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_NAMES_2024114_HPP +#define GENERATED_NAMES_2024114_HPP + +#include "names.hpp" +#include "enums.hpp" + +namespace tinytc::spv { + +auto to_string(Op op) -> char const * { + switch (op) { + case Op::Nop: + return "Nop"; + case Op::Undef: + return "Undef"; + case Op::SourceContinued: + return "SourceContinued"; + case Op::Source: + return "Source"; + case Op::SourceExtension: + return "SourceExtension"; + case Op::Name: + return "Name"; + case Op::MemberName: + return "MemberName"; + case Op::String: + return "String"; + case Op::Line: + return "Line"; + case Op::Extension: + return "Extension"; + case Op::ExtInstImport: + return "ExtInstImport"; + case Op::ExtInst: + return "ExtInst"; + case Op::MemoryModel: + return "MemoryModel"; + case Op::EntryPoint: + return "EntryPoint"; + case Op::ExecutionMode: + return "ExecutionMode"; + case Op::Capability: + return "Capability"; + case Op::TypeVoid: + return "TypeVoid"; + case Op::TypeBool: + return "TypeBool"; + case Op::TypeInt: + return "TypeInt"; + case Op::TypeFloat: + return "TypeFloat"; + case Op::TypeVector: + return "TypeVector"; + case Op::TypeMatrix: + return "TypeMatrix"; + case Op::TypeImage: + return "TypeImage"; + case Op::TypeSampler: + return "TypeSampler"; + case Op::TypeSampledImage: + return "TypeSampledImage"; + case Op::TypeArray: + return "TypeArray"; + case Op::TypeRuntimeArray: + return "TypeRuntimeArray"; + case Op::TypeStruct: + return "TypeStruct"; + case Op::TypeOpaque: + return "TypeOpaque"; + case Op::TypePointer: + return "TypePointer"; + case Op::TypeFunction: + return "TypeFunction"; + case Op::TypeEvent: + return "TypeEvent"; + case Op::TypeDeviceEvent: + return "TypeDeviceEvent"; + case Op::TypeReserveId: + return "TypeReserveId"; + case Op::TypeQueue: + return "TypeQueue"; + case Op::TypePipe: + return "TypePipe"; + case Op::TypeForwardPointer: + return "TypeForwardPointer"; + case Op::ConstantTrue: + return "ConstantTrue"; + case Op::ConstantFalse: + return "ConstantFalse"; + case Op::Constant: + return "Constant"; + case Op::ConstantComposite: + return "ConstantComposite"; + case Op::ConstantSampler: + return "ConstantSampler"; + case Op::ConstantNull: + return "ConstantNull"; + case Op::Function: + return "Function"; + case Op::FunctionParameter: + return "FunctionParameter"; + case Op::FunctionEnd: + return "FunctionEnd"; + case Op::FunctionCall: + return "FunctionCall"; + case Op::Variable: + return "Variable"; + case Op::ImageTexelPointer: + return "ImageTexelPointer"; + case Op::Load: + return "Load"; + case Op::Store: + return "Store"; + case Op::CopyMemory: + return "CopyMemory"; + case Op::CopyMemorySized: + return "CopyMemorySized"; + case Op::AccessChain: + return "AccessChain"; + case Op::InBoundsAccessChain: + return "InBoundsAccessChain"; + case Op::PtrAccessChain: + return "PtrAccessChain"; + case Op::ArrayLength: + return "ArrayLength"; + case Op::GenericPtrMemSemantics: + return "GenericPtrMemSemantics"; + case Op::InBoundsPtrAccessChain: + return "InBoundsPtrAccessChain"; + case Op::Decorate: + return "Decorate"; + case Op::MemberDecorate: + return "MemberDecorate"; + case Op::DecorationGroup: + return "DecorationGroup"; + case Op::GroupDecorate: + return "GroupDecorate"; + case Op::GroupMemberDecorate: + return "GroupMemberDecorate"; + case Op::VectorExtractDynamic: + return "VectorExtractDynamic"; + case Op::VectorInsertDynamic: + return "VectorInsertDynamic"; + case Op::VectorShuffle: + return "VectorShuffle"; + case Op::CompositeConstruct: + return "CompositeConstruct"; + case Op::CompositeExtract: + return "CompositeExtract"; + case Op::CompositeInsert: + return "CompositeInsert"; + case Op::CopyObject: + return "CopyObject"; + case Op::Transpose: + return "Transpose"; + case Op::SampledImage: + return "SampledImage"; + case Op::ImageSampleImplicitLod: + return "ImageSampleImplicitLod"; + case Op::ImageSampleExplicitLod: + return "ImageSampleExplicitLod"; + case Op::ImageSampleDrefImplicitLod: + return "ImageSampleDrefImplicitLod"; + case Op::ImageSampleDrefExplicitLod: + return "ImageSampleDrefExplicitLod"; + case Op::ImageSampleProjImplicitLod: + return "ImageSampleProjImplicitLod"; + case Op::ImageSampleProjExplicitLod: + return "ImageSampleProjExplicitLod"; + case Op::ImageSampleProjDrefImplicitLod: + return "ImageSampleProjDrefImplicitLod"; + case Op::ImageSampleProjDrefExplicitLod: + return "ImageSampleProjDrefExplicitLod"; + case Op::ImageFetch: + return "ImageFetch"; + case Op::ImageGather: + return "ImageGather"; + case Op::ImageDrefGather: + return "ImageDrefGather"; + case Op::ImageRead: + return "ImageRead"; + case Op::ImageWrite: + return "ImageWrite"; + case Op::Image: + return "Image"; + case Op::ImageQueryFormat: + return "ImageQueryFormat"; + case Op::ImageQueryOrder: + return "ImageQueryOrder"; + case Op::ImageQuerySizeLod: + return "ImageQuerySizeLod"; + case Op::ImageQuerySize: + return "ImageQuerySize"; + case Op::ImageQueryLod: + return "ImageQueryLod"; + case Op::ImageQueryLevels: + return "ImageQueryLevels"; + case Op::ImageQuerySamples: + return "ImageQuerySamples"; + case Op::ConvertFToU: + return "ConvertFToU"; + case Op::ConvertFToS: + return "ConvertFToS"; + case Op::ConvertSToF: + return "ConvertSToF"; + case Op::ConvertUToF: + return "ConvertUToF"; + case Op::UConvert: + return "UConvert"; + case Op::SConvert: + return "SConvert"; + case Op::FConvert: + return "FConvert"; + case Op::QuantizeToF16: + return "QuantizeToF16"; + case Op::ConvertPtrToU: + return "ConvertPtrToU"; + case Op::SatConvertSToU: + return "SatConvertSToU"; + case Op::SatConvertUToS: + return "SatConvertUToS"; + case Op::ConvertUToPtr: + return "ConvertUToPtr"; + case Op::PtrCastToGeneric: + return "PtrCastToGeneric"; + case Op::GenericCastToPtr: + return "GenericCastToPtr"; + case Op::GenericCastToPtrExplicit: + return "GenericCastToPtrExplicit"; + case Op::Bitcast: + return "Bitcast"; + case Op::SNegate: + return "SNegate"; + case Op::FNegate: + return "FNegate"; + case Op::IAdd: + return "IAdd"; + case Op::FAdd: + return "FAdd"; + case Op::ISub: + return "ISub"; + case Op::FSub: + return "FSub"; + case Op::IMul: + return "IMul"; + case Op::FMul: + return "FMul"; + case Op::UDiv: + return "UDiv"; + case Op::SDiv: + return "SDiv"; + case Op::FDiv: + return "FDiv"; + case Op::UMod: + return "UMod"; + case Op::SRem: + return "SRem"; + case Op::SMod: + return "SMod"; + case Op::FRem: + return "FRem"; + case Op::FMod: + return "FMod"; + case Op::VectorTimesScalar: + return "VectorTimesScalar"; + case Op::MatrixTimesScalar: + return "MatrixTimesScalar"; + case Op::VectorTimesMatrix: + return "VectorTimesMatrix"; + case Op::MatrixTimesVector: + return "MatrixTimesVector"; + case Op::MatrixTimesMatrix: + return "MatrixTimesMatrix"; + case Op::OuterProduct: + return "OuterProduct"; + case Op::Dot: + return "Dot"; + case Op::IAddCarry: + return "IAddCarry"; + case Op::ISubBorrow: + return "ISubBorrow"; + case Op::UMulExtended: + return "UMulExtended"; + case Op::SMulExtended: + return "SMulExtended"; + case Op::Any: + return "Any"; + case Op::All: + return "All"; + case Op::IsNan: + return "IsNan"; + case Op::IsInf: + return "IsInf"; + case Op::IsFinite: + return "IsFinite"; + case Op::IsNormal: + return "IsNormal"; + case Op::SignBitSet: + return "SignBitSet"; + case Op::LessOrGreater: + return "LessOrGreater"; + case Op::Ordered: + return "Ordered"; + case Op::Unordered: + return "Unordered"; + case Op::LogicalEqual: + return "LogicalEqual"; + case Op::LogicalNotEqual: + return "LogicalNotEqual"; + case Op::LogicalOr: + return "LogicalOr"; + case Op::LogicalAnd: + return "LogicalAnd"; + case Op::LogicalNot: + return "LogicalNot"; + case Op::Select: + return "Select"; + case Op::IEqual: + return "IEqual"; + case Op::INotEqual: + return "INotEqual"; + case Op::UGreaterThan: + return "UGreaterThan"; + case Op::SGreaterThan: + return "SGreaterThan"; + case Op::UGreaterThanEqual: + return "UGreaterThanEqual"; + case Op::SGreaterThanEqual: + return "SGreaterThanEqual"; + case Op::ULessThan: + return "ULessThan"; + case Op::SLessThan: + return "SLessThan"; + case Op::ULessThanEqual: + return "ULessThanEqual"; + case Op::SLessThanEqual: + return "SLessThanEqual"; + case Op::FOrdEqual: + return "FOrdEqual"; + case Op::FUnordEqual: + return "FUnordEqual"; + case Op::FOrdNotEqual: + return "FOrdNotEqual"; + case Op::FUnordNotEqual: + return "FUnordNotEqual"; + case Op::FOrdLessThan: + return "FOrdLessThan"; + case Op::FUnordLessThan: + return "FUnordLessThan"; + case Op::FOrdGreaterThan: + return "FOrdGreaterThan"; + case Op::FUnordGreaterThan: + return "FUnordGreaterThan"; + case Op::FOrdLessThanEqual: + return "FOrdLessThanEqual"; + case Op::FUnordLessThanEqual: + return "FUnordLessThanEqual"; + case Op::FOrdGreaterThanEqual: + return "FOrdGreaterThanEqual"; + case Op::FUnordGreaterThanEqual: + return "FUnordGreaterThanEqual"; + case Op::ShiftRightLogical: + return "ShiftRightLogical"; + case Op::ShiftRightArithmetic: + return "ShiftRightArithmetic"; + case Op::ShiftLeftLogical: + return "ShiftLeftLogical"; + case Op::BitwiseOr: + return "BitwiseOr"; + case Op::BitwiseXor: + return "BitwiseXor"; + case Op::BitwiseAnd: + return "BitwiseAnd"; + case Op::Not: + return "Not"; + case Op::BitFieldInsert: + return "BitFieldInsert"; + case Op::BitFieldSExtract: + return "BitFieldSExtract"; + case Op::BitFieldUExtract: + return "BitFieldUExtract"; + case Op::BitReverse: + return "BitReverse"; + case Op::BitCount: + return "BitCount"; + case Op::DPdx: + return "DPdx"; + case Op::DPdy: + return "DPdy"; + case Op::Fwidth: + return "Fwidth"; + case Op::DPdxFine: + return "DPdxFine"; + case Op::DPdyFine: + return "DPdyFine"; + case Op::FwidthFine: + return "FwidthFine"; + case Op::DPdxCoarse: + return "DPdxCoarse"; + case Op::DPdyCoarse: + return "DPdyCoarse"; + case Op::FwidthCoarse: + return "FwidthCoarse"; + case Op::EmitVertex: + return "EmitVertex"; + case Op::EndPrimitive: + return "EndPrimitive"; + case Op::EmitStreamVertex: + return "EmitStreamVertex"; + case Op::EndStreamPrimitive: + return "EndStreamPrimitive"; + case Op::ControlBarrier: + return "ControlBarrier"; + case Op::MemoryBarrier: + return "MemoryBarrier"; + case Op::AtomicLoad: + return "AtomicLoad"; + case Op::AtomicStore: + return "AtomicStore"; + case Op::AtomicExchange: + return "AtomicExchange"; + case Op::AtomicCompareExchange: + return "AtomicCompareExchange"; + case Op::AtomicCompareExchangeWeak: + return "AtomicCompareExchangeWeak"; + case Op::AtomicIIncrement: + return "AtomicIIncrement"; + case Op::AtomicIDecrement: + return "AtomicIDecrement"; + case Op::AtomicIAdd: + return "AtomicIAdd"; + case Op::AtomicISub: + return "AtomicISub"; + case Op::AtomicSMin: + return "AtomicSMin"; + case Op::AtomicUMin: + return "AtomicUMin"; + case Op::AtomicSMax: + return "AtomicSMax"; + case Op::AtomicUMax: + return "AtomicUMax"; + case Op::AtomicAnd: + return "AtomicAnd"; + case Op::AtomicOr: + return "AtomicOr"; + case Op::AtomicXor: + return "AtomicXor"; + case Op::Phi: + return "Phi"; + case Op::LoopMerge: + return "LoopMerge"; + case Op::SelectionMerge: + return "SelectionMerge"; + case Op::Label: + return "Label"; + case Op::Branch: + return "Branch"; + case Op::BranchConditional: + return "BranchConditional"; + case Op::Switch: + return "Switch"; + case Op::Kill: + return "Kill"; + case Op::Return: + return "Return"; + case Op::ReturnValue: + return "ReturnValue"; + case Op::Unreachable: + return "Unreachable"; + case Op::LifetimeStart: + return "LifetimeStart"; + case Op::LifetimeStop: + return "LifetimeStop"; + case Op::GroupAsyncCopy: + return "GroupAsyncCopy"; + case Op::GroupWaitEvents: + return "GroupWaitEvents"; + case Op::GroupAll: + return "GroupAll"; + case Op::GroupAny: + return "GroupAny"; + case Op::GroupBroadcast: + return "GroupBroadcast"; + case Op::GroupIAdd: + return "GroupIAdd"; + case Op::GroupFAdd: + return "GroupFAdd"; + case Op::GroupFMin: + return "GroupFMin"; + case Op::GroupUMin: + return "GroupUMin"; + case Op::GroupSMin: + return "GroupSMin"; + case Op::GroupFMax: + return "GroupFMax"; + case Op::GroupUMax: + return "GroupUMax"; + case Op::GroupSMax: + return "GroupSMax"; + case Op::ReadPipe: + return "ReadPipe"; + case Op::WritePipe: + return "WritePipe"; + case Op::ReservedReadPipe: + return "ReservedReadPipe"; + case Op::ReservedWritePipe: + return "ReservedWritePipe"; + case Op::ReserveReadPipePackets: + return "ReserveReadPipePackets"; + case Op::ReserveWritePipePackets: + return "ReserveWritePipePackets"; + case Op::CommitReadPipe: + return "CommitReadPipe"; + case Op::CommitWritePipe: + return "CommitWritePipe"; + case Op::IsValidReserveId: + return "IsValidReserveId"; + case Op::GetNumPipePackets: + return "GetNumPipePackets"; + case Op::GetMaxPipePackets: + return "GetMaxPipePackets"; + case Op::GroupReserveReadPipePackets: + return "GroupReserveReadPipePackets"; + case Op::GroupReserveWritePipePackets: + return "GroupReserveWritePipePackets"; + case Op::GroupCommitReadPipe: + return "GroupCommitReadPipe"; + case Op::GroupCommitWritePipe: + return "GroupCommitWritePipe"; + case Op::EnqueueMarker: + return "EnqueueMarker"; + case Op::EnqueueKernel: + return "EnqueueKernel"; + case Op::GetKernelNDrangeSubGroupCount: + return "GetKernelNDrangeSubGroupCount"; + case Op::GetKernelNDrangeMaxSubGroupSize: + return "GetKernelNDrangeMaxSubGroupSize"; + case Op::GetKernelWorkGroupSize: + return "GetKernelWorkGroupSize"; + case Op::GetKernelPreferredWorkGroupSizeMultiple: + return "GetKernelPreferredWorkGroupSizeMultiple"; + case Op::RetainEvent: + return "RetainEvent"; + case Op::ReleaseEvent: + return "ReleaseEvent"; + case Op::CreateUserEvent: + return "CreateUserEvent"; + case Op::IsValidEvent: + return "IsValidEvent"; + case Op::SetUserEventStatus: + return "SetUserEventStatus"; + case Op::CaptureEventProfilingInfo: + return "CaptureEventProfilingInfo"; + case Op::GetDefaultQueue: + return "GetDefaultQueue"; + case Op::BuildNDRange: + return "BuildNDRange"; + case Op::ImageSparseSampleImplicitLod: + return "ImageSparseSampleImplicitLod"; + case Op::ImageSparseSampleExplicitLod: + return "ImageSparseSampleExplicitLod"; + case Op::ImageSparseSampleDrefImplicitLod: + return "ImageSparseSampleDrefImplicitLod"; + case Op::ImageSparseSampleDrefExplicitLod: + return "ImageSparseSampleDrefExplicitLod"; + case Op::ImageSparseSampleProjImplicitLod: + return "ImageSparseSampleProjImplicitLod"; + case Op::ImageSparseSampleProjExplicitLod: + return "ImageSparseSampleProjExplicitLod"; + case Op::ImageSparseSampleProjDrefImplicitLod: + return "ImageSparseSampleProjDrefImplicitLod"; + case Op::ImageSparseSampleProjDrefExplicitLod: + return "ImageSparseSampleProjDrefExplicitLod"; + case Op::ImageSparseFetch: + return "ImageSparseFetch"; + case Op::ImageSparseGather: + return "ImageSparseGather"; + case Op::ImageSparseDrefGather: + return "ImageSparseDrefGather"; + case Op::ImageSparseTexelsResident: + return "ImageSparseTexelsResident"; + case Op::NoLine: + return "NoLine"; + case Op::AtomicFlagTestAndSet: + return "AtomicFlagTestAndSet"; + case Op::AtomicFlagClear: + return "AtomicFlagClear"; + case Op::ImageSparseRead: + return "ImageSparseRead"; + case Op::SizeOf: + return "SizeOf"; + case Op::TypePipeStorage: + return "TypePipeStorage"; + case Op::ConstantPipeStorage: + return "ConstantPipeStorage"; + case Op::CreatePipeFromPipeStorage: + return "CreatePipeFromPipeStorage"; + case Op::GetKernelLocalSizeForSubgroupCount: + return "GetKernelLocalSizeForSubgroupCount"; + case Op::GetKernelMaxNumSubgroups: + return "GetKernelMaxNumSubgroups"; + case Op::TypeNamedBarrier: + return "TypeNamedBarrier"; + case Op::NamedBarrierInitialize: + return "NamedBarrierInitialize"; + case Op::MemoryNamedBarrier: + return "MemoryNamedBarrier"; + case Op::ModuleProcessed: + return "ModuleProcessed"; + case Op::ExecutionModeId: + return "ExecutionModeId"; + case Op::DecorateId: + return "DecorateId"; + case Op::GroupNonUniformElect: + return "GroupNonUniformElect"; + case Op::GroupNonUniformAll: + return "GroupNonUniformAll"; + case Op::GroupNonUniformAny: + return "GroupNonUniformAny"; + case Op::GroupNonUniformAllEqual: + return "GroupNonUniformAllEqual"; + case Op::GroupNonUniformBroadcast: + return "GroupNonUniformBroadcast"; + case Op::GroupNonUniformBroadcastFirst: + return "GroupNonUniformBroadcastFirst"; + case Op::GroupNonUniformBallot: + return "GroupNonUniformBallot"; + case Op::GroupNonUniformInverseBallot: + return "GroupNonUniformInverseBallot"; + case Op::GroupNonUniformBallotBitExtract: + return "GroupNonUniformBallotBitExtract"; + case Op::GroupNonUniformBallotBitCount: + return "GroupNonUniformBallotBitCount"; + case Op::GroupNonUniformBallotFindLSB: + return "GroupNonUniformBallotFindLSB"; + case Op::GroupNonUniformBallotFindMSB: + return "GroupNonUniformBallotFindMSB"; + case Op::GroupNonUniformShuffle: + return "GroupNonUniformShuffle"; + case Op::GroupNonUniformShuffleXor: + return "GroupNonUniformShuffleXor"; + case Op::GroupNonUniformShuffleUp: + return "GroupNonUniformShuffleUp"; + case Op::GroupNonUniformShuffleDown: + return "GroupNonUniformShuffleDown"; + case Op::GroupNonUniformIAdd: + return "GroupNonUniformIAdd"; + case Op::GroupNonUniformFAdd: + return "GroupNonUniformFAdd"; + case Op::GroupNonUniformIMul: + return "GroupNonUniformIMul"; + case Op::GroupNonUniformFMul: + return "GroupNonUniformFMul"; + case Op::GroupNonUniformSMin: + return "GroupNonUniformSMin"; + case Op::GroupNonUniformUMin: + return "GroupNonUniformUMin"; + case Op::GroupNonUniformFMin: + return "GroupNonUniformFMin"; + case Op::GroupNonUniformSMax: + return "GroupNonUniformSMax"; + case Op::GroupNonUniformUMax: + return "GroupNonUniformUMax"; + case Op::GroupNonUniformFMax: + return "GroupNonUniformFMax"; + case Op::GroupNonUniformBitwiseAnd: + return "GroupNonUniformBitwiseAnd"; + case Op::GroupNonUniformBitwiseOr: + return "GroupNonUniformBitwiseOr"; + case Op::GroupNonUniformBitwiseXor: + return "GroupNonUniformBitwiseXor"; + case Op::GroupNonUniformLogicalAnd: + return "GroupNonUniformLogicalAnd"; + case Op::GroupNonUniformLogicalOr: + return "GroupNonUniformLogicalOr"; + case Op::GroupNonUniformLogicalXor: + return "GroupNonUniformLogicalXor"; + case Op::GroupNonUniformQuadBroadcast: + return "GroupNonUniformQuadBroadcast"; + case Op::GroupNonUniformQuadSwap: + return "GroupNonUniformQuadSwap"; + case Op::CopyLogical: + return "CopyLogical"; + case Op::PtrEqual: + return "PtrEqual"; + case Op::PtrNotEqual: + return "PtrNotEqual"; + case Op::PtrDiff: + return "PtrDiff"; + case Op::TypeCooperativeMatrixKHR: + return "TypeCooperativeMatrixKHR"; + case Op::CooperativeMatrixLoadKHR: + return "CooperativeMatrixLoadKHR"; + case Op::CooperativeMatrixStoreKHR: + return "CooperativeMatrixStoreKHR"; + case Op::CooperativeMatrixMulAddKHR: + return "CooperativeMatrixMulAddKHR"; + case Op::CooperativeMatrixLengthKHR: + return "CooperativeMatrixLengthKHR"; + } + return "unknown"; +} +auto to_string(ImageOperands e) -> char const * { + switch (e) { + case ImageOperands::None: + return "None"; + case ImageOperands::Bias: + return "Bias"; + case ImageOperands::Lod: + return "Lod"; + case ImageOperands::Grad: + return "Grad"; + case ImageOperands::ConstOffset: + return "ConstOffset"; + case ImageOperands::Offset: + return "Offset"; + case ImageOperands::ConstOffsets: + return "ConstOffsets"; + case ImageOperands::Sample: + return "Sample"; + case ImageOperands::MinLod: + return "MinLod"; + case ImageOperands::MakeTexelAvailable: + return "MakeTexelAvailable"; + case ImageOperands::MakeTexelVisible: + return "MakeTexelVisible"; + case ImageOperands::NonPrivateTexel: + return "NonPrivateTexel"; + case ImageOperands::VolatileTexel: + return "VolatileTexel"; + case ImageOperands::SignExtend: + return "SignExtend"; + case ImageOperands::ZeroExtend: + return "ZeroExtend"; + case ImageOperands::Nontemporal: + return "Nontemporal"; + case ImageOperands::Offsets: + return "Offsets"; + } + return "unknown"; +} +auto to_string(FPFastMathMode e) -> char const * { + switch (e) { + case FPFastMathMode::None: + return "None"; + case FPFastMathMode::NotNaN: + return "NotNaN"; + case FPFastMathMode::NotInf: + return "NotInf"; + case FPFastMathMode::NSZ: + return "NSZ"; + case FPFastMathMode::AllowRecip: + return "AllowRecip"; + case FPFastMathMode::Fast: + return "Fast"; + case FPFastMathMode::AllowContract: + return "AllowContract"; + case FPFastMathMode::AllowReassoc: + return "AllowReassoc"; + case FPFastMathMode::AllowTransform: + return "AllowTransform"; + } + return "unknown"; +} +auto to_string(SelectionControl e) -> char const * { + switch (e) { + case SelectionControl::None: + return "None"; + case SelectionControl::Flatten: + return "Flatten"; + case SelectionControl::DontFlatten: + return "DontFlatten"; + } + return "unknown"; +} +auto to_string(LoopControl e) -> char const * { + switch (e) { + case LoopControl::None: + return "None"; + case LoopControl::Unroll: + return "Unroll"; + case LoopControl::DontUnroll: + return "DontUnroll"; + case LoopControl::DependencyInfinite: + return "DependencyInfinite"; + case LoopControl::DependencyLength: + return "DependencyLength"; + case LoopControl::MinIterations: + return "MinIterations"; + case LoopControl::MaxIterations: + return "MaxIterations"; + case LoopControl::IterationMultiple: + return "IterationMultiple"; + case LoopControl::PeelCount: + return "PeelCount"; + case LoopControl::PartialCount: + return "PartialCount"; + case LoopControl::InitiationIntervalINTEL: + return "InitiationIntervalINTEL"; + case LoopControl::MaxConcurrencyINTEL: + return "MaxConcurrencyINTEL"; + case LoopControl::DependencyArrayINTEL: + return "DependencyArrayINTEL"; + case LoopControl::PipelineEnableINTEL: + return "PipelineEnableINTEL"; + case LoopControl::LoopCoalesceINTEL: + return "LoopCoalesceINTEL"; + case LoopControl::MaxInterleavingINTEL: + return "MaxInterleavingINTEL"; + case LoopControl::SpeculatedIterationsINTEL: + return "SpeculatedIterationsINTEL"; + case LoopControl::NoFusionINTEL: + return "NoFusionINTEL"; + case LoopControl::LoopCountINTEL: + return "LoopCountINTEL"; + case LoopControl::MaxReinvocationDelayINTEL: + return "MaxReinvocationDelayINTEL"; + } + return "unknown"; +} +auto to_string(FunctionControl e) -> char const * { + switch (e) { + case FunctionControl::None: + return "None"; + case FunctionControl::Inline: + return "Inline"; + case FunctionControl::DontInline: + return "DontInline"; + case FunctionControl::Pure: + return "Pure"; + case FunctionControl::Const: + return "Const"; + case FunctionControl::OptNoneEXT: + return "OptNoneEXT"; + } + return "unknown"; +} +auto to_string(MemorySemantics e) -> char const * { + switch (e) { + case MemorySemantics::Relaxed: + return "Relaxed"; + case MemorySemantics::Acquire: + return "Acquire"; + case MemorySemantics::Release: + return "Release"; + case MemorySemantics::AcquireRelease: + return "AcquireRelease"; + case MemorySemantics::SequentiallyConsistent: + return "SequentiallyConsistent"; + case MemorySemantics::UniformMemory: + return "UniformMemory"; + case MemorySemantics::SubgroupMemory: + return "SubgroupMemory"; + case MemorySemantics::WorkgroupMemory: + return "WorkgroupMemory"; + case MemorySemantics::CrossWorkgroupMemory: + return "CrossWorkgroupMemory"; + case MemorySemantics::AtomicCounterMemory: + return "AtomicCounterMemory"; + case MemorySemantics::ImageMemory: + return "ImageMemory"; + case MemorySemantics::OutputMemory: + return "OutputMemory"; + case MemorySemantics::MakeAvailable: + return "MakeAvailable"; + case MemorySemantics::MakeVisible: + return "MakeVisible"; + case MemorySemantics::Volatile: + return "Volatile"; + } + return "unknown"; +} +auto to_string(MemoryAccess e) -> char const * { + switch (e) { + case MemoryAccess::None: + return "None"; + case MemoryAccess::Volatile: + return "Volatile"; + case MemoryAccess::Aligned: + return "Aligned"; + case MemoryAccess::Nontemporal: + return "Nontemporal"; + case MemoryAccess::MakePointerAvailable: + return "MakePointerAvailable"; + case MemoryAccess::MakePointerVisible: + return "MakePointerVisible"; + case MemoryAccess::NonPrivatePointer: + return "NonPrivatePointer"; + case MemoryAccess::AliasScopeINTELMask: + return "AliasScopeINTELMask"; + case MemoryAccess::NoAliasINTELMask: + return "NoAliasINTELMask"; + } + return "unknown"; +} +auto to_string(KernelProfilingInfo e) -> char const * { + switch (e) { + case KernelProfilingInfo::None: + return "None"; + case KernelProfilingInfo::CmdExecTime: + return "CmdExecTime"; + } + return "unknown"; +} +auto to_string(RayFlags e) -> char const * { + switch (e) { + case RayFlags::NoneKHR: + return "NoneKHR"; + case RayFlags::OpaqueKHR: + return "OpaqueKHR"; + case RayFlags::NoOpaqueKHR: + return "NoOpaqueKHR"; + case RayFlags::TerminateOnFirstHitKHR: + return "TerminateOnFirstHitKHR"; + case RayFlags::SkipClosestHitShaderKHR: + return "SkipClosestHitShaderKHR"; + case RayFlags::CullBackFacingTrianglesKHR: + return "CullBackFacingTrianglesKHR"; + case RayFlags::CullFrontFacingTrianglesKHR: + return "CullFrontFacingTrianglesKHR"; + case RayFlags::CullOpaqueKHR: + return "CullOpaqueKHR"; + case RayFlags::CullNoOpaqueKHR: + return "CullNoOpaqueKHR"; + case RayFlags::SkipTrianglesKHR: + return "SkipTrianglesKHR"; + case RayFlags::SkipAABBsKHR: + return "SkipAABBsKHR"; + case RayFlags::ForceOpacityMicromap2StateEXT: + return "ForceOpacityMicromap2StateEXT"; + } + return "unknown"; +} +auto to_string(FragmentShadingRate e) -> char const * { + switch (e) { + case FragmentShadingRate::Vertical2Pixels: + return "Vertical2Pixels"; + case FragmentShadingRate::Vertical4Pixels: + return "Vertical4Pixels"; + case FragmentShadingRate::Horizontal2Pixels: + return "Horizontal2Pixels"; + case FragmentShadingRate::Horizontal4Pixels: + return "Horizontal4Pixels"; + } + return "unknown"; +} +auto to_string(RawAccessChainOperands e) -> char const * { + switch (e) { + case RawAccessChainOperands::None: + return "None"; + case RawAccessChainOperands::RobustnessPerComponentNV: + return "RobustnessPerComponentNV"; + case RawAccessChainOperands::RobustnessPerElementNV: + return "RobustnessPerElementNV"; + } + return "unknown"; +} +auto to_string(SourceLanguage e) -> char const * { + switch (e) { + case SourceLanguage::Unknown: + return "Unknown"; + case SourceLanguage::ESSL: + return "ESSL"; + case SourceLanguage::GLSL: + return "GLSL"; + case SourceLanguage::OpenCL_C: + return "OpenCL_C"; + case SourceLanguage::OpenCL_CPP: + return "OpenCL_CPP"; + case SourceLanguage::HLSL: + return "HLSL"; + case SourceLanguage::CPP_for_OpenCL: + return "CPP_for_OpenCL"; + case SourceLanguage::SYCL: + return "SYCL"; + case SourceLanguage::HERO_C: + return "HERO_C"; + case SourceLanguage::NZSL: + return "NZSL"; + case SourceLanguage::WGSL: + return "WGSL"; + case SourceLanguage::Slang: + return "Slang"; + case SourceLanguage::Zig: + return "Zig"; + } + return "unknown"; +} +auto to_string(ExecutionModel e) -> char const * { + switch (e) { + case ExecutionModel::Vertex: + return "Vertex"; + case ExecutionModel::TessellationControl: + return "TessellationControl"; + case ExecutionModel::TessellationEvaluation: + return "TessellationEvaluation"; + case ExecutionModel::Geometry: + return "Geometry"; + case ExecutionModel::Fragment: + return "Fragment"; + case ExecutionModel::GLCompute: + return "GLCompute"; + case ExecutionModel::Kernel: + return "Kernel"; + case ExecutionModel::TaskNV: + return "TaskNV"; + case ExecutionModel::MeshNV: + return "MeshNV"; + case ExecutionModel::RayGenerationKHR: + return "RayGenerationKHR"; + case ExecutionModel::IntersectionKHR: + return "IntersectionKHR"; + case ExecutionModel::AnyHitKHR: + return "AnyHitKHR"; + case ExecutionModel::ClosestHitKHR: + return "ClosestHitKHR"; + case ExecutionModel::MissKHR: + return "MissKHR"; + case ExecutionModel::CallableKHR: + return "CallableKHR"; + case ExecutionModel::TaskEXT: + return "TaskEXT"; + case ExecutionModel::MeshEXT: + return "MeshEXT"; + } + return "unknown"; +} +auto to_string(AddressingModel e) -> char const * { + switch (e) { + case AddressingModel::Logical: + return "Logical"; + case AddressingModel::Physical32: + return "Physical32"; + case AddressingModel::Physical64: + return "Physical64"; + case AddressingModel::PhysicalStorageBuffer64: + return "PhysicalStorageBuffer64"; + } + return "unknown"; +} +auto to_string(MemoryModel e) -> char const * { + switch (e) { + case MemoryModel::Simple: + return "Simple"; + case MemoryModel::GLSL450: + return "GLSL450"; + case MemoryModel::OpenCL: + return "OpenCL"; + case MemoryModel::Vulkan: + return "Vulkan"; + } + return "unknown"; +} +auto to_string(ExecutionMode e) -> char const * { + switch (e) { + case ExecutionMode::Invocations: + return "Invocations"; + case ExecutionMode::SpacingEqual: + return "SpacingEqual"; + case ExecutionMode::SpacingFractionalEven: + return "SpacingFractionalEven"; + case ExecutionMode::SpacingFractionalOdd: + return "SpacingFractionalOdd"; + case ExecutionMode::VertexOrderCw: + return "VertexOrderCw"; + case ExecutionMode::VertexOrderCcw: + return "VertexOrderCcw"; + case ExecutionMode::PixelCenterInteger: + return "PixelCenterInteger"; + case ExecutionMode::OriginUpperLeft: + return "OriginUpperLeft"; + case ExecutionMode::OriginLowerLeft: + return "OriginLowerLeft"; + case ExecutionMode::EarlyFragmentTests: + return "EarlyFragmentTests"; + case ExecutionMode::PointMode: + return "PointMode"; + case ExecutionMode::Xfb: + return "Xfb"; + case ExecutionMode::DepthReplacing: + return "DepthReplacing"; + case ExecutionMode::DepthGreater: + return "DepthGreater"; + case ExecutionMode::DepthLess: + return "DepthLess"; + case ExecutionMode::DepthUnchanged: + return "DepthUnchanged"; + case ExecutionMode::LocalSize: + return "LocalSize"; + case ExecutionMode::LocalSizeHint: + return "LocalSizeHint"; + case ExecutionMode::InputPoints: + return "InputPoints"; + case ExecutionMode::InputLines: + return "InputLines"; + case ExecutionMode::InputLinesAdjacency: + return "InputLinesAdjacency"; + case ExecutionMode::Triangles: + return "Triangles"; + case ExecutionMode::InputTrianglesAdjacency: + return "InputTrianglesAdjacency"; + case ExecutionMode::Quads: + return "Quads"; + case ExecutionMode::Isolines: + return "Isolines"; + case ExecutionMode::OutputVertices: + return "OutputVertices"; + case ExecutionMode::OutputPoints: + return "OutputPoints"; + case ExecutionMode::OutputLineStrip: + return "OutputLineStrip"; + case ExecutionMode::OutputTriangleStrip: + return "OutputTriangleStrip"; + case ExecutionMode::VecTypeHint: + return "VecTypeHint"; + case ExecutionMode::ContractionOff: + return "ContractionOff"; + case ExecutionMode::Initializer: + return "Initializer"; + case ExecutionMode::Finalizer: + return "Finalizer"; + case ExecutionMode::SubgroupSize: + return "SubgroupSize"; + case ExecutionMode::SubgroupsPerWorkgroup: + return "SubgroupsPerWorkgroup"; + case ExecutionMode::SubgroupsPerWorkgroupId: + return "SubgroupsPerWorkgroupId"; + case ExecutionMode::LocalSizeId: + return "LocalSizeId"; + case ExecutionMode::LocalSizeHintId: + return "LocalSizeHintId"; + case ExecutionMode::NonCoherentColorAttachmentReadEXT: + return "NonCoherentColorAttachmentReadEXT"; + case ExecutionMode::NonCoherentDepthAttachmentReadEXT: + return "NonCoherentDepthAttachmentReadEXT"; + case ExecutionMode::NonCoherentStencilAttachmentReadEXT: + return "NonCoherentStencilAttachmentReadEXT"; + case ExecutionMode::SubgroupUniformControlFlowKHR: + return "SubgroupUniformControlFlowKHR"; + case ExecutionMode::PostDepthCoverage: + return "PostDepthCoverage"; + case ExecutionMode::DenormPreserve: + return "DenormPreserve"; + case ExecutionMode::DenormFlushToZero: + return "DenormFlushToZero"; + case ExecutionMode::SignedZeroInfNanPreserve: + return "SignedZeroInfNanPreserve"; + case ExecutionMode::RoundingModeRTE: + return "RoundingModeRTE"; + case ExecutionMode::RoundingModeRTZ: + return "RoundingModeRTZ"; + case ExecutionMode::EarlyAndLateFragmentTestsAMD: + return "EarlyAndLateFragmentTestsAMD"; + case ExecutionMode::StencilRefReplacingEXT: + return "StencilRefReplacingEXT"; + case ExecutionMode::CoalescingAMDX: + return "CoalescingAMDX"; + case ExecutionMode::IsApiEntryAMDX: + return "IsApiEntryAMDX"; + case ExecutionMode::MaxNodeRecursionAMDX: + return "MaxNodeRecursionAMDX"; + case ExecutionMode::StaticNumWorkgroupsAMDX: + return "StaticNumWorkgroupsAMDX"; + case ExecutionMode::ShaderIndexAMDX: + return "ShaderIndexAMDX"; + case ExecutionMode::MaxNumWorkgroupsAMDX: + return "MaxNumWorkgroupsAMDX"; + case ExecutionMode::StencilRefUnchangedFrontAMD: + return "StencilRefUnchangedFrontAMD"; + case ExecutionMode::StencilRefGreaterFrontAMD: + return "StencilRefGreaterFrontAMD"; + case ExecutionMode::StencilRefLessFrontAMD: + return "StencilRefLessFrontAMD"; + case ExecutionMode::StencilRefUnchangedBackAMD: + return "StencilRefUnchangedBackAMD"; + case ExecutionMode::StencilRefGreaterBackAMD: + return "StencilRefGreaterBackAMD"; + case ExecutionMode::StencilRefLessBackAMD: + return "StencilRefLessBackAMD"; + case ExecutionMode::QuadDerivativesKHR: + return "QuadDerivativesKHR"; + case ExecutionMode::RequireFullQuadsKHR: + return "RequireFullQuadsKHR"; + case ExecutionMode::SharesInputWithAMDX: + return "SharesInputWithAMDX"; + case ExecutionMode::OutputLinesEXT: + return "OutputLinesEXT"; + case ExecutionMode::OutputPrimitivesEXT: + return "OutputPrimitivesEXT"; + case ExecutionMode::DerivativeGroupQuadsKHR: + return "DerivativeGroupQuadsKHR"; + case ExecutionMode::DerivativeGroupLinearKHR: + return "DerivativeGroupLinearKHR"; + case ExecutionMode::OutputTrianglesEXT: + return "OutputTrianglesEXT"; + case ExecutionMode::PixelInterlockOrderedEXT: + return "PixelInterlockOrderedEXT"; + case ExecutionMode::PixelInterlockUnorderedEXT: + return "PixelInterlockUnorderedEXT"; + case ExecutionMode::SampleInterlockOrderedEXT: + return "SampleInterlockOrderedEXT"; + case ExecutionMode::SampleInterlockUnorderedEXT: + return "SampleInterlockUnorderedEXT"; + case ExecutionMode::ShadingRateInterlockOrderedEXT: + return "ShadingRateInterlockOrderedEXT"; + case ExecutionMode::ShadingRateInterlockUnorderedEXT: + return "ShadingRateInterlockUnorderedEXT"; + case ExecutionMode::SharedLocalMemorySizeINTEL: + return "SharedLocalMemorySizeINTEL"; + case ExecutionMode::RoundingModeRTPINTEL: + return "RoundingModeRTPINTEL"; + case ExecutionMode::RoundingModeRTNINTEL: + return "RoundingModeRTNINTEL"; + case ExecutionMode::FloatingPointModeALTINTEL: + return "FloatingPointModeALTINTEL"; + case ExecutionMode::FloatingPointModeIEEEINTEL: + return "FloatingPointModeIEEEINTEL"; + case ExecutionMode::MaxWorkgroupSizeINTEL: + return "MaxWorkgroupSizeINTEL"; + case ExecutionMode::MaxWorkDimINTEL: + return "MaxWorkDimINTEL"; + case ExecutionMode::NoGlobalOffsetINTEL: + return "NoGlobalOffsetINTEL"; + case ExecutionMode::NumSIMDWorkitemsINTEL: + return "NumSIMDWorkitemsINTEL"; + case ExecutionMode::SchedulerTargetFmaxMhzINTEL: + return "SchedulerTargetFmaxMhzINTEL"; + case ExecutionMode::MaximallyReconvergesKHR: + return "MaximallyReconvergesKHR"; + case ExecutionMode::FPFastMathDefault: + return "FPFastMathDefault"; + case ExecutionMode::StreamingInterfaceINTEL: + return "StreamingInterfaceINTEL"; + case ExecutionMode::RegisterMapInterfaceINTEL: + return "RegisterMapInterfaceINTEL"; + case ExecutionMode::NamedBarrierCountINTEL: + return "NamedBarrierCountINTEL"; + case ExecutionMode::MaximumRegistersINTEL: + return "MaximumRegistersINTEL"; + case ExecutionMode::MaximumRegistersIdINTEL: + return "MaximumRegistersIdINTEL"; + case ExecutionMode::NamedMaximumRegistersINTEL: + return "NamedMaximumRegistersINTEL"; + } + return "unknown"; +} +auto to_string(StorageClass e) -> char const * { + switch (e) { + case StorageClass::UniformConstant: + return "UniformConstant"; + case StorageClass::Input: + return "Input"; + case StorageClass::Uniform: + return "Uniform"; + case StorageClass::Output: + return "Output"; + case StorageClass::Workgroup: + return "Workgroup"; + case StorageClass::CrossWorkgroup: + return "CrossWorkgroup"; + case StorageClass::Private: + return "Private"; + case StorageClass::Function: + return "Function"; + case StorageClass::Generic: + return "Generic"; + case StorageClass::PushConstant: + return "PushConstant"; + case StorageClass::AtomicCounter: + return "AtomicCounter"; + case StorageClass::Image: + return "Image"; + case StorageClass::StorageBuffer: + return "StorageBuffer"; + case StorageClass::TileImageEXT: + return "TileImageEXT"; + case StorageClass::NodePayloadAMDX: + return "NodePayloadAMDX"; + case StorageClass::CallableDataKHR: + return "CallableDataKHR"; + case StorageClass::IncomingCallableDataKHR: + return "IncomingCallableDataKHR"; + case StorageClass::RayPayloadKHR: + return "RayPayloadKHR"; + case StorageClass::HitAttributeKHR: + return "HitAttributeKHR"; + case StorageClass::IncomingRayPayloadKHR: + return "IncomingRayPayloadKHR"; + case StorageClass::ShaderRecordBufferKHR: + return "ShaderRecordBufferKHR"; + case StorageClass::PhysicalStorageBuffer: + return "PhysicalStorageBuffer"; + case StorageClass::HitObjectAttributeNV: + return "HitObjectAttributeNV"; + case StorageClass::TaskPayloadWorkgroupEXT: + return "TaskPayloadWorkgroupEXT"; + case StorageClass::CodeSectionINTEL: + return "CodeSectionINTEL"; + case StorageClass::DeviceOnlyINTEL: + return "DeviceOnlyINTEL"; + case StorageClass::HostOnlyINTEL: + return "HostOnlyINTEL"; + } + return "unknown"; +} +auto to_string(Dim e) -> char const * { + switch (e) { + case Dim::Dim1D: + return "Dim1D"; + case Dim::Dim2D: + return "Dim2D"; + case Dim::Dim3D: + return "Dim3D"; + case Dim::Cube: + return "Cube"; + case Dim::Rect: + return "Rect"; + case Dim::Buffer: + return "Buffer"; + case Dim::SubpassData: + return "SubpassData"; + case Dim::TileImageDataEXT: + return "TileImageDataEXT"; + } + return "unknown"; +} +auto to_string(SamplerAddressingMode e) -> char const * { + switch (e) { + case SamplerAddressingMode::None: + return "None"; + case SamplerAddressingMode::ClampToEdge: + return "ClampToEdge"; + case SamplerAddressingMode::Clamp: + return "Clamp"; + case SamplerAddressingMode::Repeat: + return "Repeat"; + case SamplerAddressingMode::RepeatMirrored: + return "RepeatMirrored"; + } + return "unknown"; +} +auto to_string(SamplerFilterMode e) -> char const * { + switch (e) { + case SamplerFilterMode::Nearest: + return "Nearest"; + case SamplerFilterMode::Linear: + return "Linear"; + } + return "unknown"; +} +auto to_string(ImageFormat e) -> char const * { + switch (e) { + case ImageFormat::Unknown: + return "Unknown"; + case ImageFormat::Rgba32f: + return "Rgba32f"; + case ImageFormat::Rgba16f: + return "Rgba16f"; + case ImageFormat::R32f: + return "R32f"; + case ImageFormat::Rgba8: + return "Rgba8"; + case ImageFormat::Rgba8Snorm: + return "Rgba8Snorm"; + case ImageFormat::Rg32f: + return "Rg32f"; + case ImageFormat::Rg16f: + return "Rg16f"; + case ImageFormat::R11fG11fB10f: + return "R11fG11fB10f"; + case ImageFormat::R16f: + return "R16f"; + case ImageFormat::Rgba16: + return "Rgba16"; + case ImageFormat::Rgb10A2: + return "Rgb10A2"; + case ImageFormat::Rg16: + return "Rg16"; + case ImageFormat::Rg8: + return "Rg8"; + case ImageFormat::R16: + return "R16"; + case ImageFormat::R8: + return "R8"; + case ImageFormat::Rgba16Snorm: + return "Rgba16Snorm"; + case ImageFormat::Rg16Snorm: + return "Rg16Snorm"; + case ImageFormat::Rg8Snorm: + return "Rg8Snorm"; + case ImageFormat::R16Snorm: + return "R16Snorm"; + case ImageFormat::R8Snorm: + return "R8Snorm"; + case ImageFormat::Rgba32i: + return "Rgba32i"; + case ImageFormat::Rgba16i: + return "Rgba16i"; + case ImageFormat::Rgba8i: + return "Rgba8i"; + case ImageFormat::R32i: + return "R32i"; + case ImageFormat::Rg32i: + return "Rg32i"; + case ImageFormat::Rg16i: + return "Rg16i"; + case ImageFormat::Rg8i: + return "Rg8i"; + case ImageFormat::R16i: + return "R16i"; + case ImageFormat::R8i: + return "R8i"; + case ImageFormat::Rgba32ui: + return "Rgba32ui"; + case ImageFormat::Rgba16ui: + return "Rgba16ui"; + case ImageFormat::Rgba8ui: + return "Rgba8ui"; + case ImageFormat::R32ui: + return "R32ui"; + case ImageFormat::Rgb10a2ui: + return "Rgb10a2ui"; + case ImageFormat::Rg32ui: + return "Rg32ui"; + case ImageFormat::Rg16ui: + return "Rg16ui"; + case ImageFormat::Rg8ui: + return "Rg8ui"; + case ImageFormat::R16ui: + return "R16ui"; + case ImageFormat::R8ui: + return "R8ui"; + case ImageFormat::R64ui: + return "R64ui"; + case ImageFormat::R64i: + return "R64i"; + } + return "unknown"; +} +auto to_string(ImageChannelOrder e) -> char const * { + switch (e) { + case ImageChannelOrder::R: + return "R"; + case ImageChannelOrder::A: + return "A"; + case ImageChannelOrder::RG: + return "RG"; + case ImageChannelOrder::RA: + return "RA"; + case ImageChannelOrder::RGB: + return "RGB"; + case ImageChannelOrder::RGBA: + return "RGBA"; + case ImageChannelOrder::BGRA: + return "BGRA"; + case ImageChannelOrder::ARGB: + return "ARGB"; + case ImageChannelOrder::Intensity: + return "Intensity"; + case ImageChannelOrder::Luminance: + return "Luminance"; + case ImageChannelOrder::Rx: + return "Rx"; + case ImageChannelOrder::RGx: + return "RGx"; + case ImageChannelOrder::RGBx: + return "RGBx"; + case ImageChannelOrder::Depth: + return "Depth"; + case ImageChannelOrder::DepthStencil: + return "DepthStencil"; + case ImageChannelOrder::sRGB: + return "sRGB"; + case ImageChannelOrder::sRGBx: + return "sRGBx"; + case ImageChannelOrder::sRGBA: + return "sRGBA"; + case ImageChannelOrder::sBGRA: + return "sBGRA"; + case ImageChannelOrder::ABGR: + return "ABGR"; + } + return "unknown"; +} +auto to_string(ImageChannelDataType e) -> char const * { + switch (e) { + case ImageChannelDataType::SnormInt8: + return "SnormInt8"; + case ImageChannelDataType::SnormInt16: + return "SnormInt16"; + case ImageChannelDataType::UnormInt8: + return "UnormInt8"; + case ImageChannelDataType::UnormInt16: + return "UnormInt16"; + case ImageChannelDataType::UnormShort565: + return "UnormShort565"; + case ImageChannelDataType::UnormShort555: + return "UnormShort555"; + case ImageChannelDataType::UnormInt101010: + return "UnormInt101010"; + case ImageChannelDataType::SignedInt8: + return "SignedInt8"; + case ImageChannelDataType::SignedInt16: + return "SignedInt16"; + case ImageChannelDataType::SignedInt32: + return "SignedInt32"; + case ImageChannelDataType::UnsignedInt8: + return "UnsignedInt8"; + case ImageChannelDataType::UnsignedInt16: + return "UnsignedInt16"; + case ImageChannelDataType::UnsignedInt32: + return "UnsignedInt32"; + case ImageChannelDataType::HalfFloat: + return "HalfFloat"; + case ImageChannelDataType::Float: + return "Float"; + case ImageChannelDataType::UnormInt24: + return "UnormInt24"; + case ImageChannelDataType::UnormInt101010_2: + return "UnormInt101010_2"; + case ImageChannelDataType::UnsignedIntRaw10EXT: + return "UnsignedIntRaw10EXT"; + case ImageChannelDataType::UnsignedIntRaw12EXT: + return "UnsignedIntRaw12EXT"; + case ImageChannelDataType::UnormInt2_101010EXT: + return "UnormInt2_101010EXT"; + } + return "unknown"; +} +auto to_string(FPRoundingMode e) -> char const * { + switch (e) { + case FPRoundingMode::RTE: + return "RTE"; + case FPRoundingMode::RTZ: + return "RTZ"; + case FPRoundingMode::RTP: + return "RTP"; + case FPRoundingMode::RTN: + return "RTN"; + } + return "unknown"; +} +auto to_string(FPDenormMode e) -> char const * { + switch (e) { + case FPDenormMode::Preserve: + return "Preserve"; + case FPDenormMode::FlushToZero: + return "FlushToZero"; + } + return "unknown"; +} +auto to_string(QuantizationModes e) -> char const * { + switch (e) { + case QuantizationModes::TRN: + return "TRN"; + case QuantizationModes::TRN_ZERO: + return "TRN_ZERO"; + case QuantizationModes::RND: + return "RND"; + case QuantizationModes::RND_ZERO: + return "RND_ZERO"; + case QuantizationModes::RND_INF: + return "RND_INF"; + case QuantizationModes::RND_MIN_INF: + return "RND_MIN_INF"; + case QuantizationModes::RND_CONV: + return "RND_CONV"; + case QuantizationModes::RND_CONV_ODD: + return "RND_CONV_ODD"; + } + return "unknown"; +} +auto to_string(FPOperationMode e) -> char const * { + switch (e) { + case FPOperationMode::IEEE: + return "IEEE"; + case FPOperationMode::ALT: + return "ALT"; + } + return "unknown"; +} +auto to_string(OverflowModes e) -> char const * { + switch (e) { + case OverflowModes::WRAP: + return "WRAP"; + case OverflowModes::SAT: + return "SAT"; + case OverflowModes::SAT_ZERO: + return "SAT_ZERO"; + case OverflowModes::SAT_SYM: + return "SAT_SYM"; + } + return "unknown"; +} +auto to_string(LinkageType e) -> char const * { + switch (e) { + case LinkageType::Export: + return "Export"; + case LinkageType::Import: + return "Import"; + case LinkageType::LinkOnceODR: + return "LinkOnceODR"; + } + return "unknown"; +} +auto to_string(AccessQualifier e) -> char const * { + switch (e) { + case AccessQualifier::ReadOnly: + return "ReadOnly"; + case AccessQualifier::WriteOnly: + return "WriteOnly"; + case AccessQualifier::ReadWrite: + return "ReadWrite"; + } + return "unknown"; +} +auto to_string(HostAccessQualifier e) -> char const * { + switch (e) { + case HostAccessQualifier::NoneINTEL: + return "NoneINTEL"; + case HostAccessQualifier::ReadINTEL: + return "ReadINTEL"; + case HostAccessQualifier::WriteINTEL: + return "WriteINTEL"; + case HostAccessQualifier::ReadWriteINTEL: + return "ReadWriteINTEL"; + } + return "unknown"; +} +auto to_string(FunctionParameterAttribute e) -> char const * { + switch (e) { + case FunctionParameterAttribute::Zext: + return "Zext"; + case FunctionParameterAttribute::Sext: + return "Sext"; + case FunctionParameterAttribute::ByVal: + return "ByVal"; + case FunctionParameterAttribute::Sret: + return "Sret"; + case FunctionParameterAttribute::NoAlias: + return "NoAlias"; + case FunctionParameterAttribute::NoCapture: + return "NoCapture"; + case FunctionParameterAttribute::NoWrite: + return "NoWrite"; + case FunctionParameterAttribute::NoReadWrite: + return "NoReadWrite"; + case FunctionParameterAttribute::RuntimeAlignedINTEL: + return "RuntimeAlignedINTEL"; + } + return "unknown"; +} +auto to_string(Decoration e) -> char const * { + switch (e) { + case Decoration::RelaxedPrecision: + return "RelaxedPrecision"; + case Decoration::SpecId: + return "SpecId"; + case Decoration::Block: + return "Block"; + case Decoration::BufferBlock: + return "BufferBlock"; + case Decoration::RowMajor: + return "RowMajor"; + case Decoration::ColMajor: + return "ColMajor"; + case Decoration::ArrayStride: + return "ArrayStride"; + case Decoration::MatrixStride: + return "MatrixStride"; + case Decoration::GLSLShared: + return "GLSLShared"; + case Decoration::GLSLPacked: + return "GLSLPacked"; + case Decoration::CPacked: + return "CPacked"; + case Decoration::BuiltIn: + return "BuiltIn"; + case Decoration::NoPerspective: + return "NoPerspective"; + case Decoration::Flat: + return "Flat"; + case Decoration::Patch: + return "Patch"; + case Decoration::Centroid: + return "Centroid"; + case Decoration::Sample: + return "Sample"; + case Decoration::Invariant: + return "Invariant"; + case Decoration::Restrict: + return "Restrict"; + case Decoration::Aliased: + return "Aliased"; + case Decoration::Volatile: + return "Volatile"; + case Decoration::Constant: + return "Constant"; + case Decoration::Coherent: + return "Coherent"; + case Decoration::NonWritable: + return "NonWritable"; + case Decoration::NonReadable: + return "NonReadable"; + case Decoration::Uniform: + return "Uniform"; + case Decoration::UniformId: + return "UniformId"; + case Decoration::SaturatedConversion: + return "SaturatedConversion"; + case Decoration::Stream: + return "Stream"; + case Decoration::Location: + return "Location"; + case Decoration::Component: + return "Component"; + case Decoration::Index: + return "Index"; + case Decoration::Binding: + return "Binding"; + case Decoration::DescriptorSet: + return "DescriptorSet"; + case Decoration::Offset: + return "Offset"; + case Decoration::XfbBuffer: + return "XfbBuffer"; + case Decoration::XfbStride: + return "XfbStride"; + case Decoration::FuncParamAttr: + return "FuncParamAttr"; + case Decoration::FPRoundingMode: + return "FPRoundingMode"; + case Decoration::FPFastMathMode: + return "FPFastMathMode"; + case Decoration::LinkageAttributes: + return "LinkageAttributes"; + case Decoration::NoContraction: + return "NoContraction"; + case Decoration::InputAttachmentIndex: + return "InputAttachmentIndex"; + case Decoration::Alignment: + return "Alignment"; + case Decoration::MaxByteOffset: + return "MaxByteOffset"; + case Decoration::AlignmentId: + return "AlignmentId"; + case Decoration::MaxByteOffsetId: + return "MaxByteOffsetId"; + case Decoration::NoSignedWrap: + return "NoSignedWrap"; + case Decoration::NoUnsignedWrap: + return "NoUnsignedWrap"; + case Decoration::WeightTextureQCOM: + return "WeightTextureQCOM"; + case Decoration::BlockMatchTextureQCOM: + return "BlockMatchTextureQCOM"; + case Decoration::BlockMatchSamplerQCOM: + return "BlockMatchSamplerQCOM"; + case Decoration::ExplicitInterpAMD: + return "ExplicitInterpAMD"; + case Decoration::NodeSharesPayloadLimitsWithAMDX: + return "NodeSharesPayloadLimitsWithAMDX"; + case Decoration::NodeMaxPayloadsAMDX: + return "NodeMaxPayloadsAMDX"; + case Decoration::TrackFinishWritingAMDX: + return "TrackFinishWritingAMDX"; + case Decoration::PayloadNodeNameAMDX: + return "PayloadNodeNameAMDX"; + case Decoration::PayloadNodeBaseIndexAMDX: + return "PayloadNodeBaseIndexAMDX"; + case Decoration::PayloadNodeSparseArrayAMDX: + return "PayloadNodeSparseArrayAMDX"; + case Decoration::PayloadNodeArraySizeAMDX: + return "PayloadNodeArraySizeAMDX"; + case Decoration::PayloadDispatchIndirectAMDX: + return "PayloadDispatchIndirectAMDX"; + case Decoration::OverrideCoverageNV: + return "OverrideCoverageNV"; + case Decoration::PassthroughNV: + return "PassthroughNV"; + case Decoration::ViewportRelativeNV: + return "ViewportRelativeNV"; + case Decoration::SecondaryViewportRelativeNV: + return "SecondaryViewportRelativeNV"; + case Decoration::PerPrimitiveEXT: + return "PerPrimitiveEXT"; + case Decoration::PerViewNV: + return "PerViewNV"; + case Decoration::PerTaskNV: + return "PerTaskNV"; + case Decoration::PerVertexKHR: + return "PerVertexKHR"; + case Decoration::NonUniform: + return "NonUniform"; + case Decoration::RestrictPointer: + return "RestrictPointer"; + case Decoration::AliasedPointer: + return "AliasedPointer"; + case Decoration::HitObjectShaderRecordBufferNV: + return "HitObjectShaderRecordBufferNV"; + case Decoration::BindlessSamplerNV: + return "BindlessSamplerNV"; + case Decoration::BindlessImageNV: + return "BindlessImageNV"; + case Decoration::BoundSamplerNV: + return "BoundSamplerNV"; + case Decoration::BoundImageNV: + return "BoundImageNV"; + case Decoration::SIMTCallINTEL: + return "SIMTCallINTEL"; + case Decoration::ReferencedIndirectlyINTEL: + return "ReferencedIndirectlyINTEL"; + case Decoration::ClobberINTEL: + return "ClobberINTEL"; + case Decoration::SideEffectsINTEL: + return "SideEffectsINTEL"; + case Decoration::VectorComputeVariableINTEL: + return "VectorComputeVariableINTEL"; + case Decoration::FuncParamIOKindINTEL: + return "FuncParamIOKindINTEL"; + case Decoration::VectorComputeFunctionINTEL: + return "VectorComputeFunctionINTEL"; + case Decoration::StackCallINTEL: + return "StackCallINTEL"; + case Decoration::GlobalVariableOffsetINTEL: + return "GlobalVariableOffsetINTEL"; + case Decoration::CounterBuffer: + return "CounterBuffer"; + case Decoration::UserSemantic: + return "UserSemantic"; + case Decoration::UserTypeGOOGLE: + return "UserTypeGOOGLE"; + case Decoration::FunctionRoundingModeINTEL: + return "FunctionRoundingModeINTEL"; + case Decoration::FunctionDenormModeINTEL: + return "FunctionDenormModeINTEL"; + case Decoration::RegisterINTEL: + return "RegisterINTEL"; + case Decoration::MemoryINTEL: + return "MemoryINTEL"; + case Decoration::NumbanksINTEL: + return "NumbanksINTEL"; + case Decoration::BankwidthINTEL: + return "BankwidthINTEL"; + case Decoration::MaxPrivateCopiesINTEL: + return "MaxPrivateCopiesINTEL"; + case Decoration::SinglepumpINTEL: + return "SinglepumpINTEL"; + case Decoration::DoublepumpINTEL: + return "DoublepumpINTEL"; + case Decoration::MaxReplicatesINTEL: + return "MaxReplicatesINTEL"; + case Decoration::SimpleDualPortINTEL: + return "SimpleDualPortINTEL"; + case Decoration::MergeINTEL: + return "MergeINTEL"; + case Decoration::BankBitsINTEL: + return "BankBitsINTEL"; + case Decoration::ForcePow2DepthINTEL: + return "ForcePow2DepthINTEL"; + case Decoration::StridesizeINTEL: + return "StridesizeINTEL"; + case Decoration::WordsizeINTEL: + return "WordsizeINTEL"; + case Decoration::TrueDualPortINTEL: + return "TrueDualPortINTEL"; + case Decoration::BurstCoalesceINTEL: + return "BurstCoalesceINTEL"; + case Decoration::CacheSizeINTEL: + return "CacheSizeINTEL"; + case Decoration::DontStaticallyCoalesceINTEL: + return "DontStaticallyCoalesceINTEL"; + case Decoration::PrefetchINTEL: + return "PrefetchINTEL"; + case Decoration::StallEnableINTEL: + return "StallEnableINTEL"; + case Decoration::FuseLoopsInFunctionINTEL: + return "FuseLoopsInFunctionINTEL"; + case Decoration::MathOpDSPModeINTEL: + return "MathOpDSPModeINTEL"; + case Decoration::AliasScopeINTEL: + return "AliasScopeINTEL"; + case Decoration::NoAliasINTEL: + return "NoAliasINTEL"; + case Decoration::InitiationIntervalINTEL: + return "InitiationIntervalINTEL"; + case Decoration::MaxConcurrencyINTEL: + return "MaxConcurrencyINTEL"; + case Decoration::PipelineEnableINTEL: + return "PipelineEnableINTEL"; + case Decoration::BufferLocationINTEL: + return "BufferLocationINTEL"; + case Decoration::IOPipeStorageINTEL: + return "IOPipeStorageINTEL"; + case Decoration::FunctionFloatingPointModeINTEL: + return "FunctionFloatingPointModeINTEL"; + case Decoration::SingleElementVectorINTEL: + return "SingleElementVectorINTEL"; + case Decoration::VectorComputeCallableFunctionINTEL: + return "VectorComputeCallableFunctionINTEL"; + case Decoration::MediaBlockIOINTEL: + return "MediaBlockIOINTEL"; + case Decoration::StallFreeINTEL: + return "StallFreeINTEL"; + case Decoration::FPMaxErrorDecorationINTEL: + return "FPMaxErrorDecorationINTEL"; + case Decoration::LatencyControlLabelINTEL: + return "LatencyControlLabelINTEL"; + case Decoration::LatencyControlConstraintINTEL: + return "LatencyControlConstraintINTEL"; + case Decoration::ConduitKernelArgumentINTEL: + return "ConduitKernelArgumentINTEL"; + case Decoration::RegisterMapKernelArgumentINTEL: + return "RegisterMapKernelArgumentINTEL"; + case Decoration::MMHostInterfaceAddressWidthINTEL: + return "MMHostInterfaceAddressWidthINTEL"; + case Decoration::MMHostInterfaceDataWidthINTEL: + return "MMHostInterfaceDataWidthINTEL"; + case Decoration::MMHostInterfaceLatencyINTEL: + return "MMHostInterfaceLatencyINTEL"; + case Decoration::MMHostInterfaceReadWriteModeINTEL: + return "MMHostInterfaceReadWriteModeINTEL"; + case Decoration::MMHostInterfaceMaxBurstINTEL: + return "MMHostInterfaceMaxBurstINTEL"; + case Decoration::MMHostInterfaceWaitRequestINTEL: + return "MMHostInterfaceWaitRequestINTEL"; + case Decoration::StableKernelArgumentINTEL: + return "StableKernelArgumentINTEL"; + case Decoration::HostAccessINTEL: + return "HostAccessINTEL"; + case Decoration::InitModeINTEL: + return "InitModeINTEL"; + case Decoration::ImplementInRegisterMapINTEL: + return "ImplementInRegisterMapINTEL"; + case Decoration::CacheControlLoadINTEL: + return "CacheControlLoadINTEL"; + case Decoration::CacheControlStoreINTEL: + return "CacheControlStoreINTEL"; + } + return "unknown"; +} +auto to_string(BuiltIn e) -> char const * { + switch (e) { + case BuiltIn::Position: + return "Position"; + case BuiltIn::PointSize: + return "PointSize"; + case BuiltIn::ClipDistance: + return "ClipDistance"; + case BuiltIn::CullDistance: + return "CullDistance"; + case BuiltIn::VertexId: + return "VertexId"; + case BuiltIn::InstanceId: + return "InstanceId"; + case BuiltIn::PrimitiveId: + return "PrimitiveId"; + case BuiltIn::InvocationId: + return "InvocationId"; + case BuiltIn::Layer: + return "Layer"; + case BuiltIn::ViewportIndex: + return "ViewportIndex"; + case BuiltIn::TessLevelOuter: + return "TessLevelOuter"; + case BuiltIn::TessLevelInner: + return "TessLevelInner"; + case BuiltIn::TessCoord: + return "TessCoord"; + case BuiltIn::PatchVertices: + return "PatchVertices"; + case BuiltIn::FragCoord: + return "FragCoord"; + case BuiltIn::PointCoord: + return "PointCoord"; + case BuiltIn::FrontFacing: + return "FrontFacing"; + case BuiltIn::SampleId: + return "SampleId"; + case BuiltIn::SamplePosition: + return "SamplePosition"; + case BuiltIn::SampleMask: + return "SampleMask"; + case BuiltIn::FragDepth: + return "FragDepth"; + case BuiltIn::HelperInvocation: + return "HelperInvocation"; + case BuiltIn::NumWorkgroups: + return "NumWorkgroups"; + case BuiltIn::WorkgroupSize: + return "WorkgroupSize"; + case BuiltIn::WorkgroupId: + return "WorkgroupId"; + case BuiltIn::LocalInvocationId: + return "LocalInvocationId"; + case BuiltIn::GlobalInvocationId: + return "GlobalInvocationId"; + case BuiltIn::LocalInvocationIndex: + return "LocalInvocationIndex"; + case BuiltIn::WorkDim: + return "WorkDim"; + case BuiltIn::GlobalSize: + return "GlobalSize"; + case BuiltIn::EnqueuedWorkgroupSize: + return "EnqueuedWorkgroupSize"; + case BuiltIn::GlobalOffset: + return "GlobalOffset"; + case BuiltIn::GlobalLinearId: + return "GlobalLinearId"; + case BuiltIn::SubgroupSize: + return "SubgroupSize"; + case BuiltIn::SubgroupMaxSize: + return "SubgroupMaxSize"; + case BuiltIn::NumSubgroups: + return "NumSubgroups"; + case BuiltIn::NumEnqueuedSubgroups: + return "NumEnqueuedSubgroups"; + case BuiltIn::SubgroupId: + return "SubgroupId"; + case BuiltIn::SubgroupLocalInvocationId: + return "SubgroupLocalInvocationId"; + case BuiltIn::VertexIndex: + return "VertexIndex"; + case BuiltIn::InstanceIndex: + return "InstanceIndex"; + case BuiltIn::CoreIDARM: + return "CoreIDARM"; + case BuiltIn::CoreCountARM: + return "CoreCountARM"; + case BuiltIn::CoreMaxIDARM: + return "CoreMaxIDARM"; + case BuiltIn::WarpIDARM: + return "WarpIDARM"; + case BuiltIn::WarpMaxIDARM: + return "WarpMaxIDARM"; + case BuiltIn::SubgroupEqMask: + return "SubgroupEqMask"; + case BuiltIn::SubgroupGeMask: + return "SubgroupGeMask"; + case BuiltIn::SubgroupGtMask: + return "SubgroupGtMask"; + case BuiltIn::SubgroupLeMask: + return "SubgroupLeMask"; + case BuiltIn::SubgroupLtMask: + return "SubgroupLtMask"; + case BuiltIn::BaseVertex: + return "BaseVertex"; + case BuiltIn::BaseInstance: + return "BaseInstance"; + case BuiltIn::DrawIndex: + return "DrawIndex"; + case BuiltIn::PrimitiveShadingRateKHR: + return "PrimitiveShadingRateKHR"; + case BuiltIn::DeviceIndex: + return "DeviceIndex"; + case BuiltIn::ViewIndex: + return "ViewIndex"; + case BuiltIn::ShadingRateKHR: + return "ShadingRateKHR"; + case BuiltIn::BaryCoordNoPerspAMD: + return "BaryCoordNoPerspAMD"; + case BuiltIn::BaryCoordNoPerspCentroidAMD: + return "BaryCoordNoPerspCentroidAMD"; + case BuiltIn::BaryCoordNoPerspSampleAMD: + return "BaryCoordNoPerspSampleAMD"; + case BuiltIn::BaryCoordSmoothAMD: + return "BaryCoordSmoothAMD"; + case BuiltIn::BaryCoordSmoothCentroidAMD: + return "BaryCoordSmoothCentroidAMD"; + case BuiltIn::BaryCoordSmoothSampleAMD: + return "BaryCoordSmoothSampleAMD"; + case BuiltIn::BaryCoordPullModelAMD: + return "BaryCoordPullModelAMD"; + case BuiltIn::FragStencilRefEXT: + return "FragStencilRefEXT"; + case BuiltIn::RemainingRecursionLevelsAMDX: + return "RemainingRecursionLevelsAMDX"; + case BuiltIn::ShaderIndexAMDX: + return "ShaderIndexAMDX"; + case BuiltIn::ViewportMaskNV: + return "ViewportMaskNV"; + case BuiltIn::SecondaryPositionNV: + return "SecondaryPositionNV"; + case BuiltIn::SecondaryViewportMaskNV: + return "SecondaryViewportMaskNV"; + case BuiltIn::PositionPerViewNV: + return "PositionPerViewNV"; + case BuiltIn::ViewportMaskPerViewNV: + return "ViewportMaskPerViewNV"; + case BuiltIn::FullyCoveredEXT: + return "FullyCoveredEXT"; + case BuiltIn::TaskCountNV: + return "TaskCountNV"; + case BuiltIn::PrimitiveCountNV: + return "PrimitiveCountNV"; + case BuiltIn::PrimitiveIndicesNV: + return "PrimitiveIndicesNV"; + case BuiltIn::ClipDistancePerViewNV: + return "ClipDistancePerViewNV"; + case BuiltIn::CullDistancePerViewNV: + return "CullDistancePerViewNV"; + case BuiltIn::LayerPerViewNV: + return "LayerPerViewNV"; + case BuiltIn::MeshViewCountNV: + return "MeshViewCountNV"; + case BuiltIn::MeshViewIndicesNV: + return "MeshViewIndicesNV"; + case BuiltIn::BaryCoordKHR: + return "BaryCoordKHR"; + case BuiltIn::BaryCoordNoPerspKHR: + return "BaryCoordNoPerspKHR"; + case BuiltIn::FragSizeEXT: + return "FragSizeEXT"; + case BuiltIn::FragInvocationCountEXT: + return "FragInvocationCountEXT"; + case BuiltIn::PrimitivePointIndicesEXT: + return "PrimitivePointIndicesEXT"; + case BuiltIn::PrimitiveLineIndicesEXT: + return "PrimitiveLineIndicesEXT"; + case BuiltIn::PrimitiveTriangleIndicesEXT: + return "PrimitiveTriangleIndicesEXT"; + case BuiltIn::CullPrimitiveEXT: + return "CullPrimitiveEXT"; + case BuiltIn::LaunchIdKHR: + return "LaunchIdKHR"; + case BuiltIn::LaunchSizeKHR: + return "LaunchSizeKHR"; + case BuiltIn::WorldRayOriginKHR: + return "WorldRayOriginKHR"; + case BuiltIn::WorldRayDirectionKHR: + return "WorldRayDirectionKHR"; + case BuiltIn::ObjectRayOriginKHR: + return "ObjectRayOriginKHR"; + case BuiltIn::ObjectRayDirectionKHR: + return "ObjectRayDirectionKHR"; + case BuiltIn::RayTminKHR: + return "RayTminKHR"; + case BuiltIn::RayTmaxKHR: + return "RayTmaxKHR"; + case BuiltIn::InstanceCustomIndexKHR: + return "InstanceCustomIndexKHR"; + case BuiltIn::ObjectToWorldKHR: + return "ObjectToWorldKHR"; + case BuiltIn::WorldToObjectKHR: + return "WorldToObjectKHR"; + case BuiltIn::HitTNV: + return "HitTNV"; + case BuiltIn::HitKindKHR: + return "HitKindKHR"; + case BuiltIn::CurrentRayTimeNV: + return "CurrentRayTimeNV"; + case BuiltIn::HitTriangleVertexPositionsKHR: + return "HitTriangleVertexPositionsKHR"; + case BuiltIn::HitMicroTriangleVertexPositionsNV: + return "HitMicroTriangleVertexPositionsNV"; + case BuiltIn::HitMicroTriangleVertexBarycentricsNV: + return "HitMicroTriangleVertexBarycentricsNV"; + case BuiltIn::IncomingRayFlagsKHR: + return "IncomingRayFlagsKHR"; + case BuiltIn::RayGeometryIndexKHR: + return "RayGeometryIndexKHR"; + case BuiltIn::WarpsPerSMNV: + return "WarpsPerSMNV"; + case BuiltIn::SMCountNV: + return "SMCountNV"; + case BuiltIn::WarpIDNV: + return "WarpIDNV"; + case BuiltIn::SMIDNV: + return "SMIDNV"; + case BuiltIn::HitKindFrontFacingMicroTriangleNV: + return "HitKindFrontFacingMicroTriangleNV"; + case BuiltIn::HitKindBackFacingMicroTriangleNV: + return "HitKindBackFacingMicroTriangleNV"; + case BuiltIn::CullMaskKHR: + return "CullMaskKHR"; + } + return "unknown"; +} +auto to_string(Scope e) -> char const * { + switch (e) { + case Scope::CrossDevice: + return "CrossDevice"; + case Scope::Device: + return "Device"; + case Scope::Workgroup: + return "Workgroup"; + case Scope::Subgroup: + return "Subgroup"; + case Scope::Invocation: + return "Invocation"; + case Scope::QueueFamily: + return "QueueFamily"; + case Scope::ShaderCallKHR: + return "ShaderCallKHR"; + } + return "unknown"; +} +auto to_string(GroupOperation e) -> char const * { + switch (e) { + case GroupOperation::Reduce: + return "Reduce"; + case GroupOperation::InclusiveScan: + return "InclusiveScan"; + case GroupOperation::ExclusiveScan: + return "ExclusiveScan"; + case GroupOperation::ClusteredReduce: + return "ClusteredReduce"; + case GroupOperation::PartitionedReduceNV: + return "PartitionedReduceNV"; + case GroupOperation::PartitionedInclusiveScanNV: + return "PartitionedInclusiveScanNV"; + case GroupOperation::PartitionedExclusiveScanNV: + return "PartitionedExclusiveScanNV"; + } + return "unknown"; +} +auto to_string(KernelEnqueueFlags e) -> char const * { + switch (e) { + case KernelEnqueueFlags::NoWait: + return "NoWait"; + case KernelEnqueueFlags::WaitKernel: + return "WaitKernel"; + case KernelEnqueueFlags::WaitWorkGroup: + return "WaitWorkGroup"; + } + return "unknown"; +} +auto to_string(Capability e) -> char const * { + switch (e) { + case Capability::Matrix: + return "Matrix"; + case Capability::Shader: + return "Shader"; + case Capability::Geometry: + return "Geometry"; + case Capability::Tessellation: + return "Tessellation"; + case Capability::Addresses: + return "Addresses"; + case Capability::Linkage: + return "Linkage"; + case Capability::Kernel: + return "Kernel"; + case Capability::Vector16: + return "Vector16"; + case Capability::Float16Buffer: + return "Float16Buffer"; + case Capability::Float16: + return "Float16"; + case Capability::Float64: + return "Float64"; + case Capability::Int64: + return "Int64"; + case Capability::Int64Atomics: + return "Int64Atomics"; + case Capability::ImageBasic: + return "ImageBasic"; + case Capability::ImageReadWrite: + return "ImageReadWrite"; + case Capability::ImageMipmap: + return "ImageMipmap"; + case Capability::Pipes: + return "Pipes"; + case Capability::Groups: + return "Groups"; + case Capability::DeviceEnqueue: + return "DeviceEnqueue"; + case Capability::LiteralSampler: + return "LiteralSampler"; + case Capability::AtomicStorage: + return "AtomicStorage"; + case Capability::Int16: + return "Int16"; + case Capability::TessellationPointSize: + return "TessellationPointSize"; + case Capability::GeometryPointSize: + return "GeometryPointSize"; + case Capability::ImageGatherExtended: + return "ImageGatherExtended"; + case Capability::StorageImageMultisample: + return "StorageImageMultisample"; + case Capability::UniformBufferArrayDynamicIndexing: + return "UniformBufferArrayDynamicIndexing"; + case Capability::SampledImageArrayDynamicIndexing: + return "SampledImageArrayDynamicIndexing"; + case Capability::StorageBufferArrayDynamicIndexing: + return "StorageBufferArrayDynamicIndexing"; + case Capability::StorageImageArrayDynamicIndexing: + return "StorageImageArrayDynamicIndexing"; + case Capability::ClipDistance: + return "ClipDistance"; + case Capability::CullDistance: + return "CullDistance"; + case Capability::ImageCubeArray: + return "ImageCubeArray"; + case Capability::SampleRateShading: + return "SampleRateShading"; + case Capability::ImageRect: + return "ImageRect"; + case Capability::SampledRect: + return "SampledRect"; + case Capability::GenericPointer: + return "GenericPointer"; + case Capability::Int8: + return "Int8"; + case Capability::InputAttachment: + return "InputAttachment"; + case Capability::SparseResidency: + return "SparseResidency"; + case Capability::MinLod: + return "MinLod"; + case Capability::Sampled1D: + return "Sampled1D"; + case Capability::Image1D: + return "Image1D"; + case Capability::SampledCubeArray: + return "SampledCubeArray"; + case Capability::SampledBuffer: + return "SampledBuffer"; + case Capability::ImageBuffer: + return "ImageBuffer"; + case Capability::ImageMSArray: + return "ImageMSArray"; + case Capability::StorageImageExtendedFormats: + return "StorageImageExtendedFormats"; + case Capability::ImageQuery: + return "ImageQuery"; + case Capability::DerivativeControl: + return "DerivativeControl"; + case Capability::InterpolationFunction: + return "InterpolationFunction"; + case Capability::TransformFeedback: + return "TransformFeedback"; + case Capability::GeometryStreams: + return "GeometryStreams"; + case Capability::StorageImageReadWithoutFormat: + return "StorageImageReadWithoutFormat"; + case Capability::StorageImageWriteWithoutFormat: + return "StorageImageWriteWithoutFormat"; + case Capability::MultiViewport: + return "MultiViewport"; + case Capability::SubgroupDispatch: + return "SubgroupDispatch"; + case Capability::NamedBarrier: + return "NamedBarrier"; + case Capability::PipeStorage: + return "PipeStorage"; + case Capability::GroupNonUniform: + return "GroupNonUniform"; + case Capability::GroupNonUniformVote: + return "GroupNonUniformVote"; + case Capability::GroupNonUniformArithmetic: + return "GroupNonUniformArithmetic"; + case Capability::GroupNonUniformBallot: + return "GroupNonUniformBallot"; + case Capability::GroupNonUniformShuffle: + return "GroupNonUniformShuffle"; + case Capability::GroupNonUniformShuffleRelative: + return "GroupNonUniformShuffleRelative"; + case Capability::GroupNonUniformClustered: + return "GroupNonUniformClustered"; + case Capability::GroupNonUniformQuad: + return "GroupNonUniformQuad"; + case Capability::ShaderLayer: + return "ShaderLayer"; + case Capability::ShaderViewportIndex: + return "ShaderViewportIndex"; + case Capability::UniformDecoration: + return "UniformDecoration"; + case Capability::CoreBuiltinsARM: + return "CoreBuiltinsARM"; + case Capability::TileImageColorReadAccessEXT: + return "TileImageColorReadAccessEXT"; + case Capability::TileImageDepthReadAccessEXT: + return "TileImageDepthReadAccessEXT"; + case Capability::TileImageStencilReadAccessEXT: + return "TileImageStencilReadAccessEXT"; + case Capability::CooperativeMatrixLayoutsARM: + return "CooperativeMatrixLayoutsARM"; + case Capability::FragmentShadingRateKHR: + return "FragmentShadingRateKHR"; + case Capability::SubgroupBallotKHR: + return "SubgroupBallotKHR"; + case Capability::DrawParameters: + return "DrawParameters"; + case Capability::WorkgroupMemoryExplicitLayoutKHR: + return "WorkgroupMemoryExplicitLayoutKHR"; + case Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR: + return "WorkgroupMemoryExplicitLayout8BitAccessKHR"; + case Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR: + return "WorkgroupMemoryExplicitLayout16BitAccessKHR"; + case Capability::SubgroupVoteKHR: + return "SubgroupVoteKHR"; + case Capability::StorageBuffer16BitAccess: + return "StorageBuffer16BitAccess"; + case Capability::UniformAndStorageBuffer16BitAccess: + return "UniformAndStorageBuffer16BitAccess"; + case Capability::StoragePushConstant16: + return "StoragePushConstant16"; + case Capability::StorageInputOutput16: + return "StorageInputOutput16"; + case Capability::DeviceGroup: + return "DeviceGroup"; + case Capability::MultiView: + return "MultiView"; + case Capability::VariablePointersStorageBuffer: + return "VariablePointersStorageBuffer"; + case Capability::VariablePointers: + return "VariablePointers"; + case Capability::AtomicStorageOps: + return "AtomicStorageOps"; + case Capability::SampleMaskPostDepthCoverage: + return "SampleMaskPostDepthCoverage"; + case Capability::StorageBuffer8BitAccess: + return "StorageBuffer8BitAccess"; + case Capability::UniformAndStorageBuffer8BitAccess: + return "UniformAndStorageBuffer8BitAccess"; + case Capability::StoragePushConstant8: + return "StoragePushConstant8"; + case Capability::DenormPreserve: + return "DenormPreserve"; + case Capability::DenormFlushToZero: + return "DenormFlushToZero"; + case Capability::SignedZeroInfNanPreserve: + return "SignedZeroInfNanPreserve"; + case Capability::RoundingModeRTE: + return "RoundingModeRTE"; + case Capability::RoundingModeRTZ: + return "RoundingModeRTZ"; + case Capability::RayQueryProvisionalKHR: + return "RayQueryProvisionalKHR"; + case Capability::RayQueryKHR: + return "RayQueryKHR"; + case Capability::UntypedPointersKHR: + return "UntypedPointersKHR"; + case Capability::RayTraversalPrimitiveCullingKHR: + return "RayTraversalPrimitiveCullingKHR"; + case Capability::RayTracingKHR: + return "RayTracingKHR"; + case Capability::TextureSampleWeightedQCOM: + return "TextureSampleWeightedQCOM"; + case Capability::TextureBoxFilterQCOM: + return "TextureBoxFilterQCOM"; + case Capability::TextureBlockMatchQCOM: + return "TextureBlockMatchQCOM"; + case Capability::TextureBlockMatch2QCOM: + return "TextureBlockMatch2QCOM"; + case Capability::Float16ImageAMD: + return "Float16ImageAMD"; + case Capability::ImageGatherBiasLodAMD: + return "ImageGatherBiasLodAMD"; + case Capability::FragmentMaskAMD: + return "FragmentMaskAMD"; + case Capability::StencilExportEXT: + return "StencilExportEXT"; + case Capability::ImageReadWriteLodAMD: + return "ImageReadWriteLodAMD"; + case Capability::Int64ImageEXT: + return "Int64ImageEXT"; + case Capability::ShaderClockKHR: + return "ShaderClockKHR"; + case Capability::ShaderEnqueueAMDX: + return "ShaderEnqueueAMDX"; + case Capability::QuadControlKHR: + return "QuadControlKHR"; + case Capability::SampleMaskOverrideCoverageNV: + return "SampleMaskOverrideCoverageNV"; + case Capability::GeometryShaderPassthroughNV: + return "GeometryShaderPassthroughNV"; + case Capability::ShaderViewportIndexLayerEXT: + return "ShaderViewportIndexLayerEXT"; + case Capability::ShaderViewportMaskNV: + return "ShaderViewportMaskNV"; + case Capability::ShaderStereoViewNV: + return "ShaderStereoViewNV"; + case Capability::PerViewAttributesNV: + return "PerViewAttributesNV"; + case Capability::FragmentFullyCoveredEXT: + return "FragmentFullyCoveredEXT"; + case Capability::MeshShadingNV: + return "MeshShadingNV"; + case Capability::ImageFootprintNV: + return "ImageFootprintNV"; + case Capability::MeshShadingEXT: + return "MeshShadingEXT"; + case Capability::FragmentBarycentricKHR: + return "FragmentBarycentricKHR"; + case Capability::ComputeDerivativeGroupQuadsKHR: + return "ComputeDerivativeGroupQuadsKHR"; + case Capability::FragmentDensityEXT: + return "FragmentDensityEXT"; + case Capability::GroupNonUniformPartitionedNV: + return "GroupNonUniformPartitionedNV"; + case Capability::ShaderNonUniform: + return "ShaderNonUniform"; + case Capability::RuntimeDescriptorArray: + return "RuntimeDescriptorArray"; + case Capability::InputAttachmentArrayDynamicIndexing: + return "InputAttachmentArrayDynamicIndexing"; + case Capability::UniformTexelBufferArrayDynamicIndexing: + return "UniformTexelBufferArrayDynamicIndexing"; + case Capability::StorageTexelBufferArrayDynamicIndexing: + return "StorageTexelBufferArrayDynamicIndexing"; + case Capability::UniformBufferArrayNonUniformIndexing: + return "UniformBufferArrayNonUniformIndexing"; + case Capability::SampledImageArrayNonUniformIndexing: + return "SampledImageArrayNonUniformIndexing"; + case Capability::StorageBufferArrayNonUniformIndexing: + return "StorageBufferArrayNonUniformIndexing"; + case Capability::StorageImageArrayNonUniformIndexing: + return "StorageImageArrayNonUniformIndexing"; + case Capability::InputAttachmentArrayNonUniformIndexing: + return "InputAttachmentArrayNonUniformIndexing"; + case Capability::UniformTexelBufferArrayNonUniformIndexing: + return "UniformTexelBufferArrayNonUniformIndexing"; + case Capability::StorageTexelBufferArrayNonUniformIndexing: + return "StorageTexelBufferArrayNonUniformIndexing"; + case Capability::RayTracingPositionFetchKHR: + return "RayTracingPositionFetchKHR"; + case Capability::RayTracingNV: + return "RayTracingNV"; + case Capability::RayTracingMotionBlurNV: + return "RayTracingMotionBlurNV"; + case Capability::VulkanMemoryModel: + return "VulkanMemoryModel"; + case Capability::VulkanMemoryModelDeviceScope: + return "VulkanMemoryModelDeviceScope"; + case Capability::PhysicalStorageBufferAddresses: + return "PhysicalStorageBufferAddresses"; + case Capability::ComputeDerivativeGroupLinearKHR: + return "ComputeDerivativeGroupLinearKHR"; + case Capability::RayTracingProvisionalKHR: + return "RayTracingProvisionalKHR"; + case Capability::CooperativeMatrixNV: + return "CooperativeMatrixNV"; + case Capability::FragmentShaderSampleInterlockEXT: + return "FragmentShaderSampleInterlockEXT"; + case Capability::FragmentShaderShadingRateInterlockEXT: + return "FragmentShaderShadingRateInterlockEXT"; + case Capability::ShaderSMBuiltinsNV: + return "ShaderSMBuiltinsNV"; + case Capability::FragmentShaderPixelInterlockEXT: + return "FragmentShaderPixelInterlockEXT"; + case Capability::DemoteToHelperInvocation: + return "DemoteToHelperInvocation"; + case Capability::DisplacementMicromapNV: + return "DisplacementMicromapNV"; + case Capability::RayTracingOpacityMicromapEXT: + return "RayTracingOpacityMicromapEXT"; + case Capability::ShaderInvocationReorderNV: + return "ShaderInvocationReorderNV"; + case Capability::BindlessTextureNV: + return "BindlessTextureNV"; + case Capability::RayQueryPositionFetchKHR: + return "RayQueryPositionFetchKHR"; + case Capability::AtomicFloat16VectorNV: + return "AtomicFloat16VectorNV"; + case Capability::RayTracingDisplacementMicromapNV: + return "RayTracingDisplacementMicromapNV"; + case Capability::RawAccessChainsNV: + return "RawAccessChainsNV"; + case Capability::CooperativeMatrixReductionsNV: + return "CooperativeMatrixReductionsNV"; + case Capability::CooperativeMatrixConversionsNV: + return "CooperativeMatrixConversionsNV"; + case Capability::CooperativeMatrixPerElementOperationsNV: + return "CooperativeMatrixPerElementOperationsNV"; + case Capability::CooperativeMatrixTensorAddressingNV: + return "CooperativeMatrixTensorAddressingNV"; + case Capability::CooperativeMatrixBlockLoadsNV: + return "CooperativeMatrixBlockLoadsNV"; + case Capability::TensorAddressingNV: + return "TensorAddressingNV"; + case Capability::SubgroupShuffleINTEL: + return "SubgroupShuffleINTEL"; + case Capability::SubgroupBufferBlockIOINTEL: + return "SubgroupBufferBlockIOINTEL"; + case Capability::SubgroupImageBlockIOINTEL: + return "SubgroupImageBlockIOINTEL"; + case Capability::SubgroupImageMediaBlockIOINTEL: + return "SubgroupImageMediaBlockIOINTEL"; + case Capability::RoundToInfinityINTEL: + return "RoundToInfinityINTEL"; + case Capability::FloatingPointModeINTEL: + return "FloatingPointModeINTEL"; + case Capability::IntegerFunctions2INTEL: + return "IntegerFunctions2INTEL"; + case Capability::FunctionPointersINTEL: + return "FunctionPointersINTEL"; + case Capability::IndirectReferencesINTEL: + return "IndirectReferencesINTEL"; + case Capability::AsmINTEL: + return "AsmINTEL"; + case Capability::AtomicFloat32MinMaxEXT: + return "AtomicFloat32MinMaxEXT"; + case Capability::AtomicFloat64MinMaxEXT: + return "AtomicFloat64MinMaxEXT"; + case Capability::AtomicFloat16MinMaxEXT: + return "AtomicFloat16MinMaxEXT"; + case Capability::VectorComputeINTEL: + return "VectorComputeINTEL"; + case Capability::VectorAnyINTEL: + return "VectorAnyINTEL"; + case Capability::ExpectAssumeKHR: + return "ExpectAssumeKHR"; + case Capability::SubgroupAvcMotionEstimationINTEL: + return "SubgroupAvcMotionEstimationINTEL"; + case Capability::SubgroupAvcMotionEstimationIntraINTEL: + return "SubgroupAvcMotionEstimationIntraINTEL"; + case Capability::SubgroupAvcMotionEstimationChromaINTEL: + return "SubgroupAvcMotionEstimationChromaINTEL"; + case Capability::VariableLengthArrayINTEL: + return "VariableLengthArrayINTEL"; + case Capability::FunctionFloatControlINTEL: + return "FunctionFloatControlINTEL"; + case Capability::FPGAMemoryAttributesINTEL: + return "FPGAMemoryAttributesINTEL"; + case Capability::FPFastMathModeINTEL: + return "FPFastMathModeINTEL"; + case Capability::ArbitraryPrecisionIntegersINTEL: + return "ArbitraryPrecisionIntegersINTEL"; + case Capability::ArbitraryPrecisionFloatingPointINTEL: + return "ArbitraryPrecisionFloatingPointINTEL"; + case Capability::UnstructuredLoopControlsINTEL: + return "UnstructuredLoopControlsINTEL"; + case Capability::FPGALoopControlsINTEL: + return "FPGALoopControlsINTEL"; + case Capability::KernelAttributesINTEL: + return "KernelAttributesINTEL"; + case Capability::FPGAKernelAttributesINTEL: + return "FPGAKernelAttributesINTEL"; + case Capability::FPGAMemoryAccessesINTEL: + return "FPGAMemoryAccessesINTEL"; + case Capability::FPGAClusterAttributesINTEL: + return "FPGAClusterAttributesINTEL"; + case Capability::LoopFuseINTEL: + return "LoopFuseINTEL"; + case Capability::FPGADSPControlINTEL: + return "FPGADSPControlINTEL"; + case Capability::MemoryAccessAliasingINTEL: + return "MemoryAccessAliasingINTEL"; + case Capability::FPGAInvocationPipeliningAttributesINTEL: + return "FPGAInvocationPipeliningAttributesINTEL"; + case Capability::FPGABufferLocationINTEL: + return "FPGABufferLocationINTEL"; + case Capability::ArbitraryPrecisionFixedPointINTEL: + return "ArbitraryPrecisionFixedPointINTEL"; + case Capability::USMStorageClassesINTEL: + return "USMStorageClassesINTEL"; + case Capability::RuntimeAlignedAttributeINTEL: + return "RuntimeAlignedAttributeINTEL"; + case Capability::IOPipesINTEL: + return "IOPipesINTEL"; + case Capability::BlockingPipesINTEL: + return "BlockingPipesINTEL"; + case Capability::FPGARegINTEL: + return "FPGARegINTEL"; + case Capability::DotProductInputAll: + return "DotProductInputAll"; + case Capability::DotProductInput4x8Bit: + return "DotProductInput4x8Bit"; + case Capability::DotProductInput4x8BitPacked: + return "DotProductInput4x8BitPacked"; + case Capability::DotProduct: + return "DotProduct"; + case Capability::RayCullMaskKHR: + return "RayCullMaskKHR"; + case Capability::CooperativeMatrixKHR: + return "CooperativeMatrixKHR"; + case Capability::ReplicatedCompositesEXT: + return "ReplicatedCompositesEXT"; + case Capability::BitInstructions: + return "BitInstructions"; + case Capability::GroupNonUniformRotateKHR: + return "GroupNonUniformRotateKHR"; + case Capability::FloatControls2: + return "FloatControls2"; + case Capability::AtomicFloat32AddEXT: + return "AtomicFloat32AddEXT"; + case Capability::AtomicFloat64AddEXT: + return "AtomicFloat64AddEXT"; + case Capability::LongCompositesINTEL: + return "LongCompositesINTEL"; + case Capability::OptNoneEXT: + return "OptNoneEXT"; + case Capability::AtomicFloat16AddEXT: + return "AtomicFloat16AddEXT"; + case Capability::DebugInfoModuleINTEL: + return "DebugInfoModuleINTEL"; + case Capability::BFloat16ConversionINTEL: + return "BFloat16ConversionINTEL"; + case Capability::SplitBarrierINTEL: + return "SplitBarrierINTEL"; + case Capability::ArithmeticFenceEXT: + return "ArithmeticFenceEXT"; + case Capability::FPGAClusterAttributesV2INTEL: + return "FPGAClusterAttributesV2INTEL"; + case Capability::FPGAKernelAttributesv2INTEL: + return "FPGAKernelAttributesv2INTEL"; + case Capability::FPMaxErrorINTEL: + return "FPMaxErrorINTEL"; + case Capability::FPGALatencyControlINTEL: + return "FPGALatencyControlINTEL"; + case Capability::FPGAArgumentInterfacesINTEL: + return "FPGAArgumentInterfacesINTEL"; + case Capability::GlobalVariableHostAccessINTEL: + return "GlobalVariableHostAccessINTEL"; + case Capability::GlobalVariableFPGADecorationsINTEL: + return "GlobalVariableFPGADecorationsINTEL"; + case Capability::SubgroupBufferPrefetchINTEL: + return "SubgroupBufferPrefetchINTEL"; + case Capability::GroupUniformArithmeticKHR: + return "GroupUniformArithmeticKHR"; + case Capability::MaskedGatherScatterINTEL: + return "MaskedGatherScatterINTEL"; + case Capability::CacheControlsINTEL: + return "CacheControlsINTEL"; + case Capability::RegisterLimitsINTEL: + return "RegisterLimitsINTEL"; + } + return "unknown"; +} +auto to_string(RayQueryIntersection e) -> char const * { + switch (e) { + case RayQueryIntersection::RayQueryCandidateIntersectionKHR: + return "RayQueryCandidateIntersectionKHR"; + case RayQueryIntersection::RayQueryCommittedIntersectionKHR: + return "RayQueryCommittedIntersectionKHR"; + } + return "unknown"; +} +auto to_string(RayQueryCommittedIntersectionType e) -> char const * { + switch (e) { + case RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionNoneKHR: + return "RayQueryCommittedIntersectionNoneKHR"; + case RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionTriangleKHR: + return "RayQueryCommittedIntersectionTriangleKHR"; + case RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionGeneratedKHR: + return "RayQueryCommittedIntersectionGeneratedKHR"; + } + return "unknown"; +} +auto to_string(RayQueryCandidateIntersectionType e) -> char const * { + switch (e) { + case RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR: + return "RayQueryCandidateIntersectionTriangleKHR"; + case RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionAABBKHR: + return "RayQueryCandidateIntersectionAABBKHR"; + } + return "unknown"; +} +auto to_string(PackedVectorFormat e) -> char const * { + switch (e) { + case PackedVectorFormat::PackedVectorFormat4x8Bit: + return "PackedVectorFormat4x8Bit"; + } + return "unknown"; +} +auto to_string(CooperativeMatrixOperands e) -> char const * { + switch (e) { + case CooperativeMatrixOperands::NoneKHR: + return "NoneKHR"; + case CooperativeMatrixOperands::MatrixASignedComponentsKHR: + return "MatrixASignedComponentsKHR"; + case CooperativeMatrixOperands::MatrixBSignedComponentsKHR: + return "MatrixBSignedComponentsKHR"; + case CooperativeMatrixOperands::MatrixCSignedComponentsKHR: + return "MatrixCSignedComponentsKHR"; + case CooperativeMatrixOperands::MatrixResultSignedComponentsKHR: + return "MatrixResultSignedComponentsKHR"; + case CooperativeMatrixOperands::SaturatingAccumulationKHR: + return "SaturatingAccumulationKHR"; + } + return "unknown"; +} +auto to_string(CooperativeMatrixLayout e) -> char const * { + switch (e) { + case CooperativeMatrixLayout::RowMajorKHR: + return "RowMajorKHR"; + case CooperativeMatrixLayout::ColumnMajorKHR: + return "ColumnMajorKHR"; + case CooperativeMatrixLayout::RowBlockedInterleavedARM: + return "RowBlockedInterleavedARM"; + case CooperativeMatrixLayout::ColumnBlockedInterleavedARM: + return "ColumnBlockedInterleavedARM"; + } + return "unknown"; +} +auto to_string(CooperativeMatrixUse e) -> char const * { + switch (e) { + case CooperativeMatrixUse::MatrixAKHR: + return "MatrixAKHR"; + case CooperativeMatrixUse::MatrixBKHR: + return "MatrixBKHR"; + case CooperativeMatrixUse::MatrixAccumulatorKHR: + return "MatrixAccumulatorKHR"; + } + return "unknown"; +} +auto to_string(CooperativeMatrixReduce e) -> char const * { + switch (e) { + case CooperativeMatrixReduce::Row: + return "Row"; + case CooperativeMatrixReduce::Column: + return "Column"; + case CooperativeMatrixReduce::CooperativeMatrixReduce2x2: + return "CooperativeMatrixReduce2x2"; + } + return "unknown"; +} +auto to_string(TensorClampMode e) -> char const * { + switch (e) { + case TensorClampMode::Undefined: + return "Undefined"; + case TensorClampMode::Constant: + return "Constant"; + case TensorClampMode::ClampToEdge: + return "ClampToEdge"; + case TensorClampMode::Repeat: + return "Repeat"; + case TensorClampMode::RepeatMirrored: + return "RepeatMirrored"; + } + return "unknown"; +} +auto to_string(TensorAddressingOperands e) -> char const * { + switch (e) { + case TensorAddressingOperands::None: + return "None"; + case TensorAddressingOperands::TensorView: + return "TensorView"; + case TensorAddressingOperands::DecodeFunc: + return "DecodeFunc"; + } + return "unknown"; +} +auto to_string(InitializationModeQualifier e) -> char const * { + switch (e) { + case InitializationModeQualifier::InitOnDeviceReprogramINTEL: + return "InitOnDeviceReprogramINTEL"; + case InitializationModeQualifier::InitOnDeviceResetINTEL: + return "InitOnDeviceResetINTEL"; + } + return "unknown"; +} +auto to_string(LoadCacheControl e) -> char const * { + switch (e) { + case LoadCacheControl::UncachedINTEL: + return "UncachedINTEL"; + case LoadCacheControl::CachedINTEL: + return "CachedINTEL"; + case LoadCacheControl::StreamingINTEL: + return "StreamingINTEL"; + case LoadCacheControl::InvalidateAfterReadINTEL: + return "InvalidateAfterReadINTEL"; + case LoadCacheControl::ConstCachedINTEL: + return "ConstCachedINTEL"; + } + return "unknown"; +} +auto to_string(StoreCacheControl e) -> char const * { + switch (e) { + case StoreCacheControl::UncachedINTEL: + return "UncachedINTEL"; + case StoreCacheControl::WriteThroughINTEL: + return "WriteThroughINTEL"; + case StoreCacheControl::WriteBackINTEL: + return "WriteBackINTEL"; + case StoreCacheControl::StreamingINTEL: + return "StreamingINTEL"; + } + return "unknown"; +} +auto to_string(NamedMaximumNumberOfRegisters e) -> char const * { + switch (e) { + case NamedMaximumNumberOfRegisters::AutoINTEL: + return "AutoINTEL"; + } + return "unknown"; +} +auto to_string(FPEncoding e) -> char const * { + switch (e) {} + return "unknown"; +} + +} // namespace tinytc::spv + +#endif // GENERATED_NAMES_2024114_HPP diff --git a/src/spv/names.hpp b/src/spv/names.hpp new file mode 100644 index 00000000..8bc5f3b9 --- /dev/null +++ b/src/spv/names.hpp @@ -0,0 +1,69 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_NAMES_2024114_HPP +#define GENERATED_NAMES_2024114_HPP + +namespace tinytc::spv { + +auto to_string(Op op) -> char const *; +auto to_string(ImageOperands e) -> char const *; +auto to_string(FPFastMathMode e) -> char const *; +auto to_string(SelectionControl e) -> char const *; +auto to_string(LoopControl e) -> char const *; +auto to_string(FunctionControl e) -> char const *; +auto to_string(MemorySemantics e) -> char const *; +auto to_string(MemoryAccess e) -> char const *; +auto to_string(KernelProfilingInfo e) -> char const *; +auto to_string(RayFlags e) -> char const *; +auto to_string(FragmentShadingRate e) -> char const *; +auto to_string(RawAccessChainOperands e) -> char const *; +auto to_string(SourceLanguage e) -> char const *; +auto to_string(ExecutionModel e) -> char const *; +auto to_string(AddressingModel e) -> char const *; +auto to_string(MemoryModel e) -> char const *; +auto to_string(ExecutionMode e) -> char const *; +auto to_string(StorageClass e) -> char const *; +auto to_string(Dim e) -> char const *; +auto to_string(SamplerAddressingMode e) -> char const *; +auto to_string(SamplerFilterMode e) -> char const *; +auto to_string(ImageFormat e) -> char const *; +auto to_string(ImageChannelOrder e) -> char const *; +auto to_string(ImageChannelDataType e) -> char const *; +auto to_string(FPRoundingMode e) -> char const *; +auto to_string(FPDenormMode e) -> char const *; +auto to_string(QuantizationModes e) -> char const *; +auto to_string(FPOperationMode e) -> char const *; +auto to_string(OverflowModes e) -> char const *; +auto to_string(LinkageType e) -> char const *; +auto to_string(AccessQualifier e) -> char const *; +auto to_string(HostAccessQualifier e) -> char const *; +auto to_string(FunctionParameterAttribute e) -> char const *; +auto to_string(Decoration e) -> char const *; +auto to_string(BuiltIn e) -> char const *; +auto to_string(Scope e) -> char const *; +auto to_string(GroupOperation e) -> char const *; +auto to_string(KernelEnqueueFlags e) -> char const *; +auto to_string(Capability e) -> char const *; +auto to_string(RayQueryIntersection e) -> char const *; +auto to_string(RayQueryCommittedIntersectionType e) -> char const *; +auto to_string(RayQueryCandidateIntersectionType e) -> char const *; +auto to_string(PackedVectorFormat e) -> char const *; +auto to_string(CooperativeMatrixOperands e) -> char const *; +auto to_string(CooperativeMatrixLayout e) -> char const *; +auto to_string(CooperativeMatrixUse e) -> char const *; +auto to_string(CooperativeMatrixReduce e) -> char const *; +auto to_string(TensorClampMode e) -> char const *; +auto to_string(TensorAddressingOperands e) -> char const *; +auto to_string(InitializationModeQualifier e) -> char const *; +auto to_string(LoadCacheControl e) -> char const *; +auto to_string(StoreCacheControl e) -> char const *; +auto to_string(NamedMaximumNumberOfRegisters e) -> char const *; +auto to_string(FPEncoding e) -> char const *; + +} // namespace tinytc::spv + +#endif // GENERATED_NAMES_2024114_HPP diff --git a/src/spv/pass/dump_asm.cpp b/src/spv/pass/dump_asm.cpp new file mode 100644 index 00000000..323dea09 --- /dev/null +++ b/src/spv/pass/dump_asm.cpp @@ -0,0 +1,110 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/pass/dump_asm.hpp" +#include "spv/module.hpp" +#include "support/casting.hpp" + +#include +#include + +namespace tinytc::spv { + +dump_asm_pass::dump_asm_pass(std::ostream &os) : os_(&os) {} + +auto dump_asm_pass::declare(spv_inst const *in) -> std::int64_t { + auto s = slot_map_.find(in); + if (s == slot_map_.end()) { + const auto slot = slot_++; + slot_map_[in] = slot; + return slot; + } + return s->second; +} + +void dump_asm_pass::pre_visit(spv_inst const &in) { + auto const num_digits = [](std::int64_t number) { + std::int64_t d = 1; + while (number /= 10) { + ++d; + } + return d; + }; + *os_ << std::endl; + if (in.has_result_id()) { + const auto slot = declare(&in); + + for (int i = 0; i < rhs_indent - 4 - num_digits(slot); ++i) { + *os_ << ' '; + } + *os_ << "%" << slot << " = "; + } else { + for (int i = 0; i < rhs_indent; ++i) { + *os_ << ' '; + } + } + *os_ << "Op" << to_string(in.opcode()); +} + +void dump_asm_pass::operator()(DecorationAttr const &da) { + std::visit(overloaded{[&](std::pair const &a) { + *os_ << " \"" << a.first << '"'; + this->operator()(a.second); + }}, + da); +} +void dump_asm_pass::operator()(ExecutionModeAttr const &ea) { + std::visit(overloaded{[&](std::int32_t const &a) { *os_ << " " << a; }, + [&](std::array const &a) { + for (auto const &s : a) { + *os_ << " " << s; + } + }}, + ea); +} +void dump_asm_pass::operator()(LiteralContextDependentNumber const &l) { + std::visit(overloaded{[&](auto const &l) { *os_ << " " << l; }}, l); +} +void dump_asm_pass::operator()(LiteralInteger const &l) { *os_ << " " << l; } +void dump_asm_pass::operator()(LiteralString const &l) { *os_ << " \"" << l << '"'; } + +void dump_asm_pass::operator()(spv_inst *const &in) { + if (auto s = slot_map_.find(in); s != slot_map_.end()) { + *os_ << " %" << s->second; + } else if (isa(*in)) { + *os_ << " %" << declare(in); + } else { + throw status::spirv_forbidden_forward_declaration; + } +} + +void dump_asm_pass::operator()(PairIdRefIdRef const &p) { + this->operator()(p.first); + this->operator()(p.second); +} +void dump_asm_pass::operator()(PairIdRefLiteralInteger const &p) { + this->operator()(p.first); + this->operator()(p.second); +} +void dump_asm_pass::operator()(PairLiteralIntegerIdRef const &p) { + std::visit(overloaded{[&](auto const &l) { *os_ << " " << l; }}, p.first); + this->operator()(p.second); +} + +void dump_asm_pass::run_on_module(mod const &m) { + auto const visit_section = [&](section s) { + for (auto const &i : m.insts(s)) { + visit(*this, i); + } + }; + visit_section(section::capability); + visit_section(section::memory_model); + visit_section(section::entry_point); + visit_section(section::execution_mode); + visit_section(section::decoration); + visit_section(section::type); + visit_section(section::function); + *os_ << std::endl; +} + +} // namespace tinytc::spv diff --git a/src/spv/pass/dump_asm.hpp b/src/spv/pass/dump_asm.hpp new file mode 100644 index 00000000..238ea0f8 --- /dev/null +++ b/src/spv/pass/dump_asm.hpp @@ -0,0 +1,57 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DUMP_ASM_20241029_HPP +#define DUMP_ASM_20241029_HPP + +#include "spv/instructions.hpp" +#include "spv/names.hpp" +#include "spv/visit.hpp" + +#include +#include +#include + +namespace tinytc::spv { + +class mod; + +class dump_asm_pass : public default_visitor { + public: + using default_visitor::operator(); + constexpr static int rhs_indent = 15; + + dump_asm_pass(std::ostream &os); + + void pre_visit(spv_inst const &in); + + template + requires requires(T const &e) { to_string(e); } + void operator()(T const &e) { + *os_ << " " << to_string(e); + } + void operator()(DecorationAttr const &da); + void operator()(ExecutionModeAttr const &ea); + void operator()(LiteralContextDependentNumber const &l); + void operator()(LiteralInteger const &l); + void operator()(LiteralString const &l); + + void operator()(spv_inst *const &in); + void operator()(PairIdRefIdRef const &p); + void operator()(PairIdRefLiteralInteger const &p); + void operator()(PairLiteralIntegerIdRef const &p); + + void run_on_module(mod const &m); + + private: + auto declare(spv_inst const *in) -> std::int64_t; + + std::ostream *os_; + + std::int64_t slot_ = 0; + std::unordered_map slot_map_; +}; + +} // namespace tinytc::spv + +#endif // DUMP_ASM_20241029_HPP diff --git a/src/spv/visit.hpp b/src/spv/visit.hpp new file mode 100644 index 00000000..0e3c9c62 --- /dev/null +++ b/src/spv/visit.hpp @@ -0,0 +1,2915 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_VISIT_2024114_HPP +#define GENERATED_VISIT_2024114_HPP + +namespace tinytc::spv { + +template struct overloaded : Ts... { + using Ts::operator()...; +}; +template overloaded(Ts...) -> overloaded; + +template auto visit(Visitor &&visitor, spv_inst const &inst) { + switch (inst.opcode()) { + case Op::Nop: + return visitor(static_cast(inst)); + case Op::Undef: + return visitor(static_cast(inst)); + case Op::SourceContinued: + return visitor(static_cast(inst)); + case Op::Source: + return visitor(static_cast(inst)); + case Op::SourceExtension: + return visitor(static_cast(inst)); + case Op::Name: + return visitor(static_cast(inst)); + case Op::MemberName: + return visitor(static_cast(inst)); + case Op::String: + return visitor(static_cast(inst)); + case Op::Line: + return visitor(static_cast(inst)); + case Op::Extension: + return visitor(static_cast(inst)); + case Op::ExtInstImport: + return visitor(static_cast(inst)); + case Op::ExtInst: + return visitor(static_cast(inst)); + case Op::MemoryModel: + return visitor(static_cast(inst)); + case Op::EntryPoint: + return visitor(static_cast(inst)); + case Op::ExecutionMode: + return visitor(static_cast(inst)); + case Op::Capability: + return visitor(static_cast(inst)); + case Op::TypeVoid: + return visitor(static_cast(inst)); + case Op::TypeBool: + return visitor(static_cast(inst)); + case Op::TypeInt: + return visitor(static_cast(inst)); + case Op::TypeFloat: + return visitor(static_cast(inst)); + case Op::TypeVector: + return visitor(static_cast(inst)); + case Op::TypeMatrix: + return visitor(static_cast(inst)); + case Op::TypeImage: + return visitor(static_cast(inst)); + case Op::TypeSampler: + return visitor(static_cast(inst)); + case Op::TypeSampledImage: + return visitor(static_cast(inst)); + case Op::TypeArray: + return visitor(static_cast(inst)); + case Op::TypeRuntimeArray: + return visitor(static_cast(inst)); + case Op::TypeStruct: + return visitor(static_cast(inst)); + case Op::TypeOpaque: + return visitor(static_cast(inst)); + case Op::TypePointer: + return visitor(static_cast(inst)); + case Op::TypeFunction: + return visitor(static_cast(inst)); + case Op::TypeEvent: + return visitor(static_cast(inst)); + case Op::TypeDeviceEvent: + return visitor(static_cast(inst)); + case Op::TypeReserveId: + return visitor(static_cast(inst)); + case Op::TypeQueue: + return visitor(static_cast(inst)); + case Op::TypePipe: + return visitor(static_cast(inst)); + case Op::TypeForwardPointer: + return visitor(static_cast(inst)); + case Op::ConstantTrue: + return visitor(static_cast(inst)); + case Op::ConstantFalse: + return visitor(static_cast(inst)); + case Op::Constant: + return visitor(static_cast(inst)); + case Op::ConstantComposite: + return visitor(static_cast(inst)); + case Op::ConstantSampler: + return visitor(static_cast(inst)); + case Op::ConstantNull: + return visitor(static_cast(inst)); + case Op::Function: + return visitor(static_cast(inst)); + case Op::FunctionParameter: + return visitor(static_cast(inst)); + case Op::FunctionEnd: + return visitor(static_cast(inst)); + case Op::FunctionCall: + return visitor(static_cast(inst)); + case Op::Variable: + return visitor(static_cast(inst)); + case Op::ImageTexelPointer: + return visitor(static_cast(inst)); + case Op::Load: + return visitor(static_cast(inst)); + case Op::Store: + return visitor(static_cast(inst)); + case Op::CopyMemory: + return visitor(static_cast(inst)); + case Op::CopyMemorySized: + return visitor(static_cast(inst)); + case Op::AccessChain: + return visitor(static_cast(inst)); + case Op::InBoundsAccessChain: + return visitor(static_cast(inst)); + case Op::PtrAccessChain: + return visitor(static_cast(inst)); + case Op::ArrayLength: + return visitor(static_cast(inst)); + case Op::GenericPtrMemSemantics: + return visitor(static_cast(inst)); + case Op::InBoundsPtrAccessChain: + return visitor(static_cast(inst)); + case Op::Decorate: + return visitor(static_cast(inst)); + case Op::MemberDecorate: + return visitor(static_cast(inst)); + case Op::DecorationGroup: + return visitor(static_cast(inst)); + case Op::GroupDecorate: + return visitor(static_cast(inst)); + case Op::GroupMemberDecorate: + return visitor(static_cast(inst)); + case Op::VectorExtractDynamic: + return visitor(static_cast(inst)); + case Op::VectorInsertDynamic: + return visitor(static_cast(inst)); + case Op::VectorShuffle: + return visitor(static_cast(inst)); + case Op::CompositeConstruct: + return visitor(static_cast(inst)); + case Op::CompositeExtract: + return visitor(static_cast(inst)); + case Op::CompositeInsert: + return visitor(static_cast(inst)); + case Op::CopyObject: + return visitor(static_cast(inst)); + case Op::Transpose: + return visitor(static_cast(inst)); + case Op::SampledImage: + return visitor(static_cast(inst)); + case Op::ImageSampleImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageFetch: + return visitor(static_cast(inst)); + case Op::ImageGather: + return visitor(static_cast(inst)); + case Op::ImageDrefGather: + return visitor(static_cast(inst)); + case Op::ImageRead: + return visitor(static_cast(inst)); + case Op::ImageWrite: + return visitor(static_cast(inst)); + case Op::Image: + return visitor(static_cast(inst)); + case Op::ImageQueryFormat: + return visitor(static_cast(inst)); + case Op::ImageQueryOrder: + return visitor(static_cast(inst)); + case Op::ImageQuerySizeLod: + return visitor(static_cast(inst)); + case Op::ImageQuerySize: + return visitor(static_cast(inst)); + case Op::ImageQueryLod: + return visitor(static_cast(inst)); + case Op::ImageQueryLevels: + return visitor(static_cast(inst)); + case Op::ImageQuerySamples: + return visitor(static_cast(inst)); + case Op::ConvertFToU: + return visitor(static_cast(inst)); + case Op::ConvertFToS: + return visitor(static_cast(inst)); + case Op::ConvertSToF: + return visitor(static_cast(inst)); + case Op::ConvertUToF: + return visitor(static_cast(inst)); + case Op::UConvert: + return visitor(static_cast(inst)); + case Op::SConvert: + return visitor(static_cast(inst)); + case Op::FConvert: + return visitor(static_cast(inst)); + case Op::QuantizeToF16: + return visitor(static_cast(inst)); + case Op::ConvertPtrToU: + return visitor(static_cast(inst)); + case Op::SatConvertSToU: + return visitor(static_cast(inst)); + case Op::SatConvertUToS: + return visitor(static_cast(inst)); + case Op::ConvertUToPtr: + return visitor(static_cast(inst)); + case Op::PtrCastToGeneric: + return visitor(static_cast(inst)); + case Op::GenericCastToPtr: + return visitor(static_cast(inst)); + case Op::GenericCastToPtrExplicit: + return visitor(static_cast(inst)); + case Op::Bitcast: + return visitor(static_cast(inst)); + case Op::SNegate: + return visitor(static_cast(inst)); + case Op::FNegate: + return visitor(static_cast(inst)); + case Op::IAdd: + return visitor(static_cast(inst)); + case Op::FAdd: + return visitor(static_cast(inst)); + case Op::ISub: + return visitor(static_cast(inst)); + case Op::FSub: + return visitor(static_cast(inst)); + case Op::IMul: + return visitor(static_cast(inst)); + case Op::FMul: + return visitor(static_cast(inst)); + case Op::UDiv: + return visitor(static_cast(inst)); + case Op::SDiv: + return visitor(static_cast(inst)); + case Op::FDiv: + return visitor(static_cast(inst)); + case Op::UMod: + return visitor(static_cast(inst)); + case Op::SRem: + return visitor(static_cast(inst)); + case Op::SMod: + return visitor(static_cast(inst)); + case Op::FRem: + return visitor(static_cast(inst)); + case Op::FMod: + return visitor(static_cast(inst)); + case Op::VectorTimesScalar: + return visitor(static_cast(inst)); + case Op::MatrixTimesScalar: + return visitor(static_cast(inst)); + case Op::VectorTimesMatrix: + return visitor(static_cast(inst)); + case Op::MatrixTimesVector: + return visitor(static_cast(inst)); + case Op::MatrixTimesMatrix: + return visitor(static_cast(inst)); + case Op::OuterProduct: + return visitor(static_cast(inst)); + case Op::Dot: + return visitor(static_cast(inst)); + case Op::IAddCarry: + return visitor(static_cast(inst)); + case Op::ISubBorrow: + return visitor(static_cast(inst)); + case Op::UMulExtended: + return visitor(static_cast(inst)); + case Op::SMulExtended: + return visitor(static_cast(inst)); + case Op::Any: + return visitor(static_cast(inst)); + case Op::All: + return visitor(static_cast(inst)); + case Op::IsNan: + return visitor(static_cast(inst)); + case Op::IsInf: + return visitor(static_cast(inst)); + case Op::IsFinite: + return visitor(static_cast(inst)); + case Op::IsNormal: + return visitor(static_cast(inst)); + case Op::SignBitSet: + return visitor(static_cast(inst)); + case Op::LessOrGreater: + return visitor(static_cast(inst)); + case Op::Ordered: + return visitor(static_cast(inst)); + case Op::Unordered: + return visitor(static_cast(inst)); + case Op::LogicalEqual: + return visitor(static_cast(inst)); + case Op::LogicalNotEqual: + return visitor(static_cast(inst)); + case Op::LogicalOr: + return visitor(static_cast(inst)); + case Op::LogicalAnd: + return visitor(static_cast(inst)); + case Op::LogicalNot: + return visitor(static_cast(inst)); + case Op::Select: + return visitor(static_cast(inst)); + case Op::IEqual: + return visitor(static_cast(inst)); + case Op::INotEqual: + return visitor(static_cast(inst)); + case Op::UGreaterThan: + return visitor(static_cast(inst)); + case Op::SGreaterThan: + return visitor(static_cast(inst)); + case Op::UGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::SGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::ULessThan: + return visitor(static_cast(inst)); + case Op::SLessThan: + return visitor(static_cast(inst)); + case Op::ULessThanEqual: + return visitor(static_cast(inst)); + case Op::SLessThanEqual: + return visitor(static_cast(inst)); + case Op::FOrdEqual: + return visitor(static_cast(inst)); + case Op::FUnordEqual: + return visitor(static_cast(inst)); + case Op::FOrdNotEqual: + return visitor(static_cast(inst)); + case Op::FUnordNotEqual: + return visitor(static_cast(inst)); + case Op::FOrdLessThan: + return visitor(static_cast(inst)); + case Op::FUnordLessThan: + return visitor(static_cast(inst)); + case Op::FOrdGreaterThan: + return visitor(static_cast(inst)); + case Op::FUnordGreaterThan: + return visitor(static_cast(inst)); + case Op::FOrdLessThanEqual: + return visitor(static_cast(inst)); + case Op::FUnordLessThanEqual: + return visitor(static_cast(inst)); + case Op::FOrdGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::FUnordGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::ShiftRightLogical: + return visitor(static_cast(inst)); + case Op::ShiftRightArithmetic: + return visitor(static_cast(inst)); + case Op::ShiftLeftLogical: + return visitor(static_cast(inst)); + case Op::BitwiseOr: + return visitor(static_cast(inst)); + case Op::BitwiseXor: + return visitor(static_cast(inst)); + case Op::BitwiseAnd: + return visitor(static_cast(inst)); + case Op::Not: + return visitor(static_cast(inst)); + case Op::BitFieldInsert: + return visitor(static_cast(inst)); + case Op::BitFieldSExtract: + return visitor(static_cast(inst)); + case Op::BitFieldUExtract: + return visitor(static_cast(inst)); + case Op::BitReverse: + return visitor(static_cast(inst)); + case Op::BitCount: + return visitor(static_cast(inst)); + case Op::DPdx: + return visitor(static_cast(inst)); + case Op::DPdy: + return visitor(static_cast(inst)); + case Op::Fwidth: + return visitor(static_cast(inst)); + case Op::DPdxFine: + return visitor(static_cast(inst)); + case Op::DPdyFine: + return visitor(static_cast(inst)); + case Op::FwidthFine: + return visitor(static_cast(inst)); + case Op::DPdxCoarse: + return visitor(static_cast(inst)); + case Op::DPdyCoarse: + return visitor(static_cast(inst)); + case Op::FwidthCoarse: + return visitor(static_cast(inst)); + case Op::EmitVertex: + return visitor(static_cast(inst)); + case Op::EndPrimitive: + return visitor(static_cast(inst)); + case Op::EmitStreamVertex: + return visitor(static_cast(inst)); + case Op::EndStreamPrimitive: + return visitor(static_cast(inst)); + case Op::ControlBarrier: + return visitor(static_cast(inst)); + case Op::MemoryBarrier: + return visitor(static_cast(inst)); + case Op::AtomicLoad: + return visitor(static_cast(inst)); + case Op::AtomicStore: + return visitor(static_cast(inst)); + case Op::AtomicExchange: + return visitor(static_cast(inst)); + case Op::AtomicCompareExchange: + return visitor(static_cast(inst)); + case Op::AtomicCompareExchangeWeak: + return visitor(static_cast(inst)); + case Op::AtomicIIncrement: + return visitor(static_cast(inst)); + case Op::AtomicIDecrement: + return visitor(static_cast(inst)); + case Op::AtomicIAdd: + return visitor(static_cast(inst)); + case Op::AtomicISub: + return visitor(static_cast(inst)); + case Op::AtomicSMin: + return visitor(static_cast(inst)); + case Op::AtomicUMin: + return visitor(static_cast(inst)); + case Op::AtomicSMax: + return visitor(static_cast(inst)); + case Op::AtomicUMax: + return visitor(static_cast(inst)); + case Op::AtomicAnd: + return visitor(static_cast(inst)); + case Op::AtomicOr: + return visitor(static_cast(inst)); + case Op::AtomicXor: + return visitor(static_cast(inst)); + case Op::Phi: + return visitor(static_cast(inst)); + case Op::LoopMerge: + return visitor(static_cast(inst)); + case Op::SelectionMerge: + return visitor(static_cast(inst)); + case Op::Label: + return visitor(static_cast(inst)); + case Op::Branch: + return visitor(static_cast(inst)); + case Op::BranchConditional: + return visitor(static_cast(inst)); + case Op::Switch: + return visitor(static_cast(inst)); + case Op::Kill: + return visitor(static_cast(inst)); + case Op::Return: + return visitor(static_cast(inst)); + case Op::ReturnValue: + return visitor(static_cast(inst)); + case Op::Unreachable: + return visitor(static_cast(inst)); + case Op::LifetimeStart: + return visitor(static_cast(inst)); + case Op::LifetimeStop: + return visitor(static_cast(inst)); + case Op::GroupAsyncCopy: + return visitor(static_cast(inst)); + case Op::GroupWaitEvents: + return visitor(static_cast(inst)); + case Op::GroupAll: + return visitor(static_cast(inst)); + case Op::GroupAny: + return visitor(static_cast(inst)); + case Op::GroupBroadcast: + return visitor(static_cast(inst)); + case Op::GroupIAdd: + return visitor(static_cast(inst)); + case Op::GroupFAdd: + return visitor(static_cast(inst)); + case Op::GroupFMin: + return visitor(static_cast(inst)); + case Op::GroupUMin: + return visitor(static_cast(inst)); + case Op::GroupSMin: + return visitor(static_cast(inst)); + case Op::GroupFMax: + return visitor(static_cast(inst)); + case Op::GroupUMax: + return visitor(static_cast(inst)); + case Op::GroupSMax: + return visitor(static_cast(inst)); + case Op::ReadPipe: + return visitor(static_cast(inst)); + case Op::WritePipe: + return visitor(static_cast(inst)); + case Op::ReservedReadPipe: + return visitor(static_cast(inst)); + case Op::ReservedWritePipe: + return visitor(static_cast(inst)); + case Op::ReserveReadPipePackets: + return visitor(static_cast(inst)); + case Op::ReserveWritePipePackets: + return visitor(static_cast(inst)); + case Op::CommitReadPipe: + return visitor(static_cast(inst)); + case Op::CommitWritePipe: + return visitor(static_cast(inst)); + case Op::IsValidReserveId: + return visitor(static_cast(inst)); + case Op::GetNumPipePackets: + return visitor(static_cast(inst)); + case Op::GetMaxPipePackets: + return visitor(static_cast(inst)); + case Op::GroupReserveReadPipePackets: + return visitor(static_cast(inst)); + case Op::GroupReserveWritePipePackets: + return visitor(static_cast(inst)); + case Op::GroupCommitReadPipe: + return visitor(static_cast(inst)); + case Op::GroupCommitWritePipe: + return visitor(static_cast(inst)); + case Op::EnqueueMarker: + return visitor(static_cast(inst)); + case Op::EnqueueKernel: + return visitor(static_cast(inst)); + case Op::GetKernelNDrangeSubGroupCount: + return visitor(static_cast(inst)); + case Op::GetKernelNDrangeMaxSubGroupSize: + return visitor(static_cast(inst)); + case Op::GetKernelWorkGroupSize: + return visitor(static_cast(inst)); + case Op::GetKernelPreferredWorkGroupSizeMultiple: + return visitor(static_cast(inst)); + case Op::RetainEvent: + return visitor(static_cast(inst)); + case Op::ReleaseEvent: + return visitor(static_cast(inst)); + case Op::CreateUserEvent: + return visitor(static_cast(inst)); + case Op::IsValidEvent: + return visitor(static_cast(inst)); + case Op::SetUserEventStatus: + return visitor(static_cast(inst)); + case Op::CaptureEventProfilingInfo: + return visitor(static_cast(inst)); + case Op::GetDefaultQueue: + return visitor(static_cast(inst)); + case Op::BuildNDRange: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseFetch: + return visitor(static_cast(inst)); + case Op::ImageSparseGather: + return visitor(static_cast(inst)); + case Op::ImageSparseDrefGather: + return visitor(static_cast(inst)); + case Op::ImageSparseTexelsResident: + return visitor(static_cast(inst)); + case Op::NoLine: + return visitor(static_cast(inst)); + case Op::AtomicFlagTestAndSet: + return visitor(static_cast(inst)); + case Op::AtomicFlagClear: + return visitor(static_cast(inst)); + case Op::ImageSparseRead: + return visitor(static_cast(inst)); + case Op::SizeOf: + return visitor(static_cast(inst)); + case Op::TypePipeStorage: + return visitor(static_cast(inst)); + case Op::ConstantPipeStorage: + return visitor(static_cast(inst)); + case Op::CreatePipeFromPipeStorage: + return visitor(static_cast(inst)); + case Op::GetKernelLocalSizeForSubgroupCount: + return visitor(static_cast(inst)); + case Op::GetKernelMaxNumSubgroups: + return visitor(static_cast(inst)); + case Op::TypeNamedBarrier: + return visitor(static_cast(inst)); + case Op::NamedBarrierInitialize: + return visitor(static_cast(inst)); + case Op::MemoryNamedBarrier: + return visitor(static_cast(inst)); + case Op::ModuleProcessed: + return visitor(static_cast(inst)); + case Op::ExecutionModeId: + return visitor(static_cast(inst)); + case Op::DecorateId: + return visitor(static_cast(inst)); + case Op::GroupNonUniformElect: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAll: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAny: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAllEqual: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBroadcast: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBroadcastFirst: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallot: + return visitor(static_cast(inst)); + case Op::GroupNonUniformInverseBallot: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotBitExtract: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotBitCount: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotFindLSB: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotFindMSB: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffle: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleUp: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleDown: + return visitor(static_cast(inst)); + case Op::GroupNonUniformIAdd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFAdd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformIMul: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMul: + return visitor(static_cast(inst)); + case Op::GroupNonUniformSMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformUMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformSMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformUMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseAnd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseOr: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalAnd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalOr: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformQuadBroadcast: + return visitor(static_cast(inst)); + case Op::GroupNonUniformQuadSwap: + return visitor(static_cast(inst)); + case Op::CopyLogical: + return visitor(static_cast(inst)); + case Op::PtrEqual: + return visitor(static_cast(inst)); + case Op::PtrNotEqual: + return visitor(static_cast(inst)); + case Op::PtrDiff: + return visitor(static_cast(inst)); + case Op::TypeCooperativeMatrixKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixLoadKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixStoreKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixMulAddKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixLengthKHR: + return visitor(static_cast(inst)); + } + throw internal_compiler_error(); +} +template class default_visitor { + public: + auto pre_visit(spv_inst const &) {} + auto operator()(OpNop const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpUndef const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + } + auto operator()(OpSourceContinued const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpSource const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpSourceExtension const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpName const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpMemberName const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpString const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpLine const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpExtension const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpExtInstImport const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpExtInst const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto const &op : in.op2()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpMemoryModel const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpEntryPoint const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + for (auto const &op : in.op3()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpExecutionMode const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpCapability const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpTypeVoid const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpTypeBool const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpTypeInt const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpTypeFloat const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + if (in.op1()) { + static_cast(this)->operator()(*in.op1()); + } + } + auto operator()(OpTypeVector const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpTypeMatrix const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpTypeImage const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->operator()(in.op6()); + if (in.op7()) { + static_cast(this)->operator()(*in.op7()); + } + } + auto operator()(OpTypeSampler const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpTypeSampledImage const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpTypeArray const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpTypeRuntimeArray const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpTypeStruct const &in) { + static_cast(this)->pre_visit(in); + for (auto const &op : in.op0()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpTypeOpaque const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpTypePointer const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpTypeFunction const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + for (auto const &op : in.op1()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpTypeEvent const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpTypeDeviceEvent const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpTypeReserveId const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpTypeQueue const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpTypePipe const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpTypeForwardPointer const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpConstantTrue const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + } + auto operator()(OpConstantFalse const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + } + auto operator()(OpConstant const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpConstantComposite const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + for (auto const &op : in.op0()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpConstantSampler const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpConstantNull const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + } + auto operator()(OpFunction const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFunctionParameter const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + } + auto operator()(OpFunctionEnd const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpFunctionCall const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + for (auto const &op : in.op1()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpVariable const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + if (in.op1()) { + static_cast(this)->operator()(*in.op1()); + } + } + auto operator()(OpImageTexelPointer const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpLoad const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + if (in.op1()) { + static_cast(this)->operator()(*in.op1()); + } + } + auto operator()(OpStore const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + } + auto operator()(OpCopyMemory const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpCopyMemorySized const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + if (in.op4()) { + static_cast(this)->operator()(*in.op4()); + } + } + auto operator()(OpAccessChain const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + for (auto const &op : in.op1()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpInBoundsAccessChain const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + for (auto const &op : in.op1()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpPtrAccessChain const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto const &op : in.op2()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpArrayLength const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpGenericPtrMemSemantics const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpInBoundsPtrAccessChain const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto const &op : in.op2()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpDecorate const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpMemberDecorate const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpDecorationGroup const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpGroupDecorate const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + for (auto const &op : in.op1()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpGroupMemberDecorate const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + for (auto const &op : in.op1()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpVectorExtractDynamic const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpVectorInsertDynamic const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpVectorShuffle const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto const &op : in.op2()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpCompositeConstruct const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + for (auto const &op : in.op0()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpCompositeExtract const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + for (auto const &op : in.op1()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpCompositeInsert const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto const &op : in.op2()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpCopyObject const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpTranspose const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpSampledImage const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpImageSampleImplicitLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + } + auto operator()(OpImageSampleExplicitLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpImageSampleDrefImplicitLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpImageSampleDrefExplicitLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpImageSampleProjImplicitLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + } + auto operator()(OpImageSampleProjExplicitLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpImageSampleProjDrefImplicitLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpImageSampleProjDrefExplicitLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpImageFetch const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + } + auto operator()(OpImageGather const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpImageDrefGather const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpImageRead const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + } + auto operator()(OpImageWrite const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpImage const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpImageQueryFormat const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpImageQueryOrder const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpImageQuerySizeLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpImageQuerySize const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpImageQueryLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpImageQueryLevels const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpImageQuerySamples const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpConvertFToU const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpConvertFToS const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpConvertSToF const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpConvertUToF const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpUConvert const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpSConvert const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpFConvert const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpQuantizeToF16 const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpConvertPtrToU const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpSatConvertSToU const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpSatConvertUToS const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpConvertUToPtr const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpPtrCastToGeneric const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpGenericCastToPtr const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpGenericCastToPtrExplicit const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpBitcast const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpSNegate const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpFNegate const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpIAdd const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFAdd const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpISub const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFSub const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpIMul const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFMul const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpUDiv const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpSDiv const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFDiv const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpUMod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpSRem const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpSMod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFRem const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFMod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpVectorTimesScalar const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpMatrixTimesScalar const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpVectorTimesMatrix const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpMatrixTimesVector const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpMatrixTimesMatrix const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpOuterProduct const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpDot const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpIAddCarry const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpISubBorrow const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpUMulExtended const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpSMulExtended const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpAny const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpAll const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpIsNan const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpIsInf const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpIsFinite const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpIsNormal const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpSignBitSet const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpLessOrGreater const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpOrdered const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpUnordered const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpLogicalEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpLogicalNotEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpLogicalOr const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpLogicalAnd const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpLogicalNot const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpSelect const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpIEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpINotEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpUGreaterThan const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpSGreaterThan const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpUGreaterThanEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpSGreaterThanEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpULessThan const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpSLessThan const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpULessThanEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpSLessThanEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFOrdEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFUnordEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFOrdNotEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFUnordNotEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFOrdLessThan const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFUnordLessThan const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFOrdGreaterThan const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFUnordGreaterThan const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFOrdLessThanEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFUnordLessThanEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFOrdGreaterThanEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpFUnordGreaterThanEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpShiftRightLogical const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpShiftRightArithmetic const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpShiftLeftLogical const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpBitwiseOr const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpBitwiseXor const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpBitwiseAnd const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpNot const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpBitFieldInsert const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpBitFieldSExtract const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpBitFieldUExtract const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpBitReverse const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpBitCount const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpDPdx const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpDPdy const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpFwidth const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpDPdxFine const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpDPdyFine const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpFwidthFine const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpDPdxCoarse const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpDPdyCoarse const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpFwidthCoarse const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpEmitVertex const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpEndPrimitive const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpEmitStreamVertex const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpEndStreamPrimitive const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpControlBarrier const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpMemoryBarrier const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpAtomicLoad const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpAtomicStore const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpAtomicExchange const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpAtomicCompareExchange const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + } + auto operator()(OpAtomicCompareExchangeWeak const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + } + auto operator()(OpAtomicIIncrement const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpAtomicIDecrement const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpAtomicIAdd const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpAtomicISub const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpAtomicSMin const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpAtomicUMin const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpAtomicSMax const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpAtomicUMax const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpAtomicAnd const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpAtomicOr const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpAtomicXor const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpPhi const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + for (auto const &op : in.op0()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpLoopMerge const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpSelectionMerge const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpLabel const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpBranch const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpBranchConditional const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + for (auto const &op : in.op3()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpSwitch const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto const &op : in.op2()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpKill const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpReturn const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpReturnValue const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpUnreachable const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpLifetimeStart const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpLifetimeStop const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpGroupAsyncCopy const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + } + auto operator()(OpGroupWaitEvents const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupAll const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpGroupAny const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpGroupBroadcast const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupIAdd const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupFAdd const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupFMin const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupUMin const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupSMin const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupFMax const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupUMax const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupSMax const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpReadPipe const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpWritePipe const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpReservedReadPipe const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + } + auto operator()(OpReservedWritePipe const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + } + auto operator()(OpReserveReadPipePackets const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpReserveWritePipePackets const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpCommitReadPipe const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpCommitWritePipe const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpIsValidReserveId const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpGetNumPipePackets const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGetMaxPipePackets const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupReserveReadPipePackets const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + } + auto operator()(OpGroupReserveWritePipePackets const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + } + auto operator()(OpGroupCommitReadPipe const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + } + auto operator()(OpGroupCommitWritePipe const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + } + auto operator()(OpEnqueueMarker const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpEnqueueKernel const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->operator()(in.op6()); + static_cast(this)->operator()(in.op7()); + static_cast(this)->operator()(in.op8()); + static_cast(this)->operator()(in.op9()); + for (auto const &op : in.op10()) { + static_cast(this)->operator()(op); + } + } + auto operator()(OpGetKernelNDrangeSubGroupCount const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + } + auto operator()(OpGetKernelNDrangeMaxSubGroupSize const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + } + auto operator()(OpGetKernelWorkGroupSize const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpGetKernelPreferredWorkGroupSizeMultiple const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpRetainEvent const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpReleaseEvent const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpCreateUserEvent const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + } + auto operator()(OpIsValidEvent const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpSetUserEventStatus const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpCaptureEventProfilingInfo const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGetDefaultQueue const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + } + auto operator()(OpBuildNDRange const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpImageSparseSampleImplicitLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + } + auto operator()(OpImageSparseSampleExplicitLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpImageSparseSampleDrefImplicitLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpImageSparseSampleDrefExplicitLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpImageSparseSampleProjImplicitLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + } + auto operator()(OpImageSparseSampleProjExplicitLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpImageSparseSampleProjDrefImplicitLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpImageSparseSampleProjDrefExplicitLod const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpImageSparseFetch const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + } + auto operator()(OpImageSparseGather const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpImageSparseDrefGather const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpImageSparseTexelsResident const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpNoLine const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpAtomicFlagTestAndSet const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpAtomicFlagClear const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpImageSparseRead const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + } + auto operator()(OpSizeOf const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpTypePipeStorage const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpConstantPipeStorage const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpCreatePipeFromPipeStorage const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpGetKernelLocalSizeForSubgroupCount const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + } + auto operator()(OpGetKernelMaxNumSubgroups const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpTypeNamedBarrier const &in) { static_cast(this)->pre_visit(in); } + auto operator()(OpNamedBarrierInitialize const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpMemoryNamedBarrier const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpModuleProcessed const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpExecutionModeId const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpDecorateId const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpGroupNonUniformElect const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpGroupNonUniformAll const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpGroupNonUniformAny const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpGroupNonUniformAllEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpGroupNonUniformBroadcast const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupNonUniformBroadcastFirst const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpGroupNonUniformBallot const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpGroupNonUniformInverseBallot const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpGroupNonUniformBallotBitExtract const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupNonUniformBallotBitCount const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupNonUniformBallotFindLSB const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpGroupNonUniformBallotFindMSB const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpGroupNonUniformShuffle const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupNonUniformShuffleXor const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupNonUniformShuffleUp const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupNonUniformShuffleDown const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupNonUniformIAdd const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpGroupNonUniformFAdd const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpGroupNonUniformIMul const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpGroupNonUniformFMul const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpGroupNonUniformSMin const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpGroupNonUniformUMin const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpGroupNonUniformFMin const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpGroupNonUniformSMax const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpGroupNonUniformUMax const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpGroupNonUniformFMax const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpGroupNonUniformBitwiseAnd const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpGroupNonUniformBitwiseOr const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpGroupNonUniformBitwiseXor const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpGroupNonUniformLogicalAnd const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpGroupNonUniformLogicalOr const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpGroupNonUniformLogicalXor const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpGroupNonUniformQuadBroadcast const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpGroupNonUniformQuadSwap const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + } + auto operator()(OpCopyLogical const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } + auto operator()(OpPtrEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpPtrNotEqual const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpPtrDiff const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + } + auto operator()(OpTypeCooperativeMatrixKHR const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + } + auto operator()(OpCooperativeMatrixLoadKHR const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpCooperativeMatrixStoreKHR const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + if (in.op4()) { + static_cast(this)->operator()(*in.op4()); + } + } + auto operator()(OpCooperativeMatrixMulAddKHR const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + } + auto operator()(OpCooperativeMatrixLengthKHR const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + } +}; + +} // namespace tinytc::spv + +#endif // GENERATED_VISIT_2024114_HPP diff --git a/tools/offline_compiler/main.cpp b/tools/offline_compiler/main.cpp index 75b5af72..360dcc6f 100644 --- a/tools/offline_compiler/main.cpp +++ b/tools/offline_compiler/main.cpp @@ -3,10 +3,12 @@ #include "argparser.hpp" #include "argparser_common.hpp" +#include "support/fnv1a.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" #include +#include #include #include #include @@ -14,14 +16,31 @@ using namespace tinytc; +enum class generator { opencl, spirv }; + int main(int argc, char **argv) { char const *filename = nullptr; auto info = core_info{}; tinytc_core_feature_flags_t core_features = 0; std::int32_t opt_level = 2; auto flags = cmd::optflag_states{}; + auto gen = generator::opencl; bool help = false; + auto const convert_string_to_generator = [](char const *str, generator &val) { + switch (fnv1a(str, std::strlen(str))) { + case "opencl"_fnv1a: + val = generator::opencl; + break; + case "spirv"_fnv1a: + val = generator::spirv; + break; + default: + return cmd::parser_status::invalid_argument; + }; + return cmd::parser_status::success; + }; + auto parser = cmd::arg_parser{}; try { info = make_core_info_intel_from_arch(intel_gpu_architecture::pvc); @@ -38,6 +57,8 @@ int main(int argc, char **argv) { } return cmd::parser_status::success; }); + parser.set_short_opt('g', &gen, "Code generation backend (opencl or spirv)") + .converter(convert_string_to_generator); parser.set_short_opt('h', &help, "Show help"); parser.set_long_opt("help", &help, "Show help"); parser.add_positional_arg("file-name", &filename, @@ -78,8 +99,14 @@ int main(int argc, char **argv) { p = parse_file(filename, ctx); } - auto src = compile_to_opencl(std::move(p), info); - std::cout << src.get_code(); + switch (gen) { + case generator::opencl: + std::cout << compile_to_opencl(std::move(p), info).get_code(); + break; + case generator::spirv: + compile_to_spirv(std::move(p), info); + break; + } } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; return 1; diff --git a/tools/spirvgen/filter.json b/tools/spirvgen/filter.json new file mode 100644 index 00000000..78f52aef --- /dev/null +++ b/tools/spirvgen/filter.json @@ -0,0 +1,11 @@ +{ + "copyright" : [ + "Copyright (C) 2024 Intel Corporation", + "SPDX-License-Identifier: BSD-3-Clause" + ], + "include" : [ + [0, 47], + [53, 999], + [4456, 4460] + ] +} diff --git a/tools/spirvgen/spirvgen.py b/tools/spirvgen/spirvgen.py new file mode 100755 index 00000000..160aeb21 --- /dev/null +++ b/tools/spirvgen/spirvgen.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# +# Very simple and stupid script to generate SPIR-V classes +# + +import argparse +import datetime +import json +import os +import shutil +import subprocess + +spv_enums = 'enums.hpp' +spv_names = 'names.hpp' +spv_names_cpp = 'names.cpp' +spv_names_cpp_includes = [spv_names, spv_enums] +spv_ops = 'instructions.hpp' +spv_visitor = 'visit.hpp' +spv_ops_includes = [ + spv_enums, 'error.hpp', 'support/ilist_base.hpp', None, '', + '', '', '', '', '', '' +] + +enumerant_subs = { + '1D': 'Dim1D', + '2D': 'Dim2D', + '3D': 'Dim3D', + '2x2': 'CooperativeMatrixReduce2x2' +} + +spv_inst_class = """ +class spv_inst : public ilist_node { + public: + inline spv_inst(Op opcode, bool has_result_id) : opcode_{opcode}, has_result_id_{has_result_id} {} + virtual ~spv_inst() = default; + + spv_inst(spv_inst const &other) = delete; + spv_inst(spv_inst &&other) = delete; + spv_inst &operator=(spv_inst const &other) = delete; + spv_inst &operator=(spv_inst &&other) = delete; + + inline auto opcode() const -> Op { return opcode_; } + inline auto has_result_id() const -> bool { return has_result_id_; } + + private: + Op opcode_; + bool has_result_id_; +}; + +using DecorationAttr = std::variant>; +using ExecutionModeAttr = std::variant>; +using LiteralContextDependentNumber + = std::variant; +using LiteralString = std::string; +using LiteralInteger = std::int32_t; +using LiteralExtInstInteger = std::int32_t; +using IdResultType = spv_inst*; +using IdRef = spv_inst*; +using IdScope = spv_inst*; +using IdMemorySemantics = spv_inst*; +using PairIdRefIdRef = std::pair; +using PairLiteralIntegerIdRef + = std::pair, spv_inst*>; +using PairIdRefLiteralInteger = std::pair; +""" + + +def get_opcode_name(instruction): + return instruction['opname'][2:] + + +def get_class_name(instruction): + return instruction['opname'] + + +def generate_enums(f, grammar): + print('enum class Op {', file=f) + for inst in grammar['instructions']: + print(f'{get_opcode_name(inst)} = {inst["opcode"]},', file=f) + print('};', file=f) + + for opkind in grammar['operand_kinds']: + category = opkind['category'] + if category != 'BitEnum' and category != 'ValueEnum': + continue + print(f'enum class {opkind["kind"]} {{', file=f) + for enumerant in opkind['enumerants']: + name = enumerant["enumerant"] + print(f'{enumerant_subs.get(name, name)} = {enumerant["value"]},', + file=f) + print('};', file=f) + + +def generate_names(f, grammar): + print('auto to_string(Op op) -> char const*;', file=f) + + for opkind in grammar['operand_kinds']: + category = opkind['category'] + if category != 'BitEnum' and category != 'ValueEnum': + continue + print(f'auto to_string({opkind["kind"]} e) -> char const*;', file=f) + + +def generate_names_cpp(f, grammar): + print('auto to_string(Op op) -> char const* { switch(op) {', file=f) + for inst in grammar['instructions']: + name = get_opcode_name(inst) + print(f'case Op::{name}: return "{name}";', file=f) + print('} return "unknown";}', file=f) + + for opkind in grammar['operand_kinds']: + category = opkind['category'] + if category != 'BitEnum' and category != 'ValueEnum': + continue + print( + f'auto to_string({opkind["kind"]} e) -> char const* {{ switch(e) {{', + file=f) + for enumerant in opkind['enumerants']: + name = enumerant["enumerant"] + name = enumerant_subs.get(name, name) + print(f'case {opkind["kind"]}::{name}: return "{name}";', file=f) + print('} return "unknown";}', file=f) + + +def get_kind(operand): + kind = operand['kind'] + quant = operand.get('quantifier') + if quant: + if quant == '?': + return f'std::optional<{kind}>' + elif quant == '*': + return f'std::vector<{kind}>' + else: + raise NotImplementedError + return kind + + +def has_result_id(instruction): + for operand in instruction.get('operands', []): + if operand['kind'] == 'IdResult': + return True + return False + + +class Operand: + + def __init__(self, name, kind, quantifier): + self.name = name + self.kind = kind + self.quantifier = quantifier + + +def get_operands(instruction): + operands = [] + opno = 0 + for num, operand in enumerate(instruction.get('operands', [])): + if operand['kind'] == 'IdResult': + pass + elif operand['kind'] == 'IdResultType': + operands.append(Operand('type', get_kind(operand), '')) + else: + operands.append( + Operand(f'op{opno}', get_kind(operand), + operand.get('quantifier', ''))) + opno = opno + 1 + return operands + + +def generate_op_classes(f, grammar): + print(spv_inst_class, file=f) + + for instruction in grammar['instructions']: + operands = get_operands(instruction) + + print(f'class {get_class_name(instruction)} : public spv_inst {{', + file=f) + print(f'public:', file=f) + print( + f'inline static bool classof(spv_inst const& s) {{ return s.opcode() == Op::{get_opcode_name(instruction)};}}', + file=f) + if 'capabilities' in instruction: + caps = instruction['capabilities'] + cap_str = ','.join([f'Capability::{cap}' for cap in caps]) + print( + f'constexpr static std::array required_capabilities = {{{cap_str}}};', + file=f) + f.write(f'{get_class_name(instruction)}(') + f.write(','.join([f'{o.kind} {o.name}' for o in operands])) + f.write(') : ') + initializer_list = [ + f'spv_inst{{Op::{get_opcode_name(instruction)}, {"true" if has_result_id(instruction) else "false"}}}' + ] + initializer_list += [ + f'{o.name}_(std::move({o.name}))' for o in operands + ] + f.write(','.join(initializer_list)) + f.write('{}') + for o in operands: + print( + f'inline auto {o.name}() const -> {o.kind} const& {{ return {o.name}_; }}', + file=f) + print(f'private:', file=f) + for o in operands: + print(f'{o.kind} {o.name}_;', file=f) + print('};', file=f) + + +def generate_visitor(f, grammar): + format_call = lambda op: f'static_cast(this)->operator()({op});' + + print("""template struct overloaded : Ts... { + using Ts::operator()...; +}; +template overloaded(Ts...) -> overloaded; + + template auto visit(Visitor&& visitor, spv_inst const& inst) { + switch (inst.opcode()) {""", + file=f) + for instruction in grammar['instructions']: + print(f"""case Op::{get_opcode_name(instruction)}: + return visitor(static_cast<{get_class_name(instruction)} const&>(inst));""", + file=f) + print("""} + throw internal_compiler_error(); +}""", file=f) + + print('template class default_visitor { public:', + file=f) + print('auto pre_visit(spv_inst const&) {}', file=f) + for instruction in grammar['instructions']: + print( + f"""auto operator()({get_class_name(instruction)} const& in) {{""", + file=f) + print(f'static_cast(this)->pre_visit(in);', file=f) + for o in get_operands(instruction): + if o.quantifier == '*': + print(f"""for (auto const& op : in.{o.name}()) {{ + {format_call('op')} +}} +""", + file=f) + elif o.quantifier == '?': + print(f"""if (in.{o.name}()) {{ + {format_call(f'*in.{o.name}()')} +}} +""", + file=f) + else: + print(format_call(f'in.{o.name}()'), file=f) + print('}', file=f) + print('};', file=f) + + +def print_includes(f, includes): + for include in includes: + if include: + if include[0] != '<' and include[0] != '"': + print(f'#include "{include}"', file=f) + else: + print(f'#include {include}', file=f) + else: + print('', file=f) + + +def generate_header(args, filename, grammar, generator, includes=[]): + filename = os.path.join(args.o, filename) + with open(filename, 'w') as f: + now = datetime.datetime.now() + basename = os.path.splitext(os.path.basename(filename))[0].upper() + headerguard_name = f'GENERATED_{basename}_{now.year}{now.month}{now.day}_HPP' + + print(f"""// Copyright (C) {now.year} Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef {headerguard_name} +#define {headerguard_name} + +""", + file=f) + print_includes(f, includes) + print(""" +namespace tinytc::spv { +""", file=f) + + generator(f, grammar) + + print(f""" +}} + +#endif // {headerguard_name} +""", file=f) + + subprocess.call([args.c, '-i', filename]) + + +def generate_cpp(args, filename, grammar, generator, includes=[]): + filename = os.path.join(args.o, filename) + with open(filename, 'w') as f: + now = datetime.datetime.now() + print(f"""// Copyright (C) {now.year} Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +""", + file=f) + print_includes(f, includes) + print(""" +namespace tinytc::spv { +""", file=f) + + generator(f, grammar) + + print(f""" +}} + +#endif // {headerguard_name} +""", file=f) + + subprocess.call([args.c, '-i', filename]) + + +def filter_grammar(grammar, filt): + filtered_instructions = [] + for instruction in grammar['instructions']: + opcode = instruction['opcode'] + for i in filt['include']: + if i[0] <= opcode and opcode <= i[1]: + filtered_instructions.append(instruction) + grammar['instructions'] = filtered_instructions + return grammar + + +def patch_grammar(grammar): + for instruction in grammar['instructions']: + if instruction['opname'] == 'OpDecorate': + if instruction['operands'][-1]['kind'] == 'Decoration': + instruction['operands'].append({'kind': 'DecorationAttr'}) + elif instruction['opname'] == 'OpExecutionMode': + if instruction['operands'][-1]['kind'] == 'ExecutionMode': + instruction['operands'].append({'kind': 'ExecutionModeAttr'}) + return grammar + + +if __name__ == '__main__': + script_dir = os.path.dirname(os.path.realpath(__file__)) + parser = argparse.ArgumentParser() + parser.add_argument('-c', + help='clang-format binary', + default='clang-format'), + parser.add_argument('-f', + help='Filter JSON file', + default=os.path.join(script_dir, 'filter.json')), + parser.add_argument('-o', help='output directory', default=''), + parser.add_argument( + 'grammar', + help='spirv.core.grammar.json file from SPIRV-Headers project') + args = parser.parse_args() + + if shutil.which(args.c): + grammar = dict() + filt = dict() + with open(args.grammar) as f: + grammar = json.load(f) + with open(args.f) as f: + filt = json.load(f) + + grammar = filter_grammar(grammar, filt) + grammar = patch_grammar(grammar) + generate_header(args, spv_enums, grammar, generate_enums) + generate_header(args, spv_names, grammar, generate_names) + generate_header(args, spv_names_cpp, grammar, generate_names_cpp, + spv_names_cpp_includes) + generate_header(args, spv_ops, grammar, generate_op_classes, + spv_ops_includes) + generate_header(args, spv_visitor, grammar, generate_visitor) + else: + print(f'Could not find clang-format: {args.c}') From 7432503fcfd5d0c40baeb3a78290d72d7200648d Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 5 Nov 2024 15:44:43 +0100 Subject: [PATCH 085/297] Convert to SPIR-V: Arith Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 4 + include/tinytc/types.h | 4 +- include/tinytc/types.hpp | 2 + src/CMakeLists.txt | 1 + src/error.cpp | 4 + src/parser/parse_context.cpp | 9 + src/parser/parse_context.hpp | 3 + src/parser/parser_impl.yy | 1 + src/pass/convert_to_spirv.cpp | 385 ++++++++++++++++++++++++++----- src/spv/enums.hpp | 6 +- src/spv/instructions.hpp | 6 +- src/spv/module.cpp | 16 +- src/spv/module.hpp | 23 +- src/spv/names.cpp | 5 - src/spv/names.hpp | 8 +- src/spv/opencl.std.cpp | 341 +++++++++++++++++++++++++++ src/spv/opencl.std.hpp | 183 +++++++++++++++ src/spv/pass/dump_asm.cpp | 45 +++- src/spv/pass/dump_asm.hpp | 4 +- src/spv/visit.hpp | 6 +- test/spv/arith.ir | 91 ++++++++ test/spv/arith_unary.ir | 60 +++++ test/spv/unique_function_type.ir | 15 ++ tools/spirvgen/ext_opencl.py | 67 ++++++ tools/spirvgen/gen.py | 78 +++++++ tools/spirvgen/spirvgen.py | 84 +------ 26 files changed, 1275 insertions(+), 176 deletions(-) create mode 100644 src/spv/opencl.std.cpp create mode 100644 src/spv/opencl.std.hpp create mode 100644 test/spv/arith.ir create mode 100644 test/spv/arith_unary.ir create mode 100644 test/spv/unique_function_type.ir create mode 100755 tools/spirvgen/ext_opencl.py create mode 100644 tools/spirvgen/gen.py diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index fc107514..cf544d01 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -120,6 +120,10 @@ is always a multiple of the subgroup size. The subgroup size attribute enforces a particular subgroup device supported by the device. +Restrictions +------------ + +Arguments must not have coopmatrix type. Regions ======= diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 2305db1f..81f9d50a 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -41,6 +41,7 @@ typedef enum { tinytc_status_unsupported_device = 0xe, ///< Unsupported device tinytc_status_invalid_core_info = 0xf, ///< Invalid core info object tinytc_status_unknown_pass_name = 0x10, ///< Invalid compiler pass name + tinytc_status_not_implemented = 0x11, ///< Not implemented // IR errors tinytc_status_ir_out_of_bounds = 0x100, ///< Out of bounds access tinytc_status_ir_invalid_shape = 0x101, ///< Invalid tensor shape @@ -83,7 +84,8 @@ typedef enum { tinytc_status_ir_incompatible_scalar_types = 0x124, ///< Incompatible scalar types // SPIR-V errors tinytc_status_spirv_forbidden_forward_declaration = - 0x1000, ///< Forward declaration of id is forbidden + 0x1000, ///< Forward declaration of id is forbidden + tinytc_status_spirv_undefined_value = 0x1001, ///< Undefined value // Level zero errors tinytc_status_ze_result_not_ready = 0x10000, ///< ZE_RESULT_NOT_READY tinytc_status_ze_result_error_device_lost = 0x10001, ///< ZE_RESULT_ERROR_DEVICE_LOST diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 766a6fb8..6640d065 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -51,6 +51,7 @@ enum class status { unsupported_device = tinytc_status_unsupported_device, invalid_core_info = tinytc_status_invalid_core_info, unknown_pass_name = tinytc_status_unknown_pass_name, + not_implemented = tinytc_status_not_implemented, // IR errors ir_out_of_bounds = tinytc_status_ir_out_of_bounds, ir_invalid_shape = tinytc_status_ir_invalid_shape, @@ -90,6 +91,7 @@ enum class status { ir_unsupported_coopmatrix_shape = tinytc_status_ir_unsupported_coopmatrix_shape, ir_incompatible_scalar_types = tinytc_status_ir_incompatible_scalar_types, spirv_forbidden_forward_declaration = tinytc_status_spirv_forbidden_forward_declaration, + spirv_undefined_value = tinytc_status_spirv_undefined_value, ze_result_not_ready = tinytc_status_ze_result_not_ready, ze_result_error_device_lost = tinytc_status_ze_result_error_device_lost, ze_result_error_out_of_host_memory = tinytc_status_ze_result_error_out_of_host_memory, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 89f863f3..70a480e5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -65,6 +65,7 @@ set(SOURCES scalar_type.cpp spv/module.cpp spv/names.cpp + spv/opencl.std.cpp spv/pass/dump_asm.cpp source.cpp tiling.cpp diff --git a/src/error.cpp b/src/error.cpp index 55deb3be..6524963f 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -117,6 +117,8 @@ char const *tinytc_error_string(tinytc_status_t status) { "is empty)"; case tinytc_status_unknown_pass_name: return "Unknown compiler pass name"; + case tinytc_status_not_implemented: + return "Not implemented"; // IR case tinytc_status_ir_out_of_bounds: return "Argument is out of bounds"; @@ -196,6 +198,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Scalar types violate compatibility rules"; case tinytc_status_spirv_forbidden_forward_declaration: return "Forward declaration of id is forbidden"; + case tinytc_status_spirv_undefined_value: + return "Undefined SPIR-V value"; // Level Zero case tinytc_status_ze_result_not_ready: return "ZE_RESULT_NOT_READY"; diff --git a/src/parser/parse_context.cpp b/src/parser/parse_context.cpp index 09c643cf..0da2aff7 100644 --- a/src/parser/parse_context.cpp +++ b/src/parser/parse_context.cpp @@ -29,6 +29,15 @@ void parse_context::pop_region() { regions_.pop(); } auto parse_context::top_region() -> tinytc_region_t { return regions_.top(); } auto parse_context::has_regions() -> bool { return !regions_.empty(); } +void parse_context::add_global_name(std::string const &name, location const &l) { + if (auto other = global_names_.find(name); other != global_names_.end()) { + auto oss = std::ostringstream{}; + oss << "Identifier @" << name << " was already used at " << other->second; + throw parser::syntax_error(l, std::move(oss).str()); + } + global_names_[name] = l; +} + void parse_context::val(std::variant const &id, tinytc_value &val, location const &l) { const auto handle_val = diff --git a/src/parser/parse_context.hpp b/src/parser/parse_context.hpp index 4c511d01..afbb146a 100644 --- a/src/parser/parse_context.hpp +++ b/src/parser/parse_context.hpp @@ -43,11 +43,14 @@ class parse_context { auto top_region() -> tinytc_region_t; auto has_regions() -> bool; + void add_global_name(std::string const &name, location const &l); + private: compiler_context compiler_ctx_; std::vector> unnamed_id_map_; std::vector> named_id_map_; std::stack regions_; + std::unordered_map global_names_; prog program_; }; diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index e56ce9f9..733bd4d7 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -257,6 +257,7 @@ func: auto loc = @FUNC; loc.end = @RPAREN.end; try { + ctx.add_global_name($GLOBAL_IDENTIFIER, loc); auto func_node = std::make_unique($GLOBAL_IDENTIFIER, $parameters.second, loc); for (auto &attr : $attributes) { diff --git a/src/pass/convert_to_spirv.cpp b/src/pass/convert_to_spirv.cpp index 2e5e5179..9450ef70 100644 --- a/src/pass/convert_to_spirv.cpp +++ b/src/pass/convert_to_spirv.cpp @@ -7,28 +7,40 @@ #include "node/inst_node.hpp" #include "node/region_node.hpp" #include "node/value_node.hpp" +#include "scalar_type.hpp" #include "spv/instructions.hpp" +#include "spv/opencl.std.hpp" +#include "support/fnv1a.hpp" #include "support/visit.hpp" +#include #include +#include namespace tinytc { class spirv_converter { public: - inline spirv_converter(spv::mod &mod, tinytc_compiler_context_t ctx) : mod_(&mod), ctx_(ctx) {} + inline spirv_converter(::tinytc_core_info const *info, spv::mod &mod, + tinytc_compiler_context_t ctx) + : info_(info), mod_(&mod), ctx_(ctx) {} auto operator()(data_type_node const &ty) -> spv::spv_inst *; // Instruction nodes void operator()(inst_node const &in); void operator()(arith_inst const &in); + void operator()(arith_unary_inst const &in); + void operator()(constant_inst const &in); void run_on_program(program_node const &p); private: auto declare(value_node const &v, spv::spv_inst *in); auto val(value_node const &v) -> spv::spv_inst *; + auto multi_declare(value_node const &v, std::vector insts); + auto multi_val(value_node const &v) -> std::vector &; + auto declare_function_type(std::vector params) -> spv::spv_inst *; void run_on_region(region_node const &fn); void run_on_function(function_node const &fn); template auto add_to(Args &&...args) -> T * { @@ -41,10 +53,16 @@ class spirv_converter { return add_to(std::forward(args)...); } + ::tinytc_core_info const *info_; spv::mod *mod_; tinytc_compiler_context_t ctx_; std::unordered_map spv_tys_; std::unordered_map vals_; + std::unordered_map> multi_vals_; + std::unordered_set capabilities_; + std::unordered_multimap function_tys_; + spv::spv_inst *opencl_ext_ = nullptr; + core_config core_cfg_ = {}; }; auto spirv_converter::declare(value_node const &v, spv::spv_inst *in) { vals_[&v] = in; } @@ -52,7 +70,35 @@ auto spirv_converter::val(value_node const &v) -> spv::spv_inst * { if (auto it = vals_.find(&v); it != vals_.end()) { return it->second; } - throw status::internal_compiler_error; + throw compilation_error(v.loc(), status::spirv_undefined_value); +} +auto spirv_converter::multi_declare(value_node const &v, std::vector insts) { + multi_vals_[&v] = std::move(insts); +} +auto spirv_converter::multi_val(value_node const &v) -> std::vector & { + if (auto it = multi_vals_.find(&v); it != multi_vals_.end()) { + return it->second; + } + throw compilation_error(v.loc(), status::spirv_undefined_value); +} +auto spirv_converter::declare_function_type(std::vector params) + -> spv::spv_inst * { + auto map_key = fnv1a0(); + for (auto const &p : params) { + map_key = fnv1a_step(map_key, p); + } + auto range = function_tys_.equal_range(map_key); + for (auto it = range.first; it != range.second; ++it) { + if (std::equal(params.begin(), params.end(), it->second->op1().begin(), + it->second->op1().end())) { + return it->second; + } + } + auto void_ty = visit(*this, *void_data_type::get(ctx_)); + return function_tys_ + .emplace(map_key, add_to( + void_ty, std::move(params))) + ->second; } auto spirv_converter::operator()(data_type_node const &ty) -> spv::spv_inst * { @@ -61,44 +107,53 @@ auto spirv_converter::operator()(data_type_node const &ty) -> spv::spv_inst * { auto spv_ty = visit( overloaded{ [&](void_data_type const &) -> spv::spv_inst * { - return add_to(); + return add_to(); }, [&](scalar_data_type const &ty) -> spv::spv_inst * { switch (ty.ty()) { case scalar_type::i1: - return add_to(); + return add_to(); case scalar_type::i8: - add_to(spv::Capability::Int8); - return add_to(8, 1); + capabilities_.insert(spv::Capability::Int8); + return add_to(8, 1); case scalar_type::i16: - add_to(spv::Capability::Int16); - return add_to(16, 1); + capabilities_.insert(spv::Capability::Int16); + return add_to(16, 1); case scalar_type::i32: + return add_to(32, 1); case scalar_type::i64: - case scalar_type::index: - return add_to(size(ty.ty()) * 8, 1); + capabilities_.insert(spv::Capability::Int64); + return add_to(64, 1); + case scalar_type::index: { + const auto sz = size(ty.ty()); + if (sz == 8) { + capabilities_.insert(spv::Capability::Int64); + } + return add_to(sz * 8, 1); + } case scalar_type::f32: case scalar_type::f64: - return add_to(size(ty.ty()) * 8, - std::nullopt); - case scalar_type::c32: + return add_to( + size(ty.ty()) * 8, std::nullopt); + case scalar_type::c32: { + auto float_ty = + visit(*this, *scalar_data_type::get(ctx_, scalar_type::f32)); + return add_to(float_ty, 2); + } case scalar_type::c64: { - auto float_ty = add_to( - size(ty.ty()) * 8 / 2, std::nullopt); - return add_to(float_ty, 2); + auto float_ty = + visit(*this, *scalar_data_type::get(ctx_, scalar_type::f64)); + return add_to(float_ty, 2); } } throw status::internal_compiler_error; }, [&](coopmatrix_data_type const &ty) -> spv::spv_inst * { - // @todo - throw status::internal_compiler_error; - return nullptr; + return visit(*this, *ty.ty()); }, [](auto const &) -> spv::spv_inst * { // @todo - throw status::internal_compiler_error; - return nullptr; + throw status::not_implemented; }}, ty); spv_tys_[&ty] = spv_ty; @@ -107,9 +162,9 @@ auto spirv_converter::operator()(data_type_node const &ty) -> spv::spv_inst * { return it->second; } -void spirv_converter::operator()(inst_node const &) { +void spirv_converter::operator()(inst_node const &in) { // @todo - throw status::internal_compiler_error; + throw compilation_error(in.loc(), status::not_implemented); } void spirv_converter::operator()(arith_inst const &in) { @@ -125,7 +180,7 @@ void spirv_converter::operator()(arith_inst const &in) { default: break; } - throw status::ir_i1_unsupported; + throw compilation_error(in.loc(), status::ir_i1_unsupported); }; auto const make_int = [&](arithmetic op, spv::spv_inst *ty, spv::spv_inst *a, spv::spv_inst *b) -> spv::spv_inst * { @@ -151,7 +206,7 @@ void spirv_converter::operator()(arith_inst const &in) { case arithmetic::xor_: return add(ty, a, b); } - throw status::internal_compiler_error; + throw compilation_error(in.loc(), status::internal_compiler_error); }; auto const make_float_complex = [&](arithmetic op, spv::spv_inst *ty, spv::spv_inst *a, spv::spv_inst *b) -> spv::spv_inst * { @@ -169,7 +224,7 @@ void spirv_converter::operator()(arith_inst const &in) { default: break; } - throw status::ir_fp_unsupported; + throw compilation_error(in.loc(), status::ir_fp_unsupported); }; auto const make = [&](scalar_type sty, arithmetic op, spv::spv_inst *ty, spv::spv_inst *a, spv::spv_inst *b) -> spv::spv_inst * { @@ -188,26 +243,229 @@ void spirv_converter::operator()(arith_inst const &in) { case scalar_type::c64: return make_float_complex(op, ty, a, b); } - throw status::internal_compiler_error; + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + + auto ty = visit(*this, *in.result(0).ty()); + + if (auto st = dyn_cast(in.result(0).ty()); st) { + auto av = val(in.a()); + auto bv = val(in.b()); + declare(in.result(0), make(st->ty(), in.operation(), ty, av, bv)); + } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { + auto const length = ct->length(core_cfg_.subgroup_size); + auto insts = std::vector{}; + insts.reserve(length); + + auto &av = multi_val(in.a()); + auto &bv = multi_val(in.b()); + for (std::int64_t i = 0; i < length; ++i) { + insts.emplace_back(make(ct->component_ty(), in.operation(), ty, av[i], bv[i])); + } + + multi_declare(in.result(0), std::move(insts)); + } else { + throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); + } +} + +void spirv_converter::operator()(arith_unary_inst const &in) { + auto const make_boolean = [&](arithmetic_unary op, spv::spv_inst *ty, + spv::spv_inst *a) -> spv::spv_inst * { + switch (op) { + case arithmetic_unary::not_: + return add(ty, a); + default: + break; + } + throw compilation_error(in.loc(), status::ir_i1_unsupported); + }; + auto const make_int = [&](arithmetic_unary op, spv::spv_inst *ty, + spv::spv_inst *a) -> spv::spv_inst * { + switch (op) { + case arithmetic_unary::abs: + return add(ty, opencl_ext_, + static_cast(spv::OpenCLEntrypoint::s_abs), + std::vector{a}); + case arithmetic_unary::neg: + return add(ty, a); + case arithmetic_unary::not_: + return add(ty, a); + default: + break; + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + auto const make_float = [&](arithmetic_unary op, spv::spv_inst *ty, + spv::spv_inst *a) -> spv::spv_inst * { + switch (op) { + case arithmetic_unary::abs: + return add(ty, opencl_ext_, + static_cast(spv::OpenCLEntrypoint::fabs), + std::vector{a}); + case arithmetic_unary::neg: + return add(ty, a); + default: + break; + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + auto const make_complex = [&](arithmetic_unary op, scalar_type sty, spv::spv_inst *ty, + spv::spv_inst *a) -> spv::spv_inst * { + switch (op) { + case arithmetic_unary::abs: { + auto spv_a_ty = visit(*this, *scalar_data_type::get(ctx_, sty)); + auto a2 = add(spv_a_ty, a, a); + auto a2_0 = add(ty, a2, std::vector{0}); + auto a2_1 = add(ty, a2, std::vector{1}); + auto a2_0p1 = add(ty, a2_0, a2_1); + return add(ty, opencl_ext_, + static_cast(spv::OpenCLEntrypoint::sqrt), + std::vector{a2_0p1}); + } + case arithmetic_unary::neg: + return add(ty, a); + case arithmetic_unary::conj: { + auto spv_float_ty = visit(*this, *scalar_data_type::get(ctx_, element_type(sty))); + auto a_im = + add(spv_float_ty, a, std::vector{1}); + auto neg_a_im = add(spv_float_ty, a_im); + return add(ty, neg_a_im, a, + std::vector{1}); + } + case arithmetic_unary::im: + return add(ty, a, std::vector{1}); + case arithmetic_unary::re: + return add(ty, a, std::vector{0}); + default: + break; + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + auto const make = [&](scalar_type sty, arithmetic_unary op, spv::spv_inst *ty, + spv::spv_inst *a) -> spv::spv_inst * { + switch (sty) { + case scalar_type::i1: + return make_boolean(op, ty, a); + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return make_int(op, ty, a); + case scalar_type::f32: + case scalar_type::f64: + return make_float(op, ty, a); + case scalar_type::c32: + case scalar_type::c64: { + return make_complex(op, sty, ty, a); + } + } + throw compilation_error(in.loc(), status::internal_compiler_error); }; auto ty = visit(*this, *in.result(0).ty()); - auto av = val(in.a()); - auto bv = val(in.b()); + + if (auto st = dyn_cast(in.a().ty()); st) { + auto av = val(in.a()); + declare(in.result(0), make(st->ty(), in.operation(), ty, av)); + } else if (auto ct = dyn_cast(in.a().ty()); ct) { + auto const length = ct->length(core_cfg_.subgroup_size); + auto insts = std::vector{}; + insts.reserve(length); + + auto &av = multi_val(in.a()); + for (std::int64_t i = 0; i < length; ++i) { + insts.emplace_back(make(ct->component_ty(), in.operation(), ty, av[i])); + } + + multi_declare(in.result(0), std::move(insts)); + } else { + throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); + } +} + +void spirv_converter::operator()(constant_inst const &in) { + auto const make = [&](scalar_type sty, spv::spv_inst *spv_ty, + constant_inst::value_type const &val) -> spv::spv_inst * { + auto const add_constant_bool = [this, &spv_ty](bool val) -> spv::spv_inst * { + if (val) { + return add_to(spv_ty); + } + return add_to(spv_ty); + }; + auto const add_constant = [this, &spv_ty](auto val) -> spv::spv_inst * { + return add_to(spv_ty, val); + }; + auto const add_constant_complex = [this, &spv_ty](spv::spv_inst *spv_float_ty, auto re, + auto im) -> spv::spv_inst * { + auto c_re = add_to(spv_float_ty, re); + auto c_im = add_to(spv_float_ty, im); + return add_to( + spv_ty, std::vector{c_re, c_im}); + }; + const auto visitor = overloaded{ + [&](std::int64_t i) -> spv::spv_inst * { + switch (sty) { + case scalar_type::i1: + return add_constant_bool(i != 0); + case scalar_type::i8: + return add_constant(static_cast(i)); + case scalar_type::i16: + return add_constant(static_cast(i)); + case scalar_type::i32: + return add_constant(static_cast(i)); + case scalar_type::i64: + case scalar_type::index: + return add_constant(i); + default: + return nullptr; + } + }, + [&](double d) -> spv::spv_inst * { + switch (sty) { + case scalar_type::f32: + return add_constant(static_cast(d)); + case scalar_type::f64: + return add_constant(d); + default: + return nullptr; + } + }, + [&](std::complex d) -> spv::spv_inst * { + switch (sty) { + case scalar_type::c32: { + auto spv_float_ty = + visit(*this, *scalar_data_type::get(ctx_, scalar_type::f32)); + return add_constant_complex(spv_float_ty, static_cast(d.real()), + static_cast(d.imag())); + } + case scalar_type::c64: { + auto spv_float_ty = + visit(*this, *scalar_data_type::get(ctx_, scalar_type::f64)); + return add_constant_complex(spv_float_ty, d.real(), d.imag()); + } + default: + return nullptr; + } + }, + }; + auto cst = std::visit(visitor, val); + if (cst == nullptr) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + return cst; + }; + + auto spv_ty = visit(*this, *in.result(0).ty()); if (auto st = dyn_cast(in.result(0).ty()); st) { - make(st->ty(), in.operation(), ty, av, bv); + declare(in.result(0), make(st->ty(), spv_ty, in.value())); } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { - // auto clinst = std::vector{}; - // auto const len = ct->length(core_cfg_.subgroup_size); - // clinst.reserve(len + 1); - // clinst.emplace_back(declaration(std::move(lhs_ty), lhs)); - // const auto sty = ct->component_ty(); - // for (std::int64_t i = 0; i < len; ++i) { - // auto op = make(a.operation(), av[i], bv[i], sty); - // clinst.emplace_back(expression_statement(assignment(lhs[i], std::move(op)))); - //} - // return clinst; + auto const length = ct->length(core_cfg_.subgroup_size); + auto cst = make(ct->component_ty(), spv_ty, in.value()); + + multi_declare(in.result(0), std::vector(length, cst)); } else { throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); } @@ -218,19 +476,29 @@ void spirv_converter::run_on_region(region_node const ®) { for (auto const &i : reg) { visit(*this, i); } + add(); } void spirv_converter::run_on_function(function_node const &fn) { - // Function type - auto void_ty = visit(*this, *void_data_type::get(ctx_)); - auto params = std::vector{}; - params.reserve(fn.num_params()); - for (auto const &p : fn.params()) { - params.push_back(visit(*this, *p.ty())); + auto const subgroup_size = fn.subgroup_size(); + try { + core_cfg_ = info_->get_core_config(subgroup_size); + } catch (std::out_of_range const &e) { + throw compilation_error(fn.loc(), status::unsupported_subgroup_size); } - auto fun_ty = add(void_ty, std::move(params)); + + // Function type + auto fun_ty = declare_function_type([&] { + auto params = std::vector{}; + params.reserve(fn.num_params()); + for (auto const &p : fn.params()) { + params.push_back(visit(*this, *p.ty())); + } + return params; + }()); // Function + auto void_ty = visit(*this, *void_data_type::get(ctx_)); auto fun = add(void_ty, spv::FunctionControl::None, fun_ty); for (auto const &p : fn.params()) { declare(p, add(visit(*this, *p.ty()))); @@ -250,38 +518,37 @@ void spirv_converter::run_on_function(function_node const &fn) { std::array{work_group_size[0], work_group_size[1], 1}}); add_to( fun, spv::ExecutionMode::SubgroupSize, spv::ExecutionModeAttr{fn.subgroup_size()}); - - // Function decoration - auto linkage_decoration = - spv::DecorationAttr{std::make_pair(std::string{fn.name()}, spv::LinkageType::Export)}; - add_to(fun, spv::Decoration::LinkageAttributes, - std::move(linkage_decoration)); } void spirv_converter::run_on_program(program_node const &p) { - add_to(spv::Capability::Addresses); - add_to(spv::Capability::Kernel); - add_to(spv::Capability::Linkage); - add_to(spv::Capability::SubgroupDispatch); + capabilities_.clear(); + capabilities_.insert(spv::Capability::Addresses); + capabilities_.insert(spv::Capability::Kernel); + capabilities_.insert(spv::Capability::SubgroupDispatch); + opencl_ext_ = add_to(spv::OpenCLExt); add_to(spv::AddressingModel::Physical64, spv::MemoryModel::OpenCL); for (auto const &fn : p) { run_on_function(fn); } + + for (auto const &cap : capabilities_) { + add_to(cap); + } } convert_to_spirv_pass::convert_to_spirv_pass(::tinytc_core_info const *info) : info_(std::move(info)) { if (info_ == nullptr) { - throw std::invalid_argument("info must not be nullptr"); + throw status::invalid_arguments; } } auto convert_to_spirv_pass::run_on_program(program_node const &p) -> std::unique_ptr { auto m = std::make_unique(); - spirv_converter(*m, p.context()).run_on_program(p); + spirv_converter(info_, *m, p.context()).run_on_program(p); return m; } diff --git a/src/spv/enums.hpp b/src/spv/enums.hpp index 457a9ee4..52d9a706 100644 --- a/src/spv/enums.hpp +++ b/src/spv/enums.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_ENUMS_2024114_HPP -#define GENERATED_ENUMS_2024114_HPP +#ifndef GENERATED_ENUMS_2024115_HPP +#define GENERATED_ENUMS_2024115_HPP namespace tinytc::spv { @@ -1422,4 +1422,4 @@ enum class FPEncoding {}; } // namespace tinytc::spv -#endif // GENERATED_ENUMS_2024114_HPP +#endif // GENERATED_ENUMS_2024115_HPP diff --git a/src/spv/instructions.hpp b/src/spv/instructions.hpp index 8890fd69..9a5d6e3c 100644 --- a/src/spv/instructions.hpp +++ b/src/spv/instructions.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_INSTRUCTIONS_2024114_HPP -#define GENERATED_INSTRUCTIONS_2024114_HPP +#ifndef GENERATED_INSTRUCTIONS_2024115_HPP +#define GENERATED_INSTRUCTIONS_2024115_HPP #include "enums.hpp" #include "error.hpp" @@ -5587,4 +5587,4 @@ class OpCooperativeMatrixLengthKHR : public spv_inst { } // namespace tinytc::spv -#endif // GENERATED_INSTRUCTIONS_2024114_HPP +#endif // GENERATED_INSTRUCTIONS_2024115_HPP diff --git a/src/spv/module.cpp b/src/spv/module.cpp index 63abaf83..953968bd 100644 --- a/src/spv/module.cpp +++ b/src/spv/module.cpp @@ -10,7 +10,21 @@ void ilist_callbacks::node_removed(spv::spv_inst *node) { delete } // namespace tinytc namespace tinytc::spv { -mod::mod() {} +mod::mod(std::int32_t major_version, std::int32_t minor_version) + : major_version_{major_version}, minor_version_{minor_version} {} mod::~mod() {} + +auto mod::bound() const -> std::int32_t { + std::int32_t bnd = 0; + for (auto const &sec : insts_) { + for (auto const &i : sec) { + if (i.has_result_id()) { + ++bnd; + } + } + } + return bnd; +} + } // namespace tinytc::spv diff --git a/src/spv/module.hpp b/src/spv/module.hpp index c0d14d60..ab0bb0ff 100644 --- a/src/spv/module.hpp +++ b/src/spv/module.hpp @@ -26,31 +26,38 @@ namespace spv { enum class section { capability = 0, - memory_model = 1, - entry_point = 2, - execution_mode = 3, - decoration = 4, - type = 5, - function = 6 + ext_inst = 1, + memory_model = 2, + entry_point = 3, + execution_mode = 4, + decoration = 5, + type_const_var = 6, + function = 7 }; -inline constexpr std::size_t num_module_sections = 7; +inline constexpr std::size_t num_module_sections = 8; class mod final { public: using iterator = ilist::iterator; using const_iterator = ilist::const_iterator; - mod(); + mod(std::int32_t major_version = 1, std::int32_t minor_version = 6); ~mod(); + auto bound() const -> std::int32_t; + inline auto insts(section s) -> ilist & { return insts_[static_cast(s)]; } inline auto insts(section s) const -> ilist const & { return insts_[static_cast(s)]; } inline auto empty(section s) const -> bool { return insts_[static_cast(s)].empty(); } + inline auto major_version() const -> std::int32_t { return major_version_; } + inline auto minor_version() const -> std::int32_t { return minor_version_; } + private: std::array, num_module_sections> insts_; + std::int32_t major_version_, minor_version_; }; } // namespace spv diff --git a/src/spv/names.cpp b/src/spv/names.cpp index 259e9096..ad48e41e 100644 --- a/src/spv/names.cpp +++ b/src/spv/names.cpp @@ -4,9 +4,6 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_NAMES_2024114_HPP -#define GENERATED_NAMES_2024114_HPP - #include "names.hpp" #include "enums.hpp" @@ -2889,5 +2886,3 @@ auto to_string(FPEncoding e) -> char const * { } } // namespace tinytc::spv - -#endif // GENERATED_NAMES_2024114_HPP diff --git a/src/spv/names.hpp b/src/spv/names.hpp index 8bc5f3b9..3c48c560 100644 --- a/src/spv/names.hpp +++ b/src/spv/names.hpp @@ -4,8 +4,10 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_NAMES_2024114_HPP -#define GENERATED_NAMES_2024114_HPP +#ifndef GENERATED_NAMES_2024115_HPP +#define GENERATED_NAMES_2024115_HPP + +#include "enums.hpp" namespace tinytc::spv { @@ -66,4 +68,4 @@ auto to_string(FPEncoding e) -> char const *; } // namespace tinytc::spv -#endif // GENERATED_NAMES_2024114_HPP +#endif // GENERATED_NAMES_2024115_HPP diff --git a/src/spv/opencl.std.cpp b/src/spv/opencl.std.cpp new file mode 100644 index 00000000..1284448b --- /dev/null +++ b/src/spv/opencl.std.cpp @@ -0,0 +1,341 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#include "opencl.std.hpp" + +namespace tinytc::spv { + +auto to_string(OpenCLEntrypoint ep) -> char const * { + switch (ep) { + case OpenCLEntrypoint::acos: + return "acos"; + case OpenCLEntrypoint::acosh: + return "acosh"; + case OpenCLEntrypoint::acospi: + return "acospi"; + case OpenCLEntrypoint::asin: + return "asin"; + case OpenCLEntrypoint::asinh: + return "asinh"; + case OpenCLEntrypoint::asinpi: + return "asinpi"; + case OpenCLEntrypoint::atan: + return "atan"; + case OpenCLEntrypoint::atan2: + return "atan2"; + case OpenCLEntrypoint::atanh: + return "atanh"; + case OpenCLEntrypoint::atanpi: + return "atanpi"; + case OpenCLEntrypoint::atan2pi: + return "atan2pi"; + case OpenCLEntrypoint::cbrt: + return "cbrt"; + case OpenCLEntrypoint::ceil: + return "ceil"; + case OpenCLEntrypoint::copysign: + return "copysign"; + case OpenCLEntrypoint::cos: + return "cos"; + case OpenCLEntrypoint::cosh: + return "cosh"; + case OpenCLEntrypoint::cospi: + return "cospi"; + case OpenCLEntrypoint::erfc: + return "erfc"; + case OpenCLEntrypoint::erf: + return "erf"; + case OpenCLEntrypoint::exp: + return "exp"; + case OpenCLEntrypoint::exp2: + return "exp2"; + case OpenCLEntrypoint::exp10: + return "exp10"; + case OpenCLEntrypoint::expm1: + return "expm1"; + case OpenCLEntrypoint::fabs: + return "fabs"; + case OpenCLEntrypoint::fdim: + return "fdim"; + case OpenCLEntrypoint::floor: + return "floor"; + case OpenCLEntrypoint::fma: + return "fma"; + case OpenCLEntrypoint::fmax: + return "fmax"; + case OpenCLEntrypoint::fmin: + return "fmin"; + case OpenCLEntrypoint::fmod: + return "fmod"; + case OpenCLEntrypoint::fract: + return "fract"; + case OpenCLEntrypoint::frexp: + return "frexp"; + case OpenCLEntrypoint::hypot: + return "hypot"; + case OpenCLEntrypoint::ilogb: + return "ilogb"; + case OpenCLEntrypoint::ldexp: + return "ldexp"; + case OpenCLEntrypoint::lgamma: + return "lgamma"; + case OpenCLEntrypoint::lgamma_r: + return "lgamma_r"; + case OpenCLEntrypoint::log: + return "log"; + case OpenCLEntrypoint::log2: + return "log2"; + case OpenCLEntrypoint::log10: + return "log10"; + case OpenCLEntrypoint::log1p: + return "log1p"; + case OpenCLEntrypoint::logb: + return "logb"; + case OpenCLEntrypoint::mad: + return "mad"; + case OpenCLEntrypoint::maxmag: + return "maxmag"; + case OpenCLEntrypoint::minmag: + return "minmag"; + case OpenCLEntrypoint::modf: + return "modf"; + case OpenCLEntrypoint::nan: + return "nan"; + case OpenCLEntrypoint::nextafter: + return "nextafter"; + case OpenCLEntrypoint::pow: + return "pow"; + case OpenCLEntrypoint::pown: + return "pown"; + case OpenCLEntrypoint::powr: + return "powr"; + case OpenCLEntrypoint::remainder: + return "remainder"; + case OpenCLEntrypoint::remquo: + return "remquo"; + case OpenCLEntrypoint::rint: + return "rint"; + case OpenCLEntrypoint::rootn: + return "rootn"; + case OpenCLEntrypoint::round: + return "round"; + case OpenCLEntrypoint::rsqrt: + return "rsqrt"; + case OpenCLEntrypoint::sin: + return "sin"; + case OpenCLEntrypoint::sincos: + return "sincos"; + case OpenCLEntrypoint::sinh: + return "sinh"; + case OpenCLEntrypoint::sinpi: + return "sinpi"; + case OpenCLEntrypoint::sqrt: + return "sqrt"; + case OpenCLEntrypoint::tan: + return "tan"; + case OpenCLEntrypoint::tanh: + return "tanh"; + case OpenCLEntrypoint::tanpi: + return "tanpi"; + case OpenCLEntrypoint::tgamma: + return "tgamma"; + case OpenCLEntrypoint::trunc: + return "trunc"; + case OpenCLEntrypoint::half_cos: + return "half_cos"; + case OpenCLEntrypoint::half_divide: + return "half_divide"; + case OpenCLEntrypoint::half_exp: + return "half_exp"; + case OpenCLEntrypoint::half_exp2: + return "half_exp2"; + case OpenCLEntrypoint::half_exp10: + return "half_exp10"; + case OpenCLEntrypoint::half_log: + return "half_log"; + case OpenCLEntrypoint::half_log2: + return "half_log2"; + case OpenCLEntrypoint::half_log10: + return "half_log10"; + case OpenCLEntrypoint::half_powr: + return "half_powr"; + case OpenCLEntrypoint::half_recip: + return "half_recip"; + case OpenCLEntrypoint::half_rsqrt: + return "half_rsqrt"; + case OpenCLEntrypoint::half_sin: + return "half_sin"; + case OpenCLEntrypoint::half_sqrt: + return "half_sqrt"; + case OpenCLEntrypoint::half_tan: + return "half_tan"; + case OpenCLEntrypoint::native_cos: + return "native_cos"; + case OpenCLEntrypoint::native_divide: + return "native_divide"; + case OpenCLEntrypoint::native_exp: + return "native_exp"; + case OpenCLEntrypoint::native_exp2: + return "native_exp2"; + case OpenCLEntrypoint::native_exp10: + return "native_exp10"; + case OpenCLEntrypoint::native_log: + return "native_log"; + case OpenCLEntrypoint::native_log2: + return "native_log2"; + case OpenCLEntrypoint::native_log10: + return "native_log10"; + case OpenCLEntrypoint::native_powr: + return "native_powr"; + case OpenCLEntrypoint::native_recip: + return "native_recip"; + case OpenCLEntrypoint::native_rsqrt: + return "native_rsqrt"; + case OpenCLEntrypoint::native_sin: + return "native_sin"; + case OpenCLEntrypoint::native_sqrt: + return "native_sqrt"; + case OpenCLEntrypoint::native_tan: + return "native_tan"; + case OpenCLEntrypoint::s_abs: + return "s_abs"; + case OpenCLEntrypoint::s_abs_diff: + return "s_abs_diff"; + case OpenCLEntrypoint::s_add_sat: + return "s_add_sat"; + case OpenCLEntrypoint::u_add_sat: + return "u_add_sat"; + case OpenCLEntrypoint::s_hadd: + return "s_hadd"; + case OpenCLEntrypoint::u_hadd: + return "u_hadd"; + case OpenCLEntrypoint::s_rhadd: + return "s_rhadd"; + case OpenCLEntrypoint::u_rhadd: + return "u_rhadd"; + case OpenCLEntrypoint::s_clamp: + return "s_clamp"; + case OpenCLEntrypoint::u_clamp: + return "u_clamp"; + case OpenCLEntrypoint::clz: + return "clz"; + case OpenCLEntrypoint::ctz: + return "ctz"; + case OpenCLEntrypoint::s_mad_hi: + return "s_mad_hi"; + case OpenCLEntrypoint::u_mad_sat: + return "u_mad_sat"; + case OpenCLEntrypoint::s_mad_sat: + return "s_mad_sat"; + case OpenCLEntrypoint::s_max: + return "s_max"; + case OpenCLEntrypoint::u_max: + return "u_max"; + case OpenCLEntrypoint::s_min: + return "s_min"; + case OpenCLEntrypoint::u_min: + return "u_min"; + case OpenCLEntrypoint::s_mul_hi: + return "s_mul_hi"; + case OpenCLEntrypoint::rotate: + return "rotate"; + case OpenCLEntrypoint::s_sub_sat: + return "s_sub_sat"; + case OpenCLEntrypoint::u_sub_sat: + return "u_sub_sat"; + case OpenCLEntrypoint::u_upsample: + return "u_upsample"; + case OpenCLEntrypoint::s_upsample: + return "s_upsample"; + case OpenCLEntrypoint::popcount: + return "popcount"; + case OpenCLEntrypoint::s_mad24: + return "s_mad24"; + case OpenCLEntrypoint::u_mad24: + return "u_mad24"; + case OpenCLEntrypoint::s_mul24: + return "s_mul24"; + case OpenCLEntrypoint::u_mul24: + return "u_mul24"; + case OpenCLEntrypoint::u_abs: + return "u_abs"; + case OpenCLEntrypoint::u_abs_diff: + return "u_abs_diff"; + case OpenCLEntrypoint::u_mul_hi: + return "u_mul_hi"; + case OpenCLEntrypoint::u_mad_hi: + return "u_mad_hi"; + case OpenCLEntrypoint::fclamp: + return "fclamp"; + case OpenCLEntrypoint::degrees: + return "degrees"; + case OpenCLEntrypoint::fmax_common: + return "fmax_common"; + case OpenCLEntrypoint::fmin_common: + return "fmin_common"; + case OpenCLEntrypoint::mix: + return "mix"; + case OpenCLEntrypoint::radians: + return "radians"; + case OpenCLEntrypoint::step: + return "step"; + case OpenCLEntrypoint::smoothstep: + return "smoothstep"; + case OpenCLEntrypoint::sign: + return "sign"; + case OpenCLEntrypoint::cross: + return "cross"; + case OpenCLEntrypoint::distance: + return "distance"; + case OpenCLEntrypoint::length: + return "length"; + case OpenCLEntrypoint::normalize: + return "normalize"; + case OpenCLEntrypoint::fast_distance: + return "fast_distance"; + case OpenCLEntrypoint::fast_length: + return "fast_length"; + case OpenCLEntrypoint::fast_normalize: + return "fast_normalize"; + case OpenCLEntrypoint::bitselect: + return "bitselect"; + case OpenCLEntrypoint::select: + return "select"; + case OpenCLEntrypoint::vloadn: + return "vloadn"; + case OpenCLEntrypoint::vstoren: + return "vstoren"; + case OpenCLEntrypoint::vload_half: + return "vload_half"; + case OpenCLEntrypoint::vload_halfn: + return "vload_halfn"; + case OpenCLEntrypoint::vstore_half: + return "vstore_half"; + case OpenCLEntrypoint::vstore_half_r: + return "vstore_half_r"; + case OpenCLEntrypoint::vstore_halfn: + return "vstore_halfn"; + case OpenCLEntrypoint::vstore_halfn_r: + return "vstore_halfn_r"; + case OpenCLEntrypoint::vloada_halfn: + return "vloada_halfn"; + case OpenCLEntrypoint::vstorea_halfn: + return "vstorea_halfn"; + case OpenCLEntrypoint::vstorea_halfn_r: + return "vstorea_halfn_r"; + case OpenCLEntrypoint::shuffle: + return "shuffle"; + case OpenCLEntrypoint::shuffle2: + return "shuffle2"; + case OpenCLEntrypoint::printf: + return "printf"; + case OpenCLEntrypoint::prefetch: + return "prefetch"; + } + return "unknown"; +} + +} // namespace tinytc::spv diff --git a/src/spv/opencl.std.hpp b/src/spv/opencl.std.hpp new file mode 100644 index 00000000..b662101f --- /dev/null +++ b/src/spv/opencl.std.hpp @@ -0,0 +1,183 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_OPENCL_STD_2024115_HPP +#define GENERATED_OPENCL_STD_2024115_HPP + +namespace tinytc::spv { + +constexpr char const *OpenCLExt = "OpenCL.std"; + +enum class OpenCLEntrypoint { + acos = 0, + acosh = 1, + acospi = 2, + asin = 3, + asinh = 4, + asinpi = 5, + atan = 6, + atan2 = 7, + atanh = 8, + atanpi = 9, + atan2pi = 10, + cbrt = 11, + ceil = 12, + copysign = 13, + cos = 14, + cosh = 15, + cospi = 16, + erfc = 17, + erf = 18, + exp = 19, + exp2 = 20, + exp10 = 21, + expm1 = 22, + fabs = 23, + fdim = 24, + floor = 25, + fma = 26, + fmax = 27, + fmin = 28, + fmod = 29, + fract = 30, + frexp = 31, + hypot = 32, + ilogb = 33, + ldexp = 34, + lgamma = 35, + lgamma_r = 36, + log = 37, + log2 = 38, + log10 = 39, + log1p = 40, + logb = 41, + mad = 42, + maxmag = 43, + minmag = 44, + modf = 45, + nan = 46, + nextafter = 47, + pow = 48, + pown = 49, + powr = 50, + remainder = 51, + remquo = 52, + rint = 53, + rootn = 54, + round = 55, + rsqrt = 56, + sin = 57, + sincos = 58, + sinh = 59, + sinpi = 60, + sqrt = 61, + tan = 62, + tanh = 63, + tanpi = 64, + tgamma = 65, + trunc = 66, + half_cos = 67, + half_divide = 68, + half_exp = 69, + half_exp2 = 70, + half_exp10 = 71, + half_log = 72, + half_log2 = 73, + half_log10 = 74, + half_powr = 75, + half_recip = 76, + half_rsqrt = 77, + half_sin = 78, + half_sqrt = 79, + half_tan = 80, + native_cos = 81, + native_divide = 82, + native_exp = 83, + native_exp2 = 84, + native_exp10 = 85, + native_log = 86, + native_log2 = 87, + native_log10 = 88, + native_powr = 89, + native_recip = 90, + native_rsqrt = 91, + native_sin = 92, + native_sqrt = 93, + native_tan = 94, + s_abs = 141, + s_abs_diff = 142, + s_add_sat = 143, + u_add_sat = 144, + s_hadd = 145, + u_hadd = 146, + s_rhadd = 147, + u_rhadd = 148, + s_clamp = 149, + u_clamp = 150, + clz = 151, + ctz = 152, + s_mad_hi = 153, + u_mad_sat = 154, + s_mad_sat = 155, + s_max = 156, + u_max = 157, + s_min = 158, + u_min = 159, + s_mul_hi = 160, + rotate = 161, + s_sub_sat = 162, + u_sub_sat = 163, + u_upsample = 164, + s_upsample = 165, + popcount = 166, + s_mad24 = 167, + u_mad24 = 168, + s_mul24 = 169, + u_mul24 = 170, + u_abs = 201, + u_abs_diff = 202, + u_mul_hi = 203, + u_mad_hi = 204, + fclamp = 95, + degrees = 96, + fmax_common = 97, + fmin_common = 98, + mix = 99, + radians = 100, + step = 101, + smoothstep = 102, + sign = 103, + cross = 104, + distance = 105, + length = 106, + normalize = 107, + fast_distance = 108, + fast_length = 109, + fast_normalize = 110, + bitselect = 186, + select = 187, + vloadn = 171, + vstoren = 172, + vload_half = 173, + vload_halfn = 174, + vstore_half = 175, + vstore_half_r = 176, + vstore_halfn = 177, + vstore_halfn_r = 178, + vloada_halfn = 179, + vstorea_halfn = 180, + vstorea_halfn_r = 181, + shuffle = 182, + shuffle2 = 183, + printf = 184, + prefetch = 185, +}; + +auto to_string(OpenCLEntrypoint op) -> char const *; + +} // namespace tinytc::spv + +#endif // GENERATED_OPENCL_STD_2024115_HPP diff --git a/src/spv/pass/dump_asm.cpp b/src/spv/pass/dump_asm.cpp index 323dea09..5b23dfb5 100644 --- a/src/spv/pass/dump_asm.cpp +++ b/src/spv/pass/dump_asm.cpp @@ -3,6 +3,7 @@ #include "spv/pass/dump_asm.hpp" #include "spv/module.hpp" +#include "spv/opencl.std.hpp" #include "support/casting.hpp" #include @@ -68,16 +69,6 @@ void dump_asm_pass::operator()(LiteralContextDependentNumber const &l) { void dump_asm_pass::operator()(LiteralInteger const &l) { *os_ << " " << l; } void dump_asm_pass::operator()(LiteralString const &l) { *os_ << " \"" << l << '"'; } -void dump_asm_pass::operator()(spv_inst *const &in) { - if (auto s = slot_map_.find(in); s != slot_map_.end()) { - *os_ << " %" << s->second; - } else if (isa(*in)) { - *os_ << " %" << declare(in); - } else { - throw status::spirv_forbidden_forward_declaration; - } -} - void dump_asm_pass::operator()(PairIdRefIdRef const &p) { this->operator()(p.first); this->operator()(p.second); @@ -91,18 +82,50 @@ void dump_asm_pass::operator()(PairLiteralIntegerIdRef const &p) { this->operator()(p.second); } +void dump_asm_pass::operator()(spv_inst *const &in) { + if (auto s = slot_map_.find(in); s != slot_map_.end()) { + *os_ << " %" << s->second; + } else if (isa(*in)) { + *os_ << " %" << declare(in); + } else { + throw status::spirv_forbidden_forward_declaration; + } +} +auto dump_asm_pass::operator()(OpExtInst const &in) { + pre_visit(in); + this->operator()(in.type()); + this->operator()(in.op0()); + + if (auto extimport = dyn_cast(in.op0()); + extimport && extimport->op0() == OpenCLExt) { + this->operator()(static_cast(in.op1())); + } else { + this->operator()(in.op1()); + } + + for (auto const &op : in.op2()) { + this->operator()(op); + } +} + void dump_asm_pass::run_on_module(mod const &m) { auto const visit_section = [&](section s) { for (auto const &i : m.insts(s)) { visit(*this, i); } }; + *os_ << "; SPIR-V" << std::endl + << "; Version " << m.major_version() << '.' << m.minor_version() << std::endl + << "; Generator: Tiny Tensor Compiler" << std::endl + << "; Bound: " << m.bound() << std::endl + << "; Schema: 0"; visit_section(section::capability); + visit_section(section::ext_inst); visit_section(section::memory_model); visit_section(section::entry_point); visit_section(section::execution_mode); visit_section(section::decoration); - visit_section(section::type); + visit_section(section::type_const_var); visit_section(section::function); *os_ << std::endl; } diff --git a/src/spv/pass/dump_asm.hpp b/src/spv/pass/dump_asm.hpp index 238ea0f8..eeb9763e 100644 --- a/src/spv/pass/dump_asm.hpp +++ b/src/spv/pass/dump_asm.hpp @@ -36,11 +36,13 @@ class dump_asm_pass : public default_visitor { void operator()(LiteralInteger const &l); void operator()(LiteralString const &l); - void operator()(spv_inst *const &in); void operator()(PairIdRefIdRef const &p); void operator()(PairIdRefLiteralInteger const &p); void operator()(PairLiteralIntegerIdRef const &p); + void operator()(spv_inst *const &in); + auto operator()(OpExtInst const &in); + void run_on_module(mod const &m); private: diff --git a/src/spv/visit.hpp b/src/spv/visit.hpp index 0e3c9c62..5c56e858 100644 --- a/src/spv/visit.hpp +++ b/src/spv/visit.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_VISIT_2024114_HPP -#define GENERATED_VISIT_2024114_HPP +#ifndef GENERATED_VISIT_2024115_HPP +#define GENERATED_VISIT_2024115_HPP namespace tinytc::spv { @@ -2912,4 +2912,4 @@ template class default_visitor { } // namespace tinytc::spv -#endif // GENERATED_VISIT_2024114_HPP +#endif // GENERATED_VISIT_2024115_HPP diff --git a/test/spv/arith.ir b/test/spv/arith.ir new file mode 100644 index 00000000..4028517d --- /dev/null +++ b/test/spv/arith.ir @@ -0,0 +1,91 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s + +; CHECK: OpCapability Int64 +; CHECK: %[[#BOOL:]] = OpTypeBool +; CHECK: %[[#I64:]] = OpTypeInt 64 1 +; CHECK: %[[#F32:]] = OpTypeFloat 32 +; CHECK: %[[#C32:]] = OpTypeVector %[[#F32]] 2 + +func @tbool(%a: i1, %b: i1) { + %0 = arith.and %a, %b : i1 + %1 = arith.or %a, %b : i1 + %2 = arith.xor %a, %b : i1 +; CHECK: %[[#]] = OpLogicalAnd %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpLogicalOr %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpLogicalNotEqual %[[#BOOL]] %[[#]] %[[#]] +} + +func @tint(%a: i64, %b: i64) { + %0 = arith.add %a, %b : i64 + %1 = arith.sub %a, %b : i64 + %2 = arith.mul %a, %b : i64 + %3 = arith.div %a, %b : i64 + %4 = arith.rem %a, %b : i64 + %5 = arith.shl %a, %b : i64 + %6 = arith.shr %a, %b : i64 + %7 = arith.and %a, %b : i64 + %8 = arith.or %a, %b : i64 + %9 = arith.xor %a, %b : i64 +; CHECK: %[[#]] = OpIAdd %[[#I64]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpISub %[[#I64]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpIMul %[[#I64]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpSDiv %[[#I64]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpSRem %[[#I64]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpShiftLeftLogical %[[#I64]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpShiftRightArithmetic %[[#I64]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpBitwiseAnd %[[#I64]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpBitwiseOr %[[#I64]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpBitwiseXor %[[#I64]] %[[#]] %[[#]] +} + +func @tfloat(%a: f32, %b: f32) { + %0 = arith.add %a, %b : f32 + %1 = arith.sub %a, %b : f32 + %2 = arith.mul %a, %b : f32 + %3 = arith.div %a, %b : f32 + %4 = arith.rem %a, %b : f32 +; CHECK: %[[#]] = OpFAdd %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFSub %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFMul %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFDiv %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFRem %[[#F32]] %[[#]] %[[#]] +} + +func @tcomplex(%a: c32, %b: c32) { + %0 = arith.add %a, %b : c32 + %1 = arith.sub %a, %b : c32 + %2 = arith.mul %a, %b : c32 + %3 = arith.div %a, %b : c32 +; CHECK: %[[#]] = OpFAdd %[[#C32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFSub %[[#C32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFMul %[[#C32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFDiv %[[#C32]] %[[#]] %[[#]] +} + +func @tfloatcoopmatrix() subgroup_size(16) { + %0 = constant 1.0 -> coopmatrix + %1 = constant 2.0 -> coopmatrix + %2 = arith.add %0, %1 : coopmatrix + %3 = arith.sub %0, %1 : coopmatrix + %4 = arith.mul %0, %1 : coopmatrix + %5 = arith.div %0, %1 : coopmatrix +; CHECK: %[[#]] = OpFAdd %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFAdd %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFAdd %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFAdd %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFSub %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFSub %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFSub %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFSub %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFMul %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFMul %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFMul %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFMul %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFDiv %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFDiv %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFDiv %[[#F32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFDiv %[[#F32]] %[[#]] %[[#]] +} diff --git a/test/spv/arith_unary.ir b/test/spv/arith_unary.ir new file mode 100644 index 00000000..eb1f6ac5 --- /dev/null +++ b/test/spv/arith_unary.ir @@ -0,0 +1,60 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s + +; CHECK: OpCapability Int64 +; CHECK: %[[#EXT:]] = OpExtInstImport "OpenCL.std" +; CHECK: %[[#BOOL:]] = OpTypeBool +; CHECK: %[[#I64:]] = OpTypeInt 64 1 +; CHECK: %[[#F32:]] = OpTypeFloat 32 +; CHECK: %[[#C32:]] = OpTypeVector %[[#F32]] 2 + +func @tbool(%a: i1) { + %0 = arith.not %a : i1 +; CHECK: OpLogicalNot %[[#BOOL]] %[[#]] +} + +func @tint(%a: i64) { + %0 = arith.abs %a : i64 + %1 = arith.neg %a : i64 + %2 = arith.not %a : i64 +; CHECK: OpExtInst %[[#I64]] %[[#EXT]] s_abs %[[#]] +; CHECK-NEXT: OpSNegate %[[#I64]] %[[#]] +; CHECK-NEXT: OpNot %[[#I64]] %[[#]] +} + +func @tfloat(%a: f32) { + %0 = arith.abs %a : f32 + %1 = arith.neg %a : f32 +; CHECK: OpExtInst %[[#F32]] %[[#EXT]] fabs %[[#]] +; CHECK-NEXT: OpFNegate %[[#F32]] %[[#]] +} + +func @tcomplex(%a: c32) { + %0 = arith.abs %a : c32 + %1 = arith.neg %a : c32 + %2 = arith.conj %a : c32 + %3 = arith.im %a : c32 + %4 = arith.re %a : c32 +; CHECK: %[[#A2:]] = OpFMul %[[#C32]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#A2_0:]] = OpCompositeExtract %[[#F32]] %[[#A2]] 0 +; CHECK-NEXT: %[[#A2_1:]] = OpCompositeExtract %[[#F32]] %[[#A2]] 1 +; CHECK-NEXT: %[[#A2_0p1:]] = OpFAdd %[[#F32]] %[[#A2_0]] %[[#A2_1]] +; CHECK-NEXT: %[[#]] = OpExtInst %[[#F32]] %[[#EXT]] sqrt %[[#A2_0p1]] +; CHECK-NEXT: %[[#]] = OpFNegate %[[#C32]] %[[#]] +; CHECK-NEXT: %[[#A_IM:]] = OpCompositeExtract %[[#F32]] %[[#OPERAND:]] 1 +; CHECK-NEXT: %[[#NEG_A_IM:]] = OpFNegate %[[#F32]] %[[#A_IM]] +; CHECK-NEXT: %[[#]] = OpCompositeInsert %[[#C32]] %[[#NEG_A_IM]] %[[#OPERAND]] 1 +; CHECK-NEXT: %[[#]] = OpCompositeExtract %[[#F32]] %[[#]] 1 +; CHECK-NEXT: %[[#]] = OpCompositeExtract %[[#F32]] %[[#]] 0 +} + +func @tfloatcoopmatrix() subgroup_size(16) { + %0 = constant 1.0 -> coopmatrix + %2 = arith.neg %0 : coopmatrix +; CHECK: %[[#]] = OpFNegate %[[#F32]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFNegate %[[#F32]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFNegate %[[#F32]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFNegate %[[#F32]] %[[#]] +} diff --git a/test/spv/unique_function_type.ir b/test/spv/unique_function_type.ir new file mode 100644 index 00000000..e6bf62a5 --- /dev/null +++ b/test/spv/unique_function_type.ir @@ -0,0 +1,15 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv < %s | filecheck %s +func @f1() {} +func @f2() {} +func @f3(%a: i32, %b: f32) {} +func @f4(%a: i32, %b: f32) {} +func @f5(%a: f32, %b: f32) {} + +; CHECK: %[[#]] = OpFunction %[[#]] None %[[#TYPE0:]] +; CHECK: %[[#]] = OpFunction %[[#]] None %[[#TYPE0]] +; CHECK: %[[#]] = OpFunction %[[#]] None %[[#TYPE1:]] +; CHECK: %[[#]] = OpFunction %[[#]] None %[[#TYPE1]] +; CHECK: %[[#]] = OpFunction %[[#]] None %[[#]] diff --git a/tools/spirvgen/ext_opencl.py b/tools/spirvgen/ext_opencl.py new file mode 100755 index 00000000..40bff3a4 --- /dev/null +++ b/tools/spirvgen/ext_opencl.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# +# Very simple and stupid script to generate SPIR-V classes +# + +import argparse +import json +import os +import shutil + +from gen import generate_cpp, generate_header + +spv_ext = 'opencl.std.hpp' +spv_ext_name = 'OpenCL.std' +spv_ext_entrypoint = 'OpenCLEntrypoint' +spv_ext_includes = [] +spv_ext_cpp = 'opencl.std.cpp' +spv_ext_cpp_includes = [spv_ext] + + +def generate_enums(f, grammar): + print(f'constexpr char const* OpenCLExt = "{spv_ext_name}";', file=f) + print(file=f) + print(f'enum class {spv_ext_entrypoint} {{', file=f) + for inst in grammar['instructions']: + print(f'{inst["opname"]} = {inst["opcode"]},', file=f) + print('};', file=f) + print(file=f) + print(f'auto to_string({spv_ext_entrypoint} op) -> char const*;', file=f) + + +def generate_enums_cpp(f, grammar): + print( + f'auto to_string({spv_ext_entrypoint} ep) -> char const* {{ switch(ep) {{', + file=f) + for inst in grammar['instructions']: + print( + f'case {spv_ext_entrypoint}::{inst["opname"]}: return "{inst["opname"]}";', + file=f) + print('} return "unknown";}', file=f) + + +if __name__ == '__main__': + script_dir = os.path.dirname(os.path.realpath(__file__)) + parser = argparse.ArgumentParser() + parser.add_argument('-c', + help='clang-format binary', + default='clang-format'), + parser.add_argument('-o', help='output directory', default=''), + parser.add_argument( + 'grammar', + help= + 'extinst.opencl.std.100.grammar.json file from SPIRV-Headers project') + args = parser.parse_args() + + if shutil.which(args.c): + with open(args.grammar) as f: + grammar = json.load(f) + generate_header(args, spv_ext, grammar, generate_enums, + spv_ext_includes) + generate_cpp(args, spv_ext_cpp, grammar, generate_enums_cpp, + spv_ext_cpp_includes) + else: + print(f'Could not find clang-format: {args.c}') diff --git a/tools/spirvgen/gen.py b/tools/spirvgen/gen.py new file mode 100644 index 00000000..ed18ccc3 --- /dev/null +++ b/tools/spirvgen/gen.py @@ -0,0 +1,78 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause + +import datetime +import os +import subprocess + +def print_includes(f, includes): + for include in includes: + if include: + if include[0] != '<' and include[0] != '"': + print(f'#include "{include}"', file=f) + else: + print(f'#include {include}', file=f) + else: + print('', file=f) + + +def generate_header(args, filename, grammar, generator, includes=[]): + filename = os.path.join(args.o, filename) + with open(filename, 'w') as f: + now = datetime.datetime.now() + basename = os.path.splitext(os.path.basename(filename))[0].upper() + basename = basename.replace('.', '_') + headerguard_name = f'GENERATED_{basename}_{now.year}{now.month}{now.day}_HPP' + + print(f"""// Copyright (C) {now.year} Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef {headerguard_name} +#define {headerguard_name} + +""", + file=f) + print_includes(f, includes) + print(""" +namespace tinytc::spv { +""", file=f) + + generator(f, grammar) + + print(f""" +}} + +#endif // {headerguard_name} +""", file=f) + + subprocess.call([args.c, '-i', filename]) + + +def generate_cpp(args, filename, grammar, generator, includes=[]): + filename = os.path.join(args.o, filename) + with open(filename, 'w') as f: + now = datetime.datetime.now() + print(f"""// Copyright (C) {now.year} Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +""", + file=f) + print_includes(f, includes) + print(""" +namespace tinytc::spv { +""", file=f) + + generator(f, grammar) + + print(f""" +}} +""", file=f) + + subprocess.call([args.c, '-i', filename]) + diff --git a/tools/spirvgen/spirvgen.py b/tools/spirvgen/spirvgen.py index 160aeb21..894d5526 100755 --- a/tools/spirvgen/spirvgen.py +++ b/tools/spirvgen/spirvgen.py @@ -7,14 +7,15 @@ # import argparse -import datetime import json import os import shutil -import subprocess + +from gen import generate_cpp, generate_header spv_enums = 'enums.hpp' spv_names = 'names.hpp' +spv_names_includes = [spv_enums] spv_names_cpp = 'names.cpp' spv_names_cpp_includes = [spv_names, spv_enums] spv_ops = 'instructions.hpp' @@ -254,79 +255,6 @@ def generate_visitor(f, grammar): print('};', file=f) -def print_includes(f, includes): - for include in includes: - if include: - if include[0] != '<' and include[0] != '"': - print(f'#include "{include}"', file=f) - else: - print(f'#include {include}', file=f) - else: - print('', file=f) - - -def generate_header(args, filename, grammar, generator, includes=[]): - filename = os.path.join(args.o, filename) - with open(filename, 'w') as f: - now = datetime.datetime.now() - basename = os.path.splitext(os.path.basename(filename))[0].upper() - headerguard_name = f'GENERATED_{basename}_{now.year}{now.month}{now.day}_HPP' - - print(f"""// Copyright (C) {now.year} Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -// This file is generated -// Do not edit manually - -#ifndef {headerguard_name} -#define {headerguard_name} - -""", - file=f) - print_includes(f, includes) - print(""" -namespace tinytc::spv { -""", file=f) - - generator(f, grammar) - - print(f""" -}} - -#endif // {headerguard_name} -""", file=f) - - subprocess.call([args.c, '-i', filename]) - - -def generate_cpp(args, filename, grammar, generator, includes=[]): - filename = os.path.join(args.o, filename) - with open(filename, 'w') as f: - now = datetime.datetime.now() - print(f"""// Copyright (C) {now.year} Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -// This file is generated -// Do not edit manually - -""", - file=f) - print_includes(f, includes) - print(""" -namespace tinytc::spv { -""", file=f) - - generator(f, grammar) - - print(f""" -}} - -#endif // {headerguard_name} -""", file=f) - - subprocess.call([args.c, '-i', filename]) - - def filter_grammar(grammar, filt): filtered_instructions = [] for instruction in grammar['instructions']: @@ -375,9 +303,9 @@ def patch_grammar(grammar): grammar = filter_grammar(grammar, filt) grammar = patch_grammar(grammar) generate_header(args, spv_enums, grammar, generate_enums) - generate_header(args, spv_names, grammar, generate_names) - generate_header(args, spv_names_cpp, grammar, generate_names_cpp, - spv_names_cpp_includes) + generate_header(args, spv_names, grammar, generate_names, spv_names_includes) + generate_cpp(args, spv_names_cpp, grammar, generate_names_cpp, + spv_names_cpp_includes) generate_header(args, spv_ops, grammar, generate_op_classes, spv_ops_includes) generate_header(args, spv_visitor, grammar, generate_visitor) From 80bf7d850b9cae9c550ab2ac028d97419582d2f1 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 6 Nov 2024 12:16:12 +0100 Subject: [PATCH 086/297] Make bool a separate non-scalar type Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.rst | 49 +++++++++ docs/api/builder_capi.yaml | 2 + docs/api/builder_cxxapi.rst | 42 ++++++++ docs/api/builder_cxxapi.yaml | 2 + docs/api/core_capi.rst | 7 ++ docs/api/core_cxxapi.rst | 7 ++ docs/manual/tensor-ir.rst | 85 ++++++++------- include/tinytc/tinytc.h | 26 +++++ include/tinytc/tinytc.hpp | 34 +++++- include/tinytc/types.h | 87 +++++++-------- include/tinytc/types.hpp | 7 +- src/codegen_tools.cpp | 11 ++ src/codegen_tools.hpp | 1 + src/compiler_context_cache.cpp | 1 + src/compiler_context_cache.hpp | 2 +- src/data_type.cpp | 8 ++ src/error.cpp | 11 +- src/inst.cpp | 25 ++++- src/node/data_type_node.cpp | 4 + src/node/data_type_node.hpp | 16 ++- src/node/inst_node.cpp | 146 ++++++++++++++------------ src/node/inst_node.hpp | 2 +- src/parser/lexer.re | 8 +- src/parser/parser_impl.yy | 24 ++++- src/pass/constant_folding.hpp | 87 +++++++++------ src/pass/convert_to_opencl.cpp | 58 +++++++--- src/pass/convert_to_opencl.hpp | 1 + src/pass/dead_code_elimination.cpp | 4 +- src/pass/dump_ir.cpp | 2 + src/pass/dump_ir.hpp | 1 + src/pass/lower_linalg.cpp | 2 +- src/recipe.cpp | 2 - src/scalar_type.cpp | 7 -- test/codegen/if.ir | 2 +- test/codegen/scalar_arithmetic.ir | 8 +- test/generator.cpp | 11 -- test/opt/constant-propagation-safe.ir | 29 +++++ test/opt/constant-propagation.ir | 70 ++++++------ test/opt/dead-code-elimination.ir | 8 +- test/spv/arith.ir | 8 +- test/spv/arith_unary.ir | 4 +- 41 files changed, 627 insertions(+), 284 deletions(-) diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index 5a50545c..233af067 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -16,6 +16,8 @@ Common * :ref:`tinytc_arithmetic_unary_t` + * :ref:`tinytc_checked_flag_t` + * :ref:`tinytc_cmp_condition_t` * :ref:`tinytc_matrix_use_t` @@ -26,6 +28,8 @@ Common * :ref:`tinytc_transpose_t` + * :ref:`tinytc_work_group_operation_t` + * Definitions * :ref:`TINYTC_DYNAMIC` @@ -38,6 +42,8 @@ Common * :ref:`tinytc_arithmetic_unary_to_string` + * :ref:`tinytc_checked_flag_to_string` + * :ref:`tinytc_cmp_condition_to_string` * :ref:`tinytc_matrix_use_to_string` @@ -50,6 +56,8 @@ Common * :ref:`tinytc_transpose_to_string` + * :ref:`tinytc_work_group_operation_to_string` + * Structures * :ref:`tinytc_position` @@ -104,6 +112,11 @@ tinytc_arithmetic_unary_t .. doxygenenum:: tinytc_arithmetic_unary_t +tinytc_checked_flag_t +..................... + +.. doxygenenum:: tinytc_checked_flag_t + tinytc_cmp_condition_t ...................... @@ -129,6 +142,11 @@ tinytc_transpose_t .. doxygenenum:: tinytc_transpose_t +tinytc_work_group_operation_t +............................. + +.. doxygenenum:: tinytc_work_group_operation_t + Common Definitions ------------------ @@ -155,6 +173,11 @@ tinytc_arithmetic_unary_to_string .. doxygenfunction:: tinytc_arithmetic_unary_to_string +tinytc_checked_flag_to_string +............................. + +.. doxygenfunction:: tinytc_checked_flag_to_string + tinytc_cmp_condition_to_string .............................. @@ -185,6 +208,11 @@ tinytc_transpose_to_string .. doxygenfunction:: tinytc_transpose_to_string +tinytc_work_group_operation_to_string +..................................... + +.. doxygenfunction:: tinytc_work_group_operation_to_string + Common Structures ----------------- @@ -276,6 +304,8 @@ Data Type * Functions + * :ref:`tinytc_boolean_type_get` + * :ref:`tinytc_coopmatrix_type_get` * :ref:`tinytc_group_type_get` @@ -287,6 +317,11 @@ Data Type Data Type Functions ------------------- +tinytc_boolean_type_get +....................... + +.. doxygenfunction:: tinytc_boolean_type_get + tinytc_coopmatrix_type_get .......................... @@ -367,6 +402,8 @@ Instruction * :ref:`tinytc_cmp_inst_create` + * :ref:`tinytc_constant_inst_create_boolean` + * :ref:`tinytc_constant_inst_create_complex` * :ref:`tinytc_constant_inst_create_float` @@ -427,6 +464,8 @@ Instruction * :ref:`tinytc_sum_inst_create` + * :ref:`tinytc_work_group_inst_create` + * :ref:`tinytc_yield_inst_create` * :ref:`tinytc_inst_get_regions` @@ -468,6 +507,11 @@ tinytc_cmp_inst_create .. doxygenfunction:: tinytc_cmp_inst_create +tinytc_constant_inst_create_boolean +................................... + +.. doxygenfunction:: tinytc_constant_inst_create_boolean + tinytc_constant_inst_create_complex ................................... @@ -618,6 +662,11 @@ tinytc_sum_inst_create .. doxygenfunction:: tinytc_sum_inst_create +tinytc_work_group_inst_create +............................. + +.. doxygenfunction:: tinytc_work_group_inst_create + tinytc_yield_inst_create ........................ diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index 95f35ff7..50abd572 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -47,6 +47,7 @@ Builder C-API: - const_tinytc_value_t Data Type: function: + - tinytc_boolean_type_get - tinytc_coopmatrix_type_get - tinytc_group_type_get - tinytc_memref_type_get @@ -66,6 +67,7 @@ Builder C-API: - tinytc_arith_unary_inst_create - tinytc_cast_inst_create - tinytc_cmp_inst_create + - tinytc_constant_inst_create_boolean - tinytc_constant_inst_create_complex - tinytc_constant_inst_create_float - tinytc_constant_inst_create_int diff --git a/docs/api/builder_cxxapi.rst b/docs/api/builder_cxxapi.rst index 9c10938f..2187115a 100644 --- a/docs/api/builder_cxxapi.rst +++ b/docs/api/builder_cxxapi.rst @@ -26,6 +26,8 @@ Common * :ref:`transpose` + * :ref:`work_group_operation` + * Functions * :ref:`is_dynamic_value` @@ -36,6 +38,8 @@ Common * :ref:`to_string(arithmetic_unary)` + * :ref:`to_string(checked_flag)` + * :ref:`to_string(cmp_condition)` * :ref:`to_string(matrix_use)` @@ -46,6 +50,8 @@ Common * :ref:`to_string(transpose)` + * :ref:`to_string(work_group_operation)` + * :ref:`size` * Classes @@ -105,6 +111,11 @@ transpose .. doxygenenum:: tinytc::transpose +work_group_operation +.................... + +.. doxygenenum:: tinytc::work_group_operation + Common Functions ---------------- @@ -128,6 +139,11 @@ to_string(arithmetic_unary) .. doxygenfunction:: tinytc::to_string(arithmetic_unary) +to_string(checked_flag) +....................... + +.. doxygenfunction:: tinytc::to_string(checked_flag) + to_string(cmp_condition) ........................ @@ -153,6 +169,11 @@ to_string(transpose) .. doxygenfunction:: tinytc::to_string(transpose) +to_string(work_group_operation) +............................... + +.. doxygenfunction:: tinytc::to_string(work_group_operation) + size .... @@ -192,6 +213,8 @@ Data Type * Functions + * :ref:`get_boolean` + * :ref:`get_coopmatrix` * :ref:`get_group` @@ -215,6 +238,11 @@ Data Type Data Type Functions ------------------- +get_boolean +........... + +.. doxygenfunction:: tinytc::get_boolean + get_coopmatrix .............. @@ -303,6 +331,8 @@ Instruction * :ref:`make_cmp` + * :ref:`make_constant(bool,data_type,location const&)` + * :ref:`make_constant(std::complex\,data_type,location const&)` * :ref:`make_constant(double,data_type,location const&)` @@ -365,6 +395,8 @@ Instruction * :ref:`make_sum` + * :ref:`make_work_group` + * :ref:`make_yield` * Classes @@ -404,6 +436,11 @@ make_cmp .. doxygenfunction:: tinytc::make_cmp +make_constant(bool,data_type,location const&) +............................................. + +.. doxygenfunction:: tinytc::make_constant(bool,data_type,location const&) + make_constant(std::complex,data_type,location const&) ............................................................. @@ -559,6 +596,11 @@ make_sum .. doxygenfunction:: tinytc::make_sum +make_work_group +............... + +.. doxygenfunction:: tinytc::make_work_group + make_yield .......... diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index 34db5874..c79be440 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -34,6 +34,7 @@ Builder C++-API: - tinytc::dynamic Data Type: function: + - tinytc::get_boolean - tinytc::get_coopmatrix - tinytc::get_group - tinytc::get_memref @@ -57,6 +58,7 @@ Builder C++-API: - tinytc::make_arith(arithmetic_unary,value,location const&) - tinytc::make_cast - tinytc::make_cmp + - tinytc::make_constant(bool,data_type,location const&) - tinytc::make_constant(std::complex,data_type,location const&) - tinytc::make_constant(double,data_type,location const&) - tinytc::make_constant(std::int32_t,data_type,location const&) diff --git a/docs/api/core_capi.rst b/docs/api/core_capi.rst index 8579e308..fe81e9cc 100644 --- a/docs/api/core_capi.rst +++ b/docs/api/core_capi.rst @@ -263,6 +263,8 @@ Compiler * :ref:`tinytc_prog_compile_to_opencl` + * :ref:`tinytc_prog_compile_to_spirv` + Compiler Enumerations --------------------- @@ -294,6 +296,11 @@ tinytc_prog_compile_to_opencl .. doxygenfunction:: tinytc_prog_compile_to_opencl +tinytc_prog_compile_to_spirv +............................ + +.. doxygenfunction:: tinytc_prog_compile_to_spirv + Compiler Context ================ diff --git a/docs/api/core_cxxapi.rst b/docs/api/core_cxxapi.rst index 72513bc9..5038b0af 100644 --- a/docs/api/core_cxxapi.rst +++ b/docs/api/core_cxxapi.rst @@ -162,6 +162,8 @@ Compiler * :ref:`compile_to_opencl` + * :ref:`compile_to_spirv` + Compiler Functions ------------------ @@ -180,6 +182,11 @@ compile_to_opencl .. doxygenfunction:: tinytc::compile_to_opencl +compile_to_spirv +................ + +.. doxygenfunction:: tinytc::compile_to_spirv + Compiler Context ================ diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index cf544d01..1d892bb7 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -68,8 +68,9 @@ Constants .. code:: abnf - constant = complex-constant / floating-constant / integer-constant - integer-constant = "true" / "false" / [sign] 1*DIGIT + constant = boolean-constant / integer-constant / floating-constant / complex-constant + boolean-constant = "true" / "false" + integer-constant = [sign] 1*DIGIT sign = "-" / "+" floating-constant = [sign] *DIGIT "." 1*DIGIT ["e" [sign] 1*DIGIT] mantissa-dec = *DIGIT "." 1*DIGIT / 1*DIGIT "." @@ -142,16 +143,25 @@ Types .. code:: abnf - type = void-type / scalar-type / memref-type / group-type + type = void-type / boolean-type / scalar-type / memref-type / group-type void-type = "void" +Boolean type +------------ + +.. code:: abnf + + boolean-type = "bool" + +Boolean type that only has two states (true or false). + Scalar types ------------ .. code:: abnf scalar-type = integer-type / floating-type / complex-type - integer-type = "i" ("1" / "8" / "16" / "32" / "64") / "index" + integer-type = "i" ("8" / "16" / "32" / "64") / "index" floating-type = "f" ("32" / "64") complex-type = "c" ("32" / "64") @@ -162,7 +172,7 @@ e.g. "f64" are double precision floating point numbers. The "index" type is an integer type whose width is platform-specific. Scalar types are ordered as -:math:`i1 \prec \text{i8} \prec \text{i16} \prec \text{i32} \prec \text{i64} \prec \text{f32} \prec \text{f64} \prec \text{c32} \prec \text{c64}`. +:math:`\text{i8} \prec \text{i16} \prec \text{i32} \prec \text{i64} \prec \text{f32} \prec \text{f64} \prec \text{c32} \prec \text{c64}`. A scalar type :math:`\alpha` is called *compatible to* a scalar type :math:`\beta` if :math:`\alpha \preceq \beta`. If an arithmetic operation involves mixed types :math:`\alpha` and :math:`\beta` and @@ -289,8 +299,6 @@ The supported matrix shapes may depend on data type, matrix use, and target hard An argument to any instruction that has coopmatrix type **must** be dynamically uniform. -Having i1 as component type of a coopmatrix is forbidden. - Instructions ============ @@ -639,7 +647,7 @@ Arithmetic (binary) ".or" / ".xor" value-instruction =/ "arith" arith-binary-type local-identifier "," local-identifier - ":" (scalar-type / coopmatrix-type) + ":" (boolean-type / scalar-type / coopmatrix-type) Overview ~~~~~~~~ @@ -650,22 +658,21 @@ Arithmetic on cooperative matrices is done component-wise. The following table shows the operations' description and the types that are allowed for the operation. The backslash "\\" is used to exclude types from the list of allowed types. -Boolean arithmetic is only allowed for .and, .or, and .xor. - -==== ============================= ========== ====================================================== -Op Allowed type i1 allowed Description -==== ============================= ========== ====================================================== -.add scalar-type / coopmatrix-type No Sum of operands -.sub scalar-type / coopmatrix-type No Difference of operands -.mul scalar-type / coopmatrix-type No Product of operands -.div scalar-type / coopmatrix-type No Quotient of operands -.rem scalar-type \\ complex-type No Remainder from the division of operands -.shl integer-type No Left shift first operand by second operand -.shr integer-type No Arithmetic right shift first operand by second operand -.and integer-type Yes Bitwise and -.or integer-type Yes Bitwise or -.xor integer-type Yes Bitwise xor -==== ============================= ========== ====================================================== + +==== ============================= ====================================================== +Op Allowed type Description +==== ============================= ====================================================== +.add scalar-type / coopmatrix-type Sum of operands +.sub scalar-type / coopmatrix-type Difference of operands +.mul scalar-type / coopmatrix-type Product of operands +.div scalar-type / coopmatrix-type Quotient of operands +.rem scalar-type \\ complex-type Remainder from the division of operands +.shl integer-type Left shift first operand by second operand +.shr integer-type Arithmetic right shift first operand by second operand +.and boolean-type / integer-type Bitwise and +.or boolean-type / integer-type Bitwise or +.xor boolean-type / integer-type Bitwise xor +==== ============================= ====================================================== Arithmetic (unary) .................. @@ -686,18 +693,17 @@ for ".abs", ".im", and ".re", and the returned value has the same type as the op for ".neg" and ".conj". The following table shows the operations' description and the types that are allowed for the operation. -Boolean arithmetic is only allowed for .neg. - -===== ============================= ========== ============================= -Op Allowed type i1 allowed Description -===== ============================= ========== ============================= -.abs scalar-type No Compute absolute value -.neg scalar-type / coopmatrix-type No Negation -.not integer-type Yes Bitwise not -.conj complex-type No Complex conjugate -.im complex-type No Extract imaginary part -.re complex-type No Extract real part -===== ============================= ========== ============================= + +===== ============================= ============================= +Op Allowed type Description +===== ============================= ============================= +.abs scalar-type Compute absolute value +.neg scalar-type / coopmatrix-type Negation +.not boolean-type / integer-type Bitwise not +.conj complex-type Complex conjugate +.im complex-type Extract imaginary part +.re complex-type Extract real part +===== ============================= ============================= Barrier ....... @@ -740,6 +746,7 @@ Overview Cast scalar values or cooperative matrices. The shape and the use the coopmatrix types must match. + Casts from complex types to non-complex types are forbidden. The following table summarizes the casts and the mapping to SPIR-V (the casts are done component-wise for coopmatrix types): @@ -770,7 +777,7 @@ Overview ~~~~~~~~ Scalar comparison. -Both operands must have the same scalar type and the returned value is boolean. +Both operands must have the same scalar type and the returned value has boolean type. The following table shows the comparisons' description and the types that are allowed for the comparison. The backslash "\\" is used to exclude types from the list of allowed types. @@ -791,7 +798,7 @@ Constant .. code:: abnf - value-instruction =/ "constant" constant "->" (scalar-type / coopmatrix-type) + value-instruction =/ "constant" constant "->" (boolean-type / scalar-type / coopmatrix-type) Overview ~~~~~~~~ @@ -1165,7 +1172,7 @@ Overview An if statement. Both regions are *mixed regions*. -The condition must be of bool type. +The condition must have boolean type. Returns ~~~~~~~ diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 3714318d..aaeca1d7 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -49,6 +49,17 @@ TINYTC_EXPORT size_t tinytc_scalar_type_size(tinytc_scalar_type_t ty); ///////// Data type //////// //////////////////////////// +/** + * @brief Get boolean data type + * + * @param dt [out] pointer to the data type object created + * @param ctx [inout] compiler context + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_boolean_type_get(tinytc_data_type_t *dt, + tinytc_compiler_context_t ctx); + /** * @brief Get scalar data type * @@ -251,6 +262,21 @@ TINYTC_EXPORT tinytc_status_t tinytc_cmp_inst_create(tinytc_inst_t *instr, tinytc_value_t b, const tinytc_location_t *loc); +/** + * @brief Create boolean constant instruction + * + * @param instr [out] pointer to the inst object created + * @param value [in] constant value + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_boolean(tinytc_inst_t *instr, + tinytc_bool_t value, + tinytc_data_type_t ty, + const tinytc_location_t *loc); + /** * @brief Create complex constant instruction * diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 8d4f331c..6ce71fb2 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -81,10 +81,6 @@ inline std::size_t size(scalar_type ty) { */ template struct to_scalar_type; //! to_scalar_type specialization -template <> struct to_scalar_type { - static constexpr scalar_type value = scalar_type::i1; ///< value -}; -//! to_scalar_type specialization template <> struct to_scalar_type { static constexpr scalar_type value = scalar_type::i8; ///< value }; @@ -548,6 +544,21 @@ inline bool is_dynamic_value(std::int64_t i) { return i == dynamic; } //! Alias for tinytc_data_type_t using data_type = tinytc_data_type_t; +/** + * @brief Get the boolean data type + * + * Cf. \ref tinytc_boolean_type_get + * + * @param ctx Compiler context + * + * @return Data type + */ +inline data_type get_boolean(compiler_context const &ctx) { + tinytc_data_type_t bt; + CHECK_STATUS(tinytc_boolean_type_get(&bt, ctx.get())); + return bt; +} + /** * @brief Get a scalar data type * @@ -912,6 +923,21 @@ inline inst make_cmp(cmp_condition cond, value a, value b, location const &loc = return inst(instr); } +/** + * @brief Make boolean constant + * + * @param value Constant + * @param ty Data type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_constant(bool value, data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_boolean(&instr, value, ty, &loc), loc); + return inst(instr); +} + /** * @brief Make complex constant * diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 81f9d50a..a67d2223 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -49,39 +49,43 @@ typedef enum { tinytc_status_ir_shape_stride_mismatch = 0x103, ///< Mismatch of shape and stride tinytc_status_ir_scalar_mismatch = 0x104, ///< Mismatch of scalar types tinytc_status_ir_invalid_number_of_indices = 0x105, /// Invalid number of indices - tinytc_status_ir_expected_scalar = 0x106, ///< Expected a value of scalar type - tinytc_status_ir_expected_index = 0x107, ///< Expected a value of index type - tinytc_status_ir_expected_coopmatrix = 0x108, ///< Expected a value of coopmatrix type + tinytc_status_ir_expected_boolean = 0x106, ///< Expected a value of boolean type + tinytc_status_ir_expected_scalar = 0x107, ///< Expected a value of scalar type + tinytc_status_ir_expected_index = 0x108, ///< Expected a value of index type + tinytc_status_ir_expected_coopmatrix = 0x109, ///< Expected a value of coopmatrix type tinytc_status_ir_expected_coopmatrix_or_scalar = - 0x109, ///< Expected a value of coopmatrix or scalar type - tinytc_status_ir_expected_memref = 0x10a, ///< Expected a value of memref type - tinytc_status_ir_expected_memref_or_scalar = 0x10b, ///< Expected memref or scalar type - tinytc_status_ir_expected_memref_or_group = 0x10c, ///< Expected a value of memref or group type - tinytc_status_ir_expected_matrix = 0x10d, ///< Expected a marix - tinytc_status_ir_expected_vector_or_matrix = 0x10e, ///< Expected a vector or marix - tinytc_status_ir_unexpected_yield = 0x10f, ///< Unexpected yield instruction - tinytc_status_ir_yield_mismatch = 0x110, ///< Wrong number of yielded values - tinytc_status_ir_subview_mismatch = 0x111, ///< Mismatch in subview - tinytc_status_ir_invalid_slice = 0x112, ///< Invalid slice - tinytc_status_ir_expand_shape_order_too_small = 0x113, ///< Expand shape too small - tinytc_status_ir_expand_shape_mismatch = 0x114, ///< Invalid expand shape - tinytc_status_ir_collective_called_from_spmd = 0x115, ///< Collective instruction from SPMD - tinytc_status_ir_fp_unsupported = 0x116, ///< Instruction does not support floating type - tinytc_status_ir_spmd_called_from_collective = 0x117, ///< SPMD instruction from collective - tinytc_status_ir_expected_local_address_space = 0x118, ///< Expected local address space - tinytc_status_ir_expected_global_address_space = 0x119, ///< Expected global address space - tinytc_status_ir_invalid_offset = 0x11a, ///< Invalid offset - tinytc_status_ir_int_unsupported = 0x11b, ///< Instruction does not support int type - tinytc_status_ir_i1_unsupported = 0x11c, ///< Instruction does not support i1 type - tinytc_status_ir_complex_unsupported = 0x11d, ///< Instruction does not support complex type + 0x10a, ///< Expected a value of coopmatrix or scalar type + tinytc_status_ir_expected_coopmatrix_scalar_or_boolean = + 0x10b, ///< Expected a value of coopmatrix, scalar type, or boolean + tinytc_status_ir_expected_memref = 0x10c, ///< Expected a value of memref type + tinytc_status_ir_expected_memref_or_scalar = 0x10d, ///< Expected memref or scalar type + tinytc_status_ir_expected_memref_or_group = 0x10e, ///< Expected a value of memref or group type + tinytc_status_ir_expected_matrix = 0x10f, ///< Expected a marix + tinytc_status_ir_expected_vector_or_matrix = 0x110, ///< Expected a vector or marix + tinytc_status_ir_unexpected_yield = 0x111, ///< Unexpected yield instruction + tinytc_status_ir_yield_mismatch = 0x112, ///< Wrong number of yielded values + tinytc_status_ir_subview_mismatch = 0x113, ///< Mismatch in subview + tinytc_status_ir_invalid_slice = 0x114, ///< Invalid slice + tinytc_status_ir_expand_shape_order_too_small = 0x115, ///< Expand shape too small + tinytc_status_ir_expand_shape_mismatch = 0x116, ///< Invalid expand shape + tinytc_status_ir_collective_called_from_spmd = 0x117, ///< Collective instruction from SPMD + tinytc_status_ir_fp_unsupported = 0x118, ///< Instruction does not support floating type + tinytc_status_ir_spmd_called_from_collective = 0x119, ///< SPMD instruction from collective + tinytc_status_ir_expected_local_address_space = 0x11a, ///< Expected local address space + tinytc_status_ir_expected_global_address_space = 0x11b, ///< Expected global address space + tinytc_status_ir_invalid_offset = 0x11c, ///< Invalid offset + tinytc_status_ir_int_unsupported = 0x11d, ///< Instruction does not support int type + tinytc_status_ir_boolean_unsupported = 0x11e, ///< Instruction does not support boolean type + tinytc_status_ir_complex_unsupported = 0x11f, ///< Instruction does not support complex type tinytc_status_ir_coopmatrix_unsupported = - 0x11e, ///< Instruction does not support coopmatrix type - tinytc_status_ir_forbidden_cast = 0x11f, ///< Forbidden cast - tinytc_status_ir_invalid_beta = 0x120, ///< Invalid beta value - tinytc_status_ir_init_return_mismatch = 0x121, ///< Mismatch of init values and returned values - tinytc_status_ir_invalid_matrix_use = 0x122, ///< Invalid matrix use - tinytc_status_ir_unsupported_coopmatrix_shape = 0x123, ///< Unsupported coopmatrix shape - tinytc_status_ir_incompatible_scalar_types = 0x124, ///< Incompatible scalar types + 0x120, ///< Instruction does not support coopmatrix type + tinytc_status_ir_forbidden_cast = 0x121, ///< Forbidden cast + tinytc_status_ir_invalid_beta = 0x122, ///< Invalid beta value + tinytc_status_ir_init_return_mismatch = 0x123, ///< Mismatch of init values and returned values + tinytc_status_ir_invalid_matrix_use = 0x124, ///< Invalid matrix use + tinytc_status_ir_unsupported_coopmatrix_shape = 0x125, ///< Unsupported coopmatrix shape + tinytc_status_ir_incompatible_scalar_types = 0x126, ///< Incompatible scalar types + tinytc_status_ir_constant_mismatch = 0x127, ///< Constant mismatch // SPIR-V errors tinytc_status_spirv_forbidden_forward_declaration = 0x1000, ///< Forward declaration of id is forbidden @@ -239,18 +243,17 @@ typedef enum { //! Scalar types typedef enum { - tinytc_scalar_type_i1 = 0, ///< Signed 1 bit integer (boolean) - tinytc_scalar_type_i8 = 1, ///< Signed 8 bit integer - tinytc_scalar_type_i16 = 2, ///< Signed 16 bit integer - tinytc_scalar_type_i32 = 3, ///< Signed 32 bit integer - tinytc_scalar_type_i64 = 4, ///< Signed 64 bit integer - tinytc_scalar_type_index = 5, ///< Integer type for indices - tinytc_scalar_type_f32 = 6, ///< Single precision floating point (32 bit) - tinytc_scalar_type_f64 = 7, ///< Double precision floating point (64 bit) - tinytc_scalar_type_c32 = 8, ///< Single precision complex (2x32 bit) - tinytc_scalar_type_c64 = 9 ///< Double precision complex (2x64 bit) + tinytc_scalar_type_i8 = 0, ///< Signed 8 bit integer + tinytc_scalar_type_i16 = 1, ///< Signed 16 bit integer + tinytc_scalar_type_i32 = 2, ///< Signed 32 bit integer + tinytc_scalar_type_i64 = 3, ///< Signed 64 bit integer + tinytc_scalar_type_index = 4, ///< Integer type for indices + tinytc_scalar_type_f32 = 5, ///< Single precision floating point (32 bit) + tinytc_scalar_type_f64 = 6, ///< Double precision floating point (64 bit) + tinytc_scalar_type_c32 = 7, ///< Single precision complex (2x32 bit) + tinytc_scalar_type_c64 = 8 ///< Double precision complex (2x64 bit) } tinytc_scalar_type_t; -#define TINYTC_NUMBER_OF_SCALAR_TYPES 10 // @todo Keep up to date with tinytc_scalar_type_t +#define TINYTC_NUMBER_OF_SCALAR_TYPES 9 // @todo Keep up to date with tinytc_scalar_type_t //! Arithmetic operations typedef enum { diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 6640d065..f03f1ec0 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -59,10 +59,13 @@ enum class status { ir_shape_stride_mismatch = tinytc_status_ir_shape_stride_mismatch, ir_scalar_mismatch = tinytc_status_ir_scalar_mismatch, ir_invalid_number_of_indices = tinytc_status_ir_invalid_number_of_indices, + ir_expected_boolean = tinytc_status_ir_expected_boolean, ir_expected_scalar = tinytc_status_ir_expected_scalar, ir_expected_index = tinytc_status_ir_expected_index, ir_expected_coopmatrix = tinytc_status_ir_expected_coopmatrix, ir_expected_coopmatrix_or_scalar = tinytc_status_ir_expected_coopmatrix_or_scalar, + ir_expected_coopmatrix_scalar_or_boolean = + tinytc_status_ir_expected_coopmatrix_scalar_or_boolean, ir_expected_memref = tinytc_status_ir_expected_memref, ir_expected_memref_or_scalar = tinytc_status_ir_expected_memref_or_scalar, ir_expected_memref_or_group = tinytc_status_ir_expected_memref_or_group, @@ -81,7 +84,7 @@ enum class status { ir_expected_global_address_space = tinytc_status_ir_expected_global_address_space, ir_invalid_offset = tinytc_status_ir_invalid_offset, ir_int_unsupported = tinytc_status_ir_int_unsupported, - ir_i1_unsupported = tinytc_status_ir_i1_unsupported, + ir_boolean_unsupported = tinytc_status_ir_boolean_unsupported, ir_complex_unsupported = tinytc_status_ir_complex_unsupported, ir_coopmatrix_unsupported = tinytc_status_ir_coopmatrix_unsupported, ir_forbidden_cast = tinytc_status_ir_forbidden_cast, @@ -90,6 +93,7 @@ enum class status { ir_invalid_matrix_use = tinytc_status_ir_invalid_matrix_use, ir_unsupported_coopmatrix_shape = tinytc_status_ir_unsupported_coopmatrix_shape, ir_incompatible_scalar_types = tinytc_status_ir_incompatible_scalar_types, + ir_constant_mismatch = tinytc_status_ir_constant_mismatch, spirv_forbidden_forward_declaration = tinytc_status_spirv_forbidden_forward_declaration, spirv_undefined_value = tinytc_status_spirv_undefined_value, ze_result_not_ready = tinytc_status_ze_result_not_ready, @@ -219,7 +223,6 @@ enum class status { //! Scalar types enum class scalar_type { - i1 = tinytc_scalar_type_i1, ///< Signed 1 bit integer (boolean) i8 = tinytc_scalar_type_i8, ///< Signed 8 bit integer i16 = tinytc_scalar_type_i16, ///< Signed 16 bit integer i32 = tinytc_scalar_type_i32, ///< Signed 32 bit integer diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 7a32425f..5a3b626f 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -680,6 +680,17 @@ auto instant_constant_fold_add(region_builder &bb, inst i) -> value { return bb.add(std::move(i)); } +auto get_bool_constant(tinytc_value_t val) -> std::optional { + if (auto i = val->defining_inst(); i) { + if (auto *ci = dyn_cast(i); ci) { + if (std::holds_alternative(ci->value())) { + return std::get(ci->value()); + } + } + } + return std::nullopt; +} + auto get_int_constant(tinytc_value_t val) -> std::optional { if (auto i = val->defining_inst(); i) { if (auto *ci = dyn_cast(i); ci) { diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index f8a87770..26255e27 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -154,6 +154,7 @@ void blas_update(region_builder &bb, bool atomic, value alpha, value ab, value b array_view index_list, location const &loc); auto instant_constant_fold_add(region_builder &bb, inst i) -> value; +auto get_bool_constant(tinytc_value_t val) -> std::optional; auto get_int_constant(tinytc_value_t val) -> std::optional; } // namespace tinytc diff --git a/src/compiler_context_cache.cpp b/src/compiler_context_cache.cpp index 0f0447ab..e6f2c369 100644 --- a/src/compiler_context_cache.cpp +++ b/src/compiler_context_cache.cpp @@ -9,6 +9,7 @@ namespace tinytc { compiler_context_cache::compiler_context_cache(tinytc_compiler_context_t ctx) { + bool_ty = std::unique_ptr(new boolean_data_type(ctx)); void_ty = std::unique_ptr(new void_data_type(ctx)); for (int i = 0; i < TINYTC_NUMBER_OF_SCALAR_TYPES; ++i) { scalar_tys[i] = diff --git a/src/compiler_context_cache.hpp b/src/compiler_context_cache.hpp index 4d61c22c..dc426506 100644 --- a/src/compiler_context_cache.hpp +++ b/src/compiler_context_cache.hpp @@ -35,7 +35,7 @@ class compiler_context_cache { compiler_context_cache(compiler_context_cache const &) = delete; compiler_context_cache &operator=(compiler_context_cache const &) = delete; - std::unique_ptr void_ty; + std::unique_ptr void_ty, bool_ty; std::array, TINYTC_NUMBER_OF_SCALAR_TYPES> scalar_tys; std::unordered_multimap memref_tys; std::unordered_multimap coopmatrix_tys; diff --git a/src/data_type.cpp b/src/data_type.cpp index 3b9058cd..c41e8d94 100644 --- a/src/data_type.cpp +++ b/src/data_type.cpp @@ -29,6 +29,14 @@ char const *tinytc_matrix_use_to_string(tinytc_matrix_use_t u) { return "unknown"; } +tinytc_status_t tinytc_boolean_type_get(tinytc_data_type_t *dt, tinytc_compiler_context_t ctx) { + if (dt == nullptr || ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + + return exception_to_status_code([&] { *dt = boolean_data_type::get(ctx); }); +} + tinytc_status_t tinytc_scalar_type_get(tinytc_data_type_t *dt, tinytc_compiler_context_t ctx, tinytc_scalar_type_t type) { if (dt == nullptr || ctx == nullptr) { diff --git a/src/error.cpp b/src/error.cpp index 6524963f..e9535fd7 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -132,6 +132,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Scalar type mismatch"; case tinytc_status_ir_invalid_number_of_indices: return "Number of indices must match memref order or must be 1 for group types"; + case tinytc_status_ir_expected_boolean: + return "Expected boolean type"; case tinytc_status_ir_expected_scalar: return "Expected scalar type"; case tinytc_status_ir_expected_index: @@ -140,6 +142,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Expected coopmatrix type"; case tinytc_status_ir_expected_coopmatrix_or_scalar: return "Expected coopmatrix type or scalar type"; + case tinytc_status_ir_expected_coopmatrix_scalar_or_boolean: + return "Expected coopmatrix type, scalar type, or boolean type"; case tinytc_status_ir_expected_memref: return "Expected memref type"; case tinytc_status_ir_expected_memref_or_scalar: @@ -177,8 +181,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Offset must be non-negative or dynamic"; case tinytc_status_ir_int_unsupported: return "int type unsupported by instruction"; - case tinytc_status_ir_i1_unsupported: - return "i1 type unsupported by instruction"; + case tinytc_status_ir_boolean_unsupported: + return "boolean type unsupported by instruction"; case tinytc_status_ir_complex_unsupported: return "complex type unsupported by instruction"; case tinytc_status_ir_coopmatrix_unsupported: @@ -196,6 +200,9 @@ char const *tinytc_error_string(tinytc_status_t status) { "target architecture"; case tinytc_status_ir_incompatible_scalar_types: return "Scalar types violate compatibility rules"; + case tinytc_status_ir_constant_mismatch: + return "Type of constant does not match type of returned value"; + // SPIR-V case tinytc_status_spirv_forbidden_forward_declaration: return "Forward declaration of id is forbidden"; case tinytc_status_spirv_undefined_value: diff --git a/src/inst.cpp b/src/inst.cpp index 9101168a..2260c573 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -188,6 +188,17 @@ tinytc_status_t tinytc_cmp_inst_create(tinytc_inst_t *instr, tinytc_cmp_conditio }); } +tinytc_status_t tinytc_constant_inst_create_boolean(tinytc_inst_t *instr, tinytc_bool_t value, + tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = std::make_unique(value != 0, ty, get_optional(loc)).release(); + }); +} + tinytc_status_t tinytc_constant_inst_create_complex(tinytc_inst_t *instr, double value_re, double value_im, tinytc_data_type_t ty, const tinytc_location_t *loc) { @@ -227,6 +238,12 @@ tinytc_status_t tinytc_constant_inst_create_one(tinytc_inst_t *instr, tinytc_dat return tinytc_status_invalid_arguments; } + if (const auto *bt = dyn_cast(ty); bt != nullptr) { + return exception_to_status_code([&] { + *instr = std::make_unique(true, ty, get_optional(loc)).release(); + }); + } + scalar_type sty; if (const auto *st = dyn_cast(ty); st != nullptr) { sty = st->ty(); @@ -238,7 +255,6 @@ tinytc_status_t tinytc_constant_inst_create_one(tinytc_inst_t *instr, tinytc_dat return exception_to_status_code([&] { switch (sty) { - case scalar_type::i1: case scalar_type::i8: case scalar_type::i16: case scalar_type::i32: @@ -266,6 +282,12 @@ tinytc_status_t tinytc_constant_inst_create_zero(tinytc_inst_t *instr, tinytc_da return tinytc_status_invalid_arguments; } + if (const auto *bt = dyn_cast(ty); bt != nullptr) { + return exception_to_status_code([&] { + *instr = std::make_unique(false, ty, get_optional(loc)).release(); + }); + } + scalar_type sty; if (const auto *st = dyn_cast(ty); st != nullptr) { sty = st->ty(); @@ -277,7 +299,6 @@ tinytc_status_t tinytc_constant_inst_create_zero(tinytc_inst_t *instr, tinytc_da return exception_to_status_code([&] { switch (sty) { - case scalar_type::i1: case scalar_type::i8: case scalar_type::i16: case scalar_type::i32: diff --git a/src/node/data_type_node.cpp b/src/node/data_type_node.cpp index f6e1bec3..bdc8d28e 100644 --- a/src/node/data_type_node.cpp +++ b/src/node/data_type_node.cpp @@ -18,6 +18,10 @@ namespace tinytc { +auto boolean_data_type::get(tinytc_compiler_context_t ctx) -> tinytc_data_type_t { + return ctx->cache()->bool_ty.get(); +} + auto coopmatrix_data_type::get(tinytc_data_type_t ty, std::int64_t rows, std::int64_t cols, matrix_use use, location const &lc) -> tinytc_data_type_t { auto ctx = ty->context(); diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp index ae0095cb..024aadb9 100644 --- a/src/node/data_type_node.hpp +++ b/src/node/data_type_node.hpp @@ -15,10 +15,10 @@ #include namespace tinytc { -enum class DTK { coopmatrix, group, memref, scalar, void_ }; +enum class DTK { bool_, coopmatrix, group, memref, scalar, void_ }; using data_type_nodes = - type_list; + type_list; } // namespace tinytc struct tinytc_data_type { @@ -40,6 +40,16 @@ namespace tinytc { using data_type_node = ::tinytc_data_type; +class boolean_data_type : public data_type_node { + public: + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::bool_; } + static auto get(tinytc_compiler_context_t ctx) -> tinytc_data_type_t; + + protected: + inline boolean_data_type(tinytc_compiler_context_t ctx) : data_type_node(DTK::bool_, ctx) {} + friend class compiler_context_cache; +}; + class coopmatrix_data_type : public data_type_node { public: inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::coopmatrix; } diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 7cf3a4a7..c3c854ae 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -234,7 +234,21 @@ arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b op(op_b, b0); loc(lc); - if (isa(*a().ty())) { + if (isa(*a().ty())) { + auto const inst_supports_bool = [&] { + switch (operation) { + case arithmetic::and_: + case arithmetic::or_: + case arithmetic::xor_: + return true; + default: + return false; + } + }(); + if (!inst_supports_bool) { + throw compilation_error(loc(), status::ir_boolean_unsupported); + } + } else if (isa(*a().ty())) { if (!isa(*b().ty())) { throw compilation_error(loc(), status::ir_expected_coopmatrix); } @@ -259,7 +273,6 @@ arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b if (a_ty != b_ty) { throw compilation_error(loc(), status::ir_scalar_mismatch); } - bool inst_supports_i1 = true; bool inst_supports_fp = true; bool inst_supports_complex = true; switch (operation) { @@ -267,10 +280,8 @@ arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b case arithmetic::sub: case arithmetic::mul: case arithmetic::div: - inst_supports_i1 = false; break; case arithmetic::rem: - inst_supports_i1 = false; inst_supports_complex = false; break; case arithmetic::and_: @@ -281,14 +292,10 @@ arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b break; case arithmetic::shl: case arithmetic::shr: - inst_supports_i1 = false; inst_supports_fp = false; inst_supports_complex = false; break; } - if (!inst_supports_i1 && a_ty == scalar_type::i1) { - throw compilation_error(loc(), status::ir_i1_unsupported); - } if (!inst_supports_fp && is_floating_type(a_ty)) { throw compilation_error(loc(), status::ir_fp_unsupported); } @@ -306,60 +313,60 @@ arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0 op(op_a, a0); loc(lc); - tinytc_data_type_t to_ty = nullptr; - - if (isa(*a().ty())) { - if (operation_ != arithmetic_unary::neg) { - throw compilation_error(loc(), status::ir_coopmatrix_unsupported); - } - to_ty = a().ty(); - } else { - auto a_ty = get_scalar_type(loc(), a()); - to_ty = a_ty; - - bool inst_supports_i1 = true; - bool inst_supports_int = true; - bool inst_supports_fp = true; - bool inst_supports_complex = true; - switch (operation_) { - case arithmetic_unary::abs: - case arithmetic_unary::neg: - inst_supports_i1 = false; - break; - case arithmetic_unary::not_: - inst_supports_fp = false; - inst_supports_complex = false; - break; - case arithmetic_unary::conj: - case arithmetic_unary::im: - case arithmetic_unary::re: - inst_supports_i1 = false; - inst_supports_int = false; - inst_supports_fp = false; - break; + tinytc_data_type_t to_ty = [&]() -> tinytc_data_type_t { + if (isa(*a().ty())) { + if (operation_ != arithmetic_unary::not_) { + throw compilation_error(loc(), status::ir_boolean_unsupported); + } + return a().ty(); + } else if (isa(*a().ty())) { + if (operation_ != arithmetic_unary::neg) { + throw compilation_error(loc(), status::ir_coopmatrix_unsupported); + } + return a().ty(); + } else { + auto a_ty = get_scalar_type(loc(), a()); + tinytc_data_type_t to_ty = a_ty; + + bool inst_supports_int = true; + bool inst_supports_fp = true; + bool inst_supports_complex = true; + switch (operation_) { + case arithmetic_unary::abs: + case arithmetic_unary::neg: + break; + case arithmetic_unary::not_: + inst_supports_fp = false; + inst_supports_complex = false; + break; + case arithmetic_unary::conj: + case arithmetic_unary::im: + case arithmetic_unary::re: + inst_supports_int = false; + inst_supports_fp = false; + break; + } + if (!inst_supports_int && is_integer_type(a_ty->ty())) { + throw compilation_error(loc(), status::ir_int_unsupported); + } + if (!inst_supports_fp && is_floating_type(a_ty->ty())) { + throw compilation_error(loc(), status::ir_fp_unsupported); + } + if (!inst_supports_complex && is_complex_type(a_ty->ty())) { + throw compilation_error(loc(), status::ir_complex_unsupported); + } + switch (operation_) { + case arithmetic_unary::abs: + case arithmetic_unary::im: + case arithmetic_unary::re: + to_ty = scalar_data_type::get(a_ty->context(), element_type(a_ty->ty())); + break; + default: + break; + } + return to_ty; } - if (!inst_supports_i1 && a_ty->ty() == scalar_type::i1) { - throw compilation_error(loc(), status::ir_i1_unsupported); - } - if (!inst_supports_int && is_integer_type(a_ty->ty())) { - throw compilation_error(loc(), status::ir_int_unsupported); - } - if (!inst_supports_fp && is_floating_type(a_ty->ty())) { - throw compilation_error(loc(), status::ir_fp_unsupported); - } - if (!inst_supports_complex && is_complex_type(a_ty->ty())) { - throw compilation_error(loc(), status::ir_complex_unsupported); - } - switch (operation_) { - case arithmetic_unary::abs: - case arithmetic_unary::im: - case arithmetic_unary::re: - to_ty = scalar_data_type::get(a_ty->context(), element_type(a_ty->ty())); - break; - default: - break; - } - } + }(); result(0) = value_node{to_ty, this, lc}; } @@ -428,7 +435,7 @@ compare_inst::compare_inst(cmp_condition cond, tinytc_value_t a0, tinytc_value_t throw compilation_error(loc(), status::ir_complex_unsupported); } - auto result_ty = scalar_data_type::get(at->context(), scalar_type::i1); + auto result_ty = boolean_data_type::get(at->context()); result(0) = value_node{result_ty, this, lc}; } @@ -442,16 +449,20 @@ constant_inst::constant_inst(value_type const &value, tinytc_data_type_t ty, loc (is_complex_type(ty) && std::holds_alternative>(val)); }; - if (auto st = dyn_cast(ty); st) { + if (auto bt = dyn_cast(ty); bt) { + if (!std::holds_alternative(value_)) { + throw compilation_error(loc(), status::ir_constant_mismatch); + } + } else if (auto st = dyn_cast(ty); st) { if (!type_ok(value_, st->ty())) { - throw compilation_error(loc(), status::ir_scalar_mismatch); + throw compilation_error(loc(), status::ir_constant_mismatch); } } else if (auto ct = dyn_cast(ty); ct) { if (!type_ok(value_, ct->component_ty())) { - throw compilation_error(loc(), status::ir_scalar_mismatch); + throw compilation_error(loc(), status::ir_constant_mismatch); } } else { - throw compilation_error(loc(), status::ir_expected_coopmatrix_or_scalar); + throw compilation_error(loc(), status::ir_expected_coopmatrix_scalar_or_boolean); } result(0) = value_node{ty, this, lc}; @@ -826,6 +837,9 @@ if_inst::if_inst(tinytc_value_t condition, array_view return : standard_inst{IK::if_, 1, static_cast(return_types.size())} { op(0, condition); loc(lc); + if (!isa(*condition->ty())) { + throw compilation_error(loc(), status::ir_expected_boolean); + } for (std::size_t i = 0; i < return_types.size(); ++i) { if (!isa(*return_types[i]) && !isa(*return_types[i])) { diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 0840c99d..63040706 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -452,7 +452,7 @@ class compare_inst : public standard_inst<2, 1> { class constant_inst : public standard_inst<0, 1> { public: - using value_type = std::variant>; + using value_type = std::variant>; inline static bool classof(inst_node const &i) { return i.type_id() == IK::constant; } constant_inst(value_type const &value, tinytc_data_type_t ty, location const &lc = {}); diff --git a/src/parser/lexer.re b/src/parser/lexer.re index fdb80368..c77e2991 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -41,7 +41,7 @@ lex: local_named_identifier = "%" named_identifier; global_identifier = "@" (unnamed_identifier | named_identifier); - integer_type = "i" ("1" | "8" | "16" | "32" | "64") | "index"; + integer_type = "i" ("8" | "16" | "32" | "64") | "index"; floating_type = ("f" | "c") ("32" | "64"); digit = [0-9]; @@ -103,8 +103,8 @@ lex: ".global" { adv_loc(); return parser::make_GLOBAL_ATTR(loc_); } // constants - "true" { adv_loc(); return parser::make_INTEGER_CONSTANT(1, loc_); } - "false" { adv_loc(); return parser::make_INTEGER_CONSTANT(0, loc_); } + "true" { adv_loc(); return parser::make_BOOLEAN_CONSTANT(true, loc_); } + "false" { adv_loc(); return parser::make_BOOLEAN_CONSTANT(false, loc_); } integer_constant { adv_loc(); auto i = lex_integer_constant(b, YYCURSOR); @@ -117,6 +117,7 @@ lex: } // types + "bool" { return parser::make_BOOLEAN(loc_); } integer_type { adv_loc(); auto t = lex_integer_type(b, YYCURSOR); @@ -280,7 +281,6 @@ scalar_type lexer::lex_integer_type(char const *s, char const *) { re2c:yyfill:enable = 0; re2c:define:YYCURSOR = s; - "i1" { return scalar_type::i1; } "i8" { return scalar_type::i8; } "i16" { return scalar_type::i16; } "i32" { return scalar_type::i32; } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 733bd4d7..e4186854 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -101,6 +101,7 @@ GLOBAL "global" LOCAL_ATTR ".local" GLOBAL_ATTR ".global" + BOOLEAN "bool" COOPMATRIX "coopmatrix" MEMREF "memref" GROUP "group" @@ -143,6 +144,7 @@ ; %token > LOCAL_IDENTIFIER %token GLOBAL_IDENTIFIER +%token BOOLEAN_CONSTANT %token INTEGER_CONSTANT %token FLOATING_CONSTANT %token INTEGER_TYPE @@ -162,6 +164,7 @@ %nterm >> attributes %nterm > attribute %nterm data_type +%nterm boolean_type %nterm scalar_type %nterm coopmatrix_type %nterm memref_type @@ -330,10 +333,15 @@ attribute: data_type: - scalar_type + boolean_type | coopmatrix_type - | memref_type | group_type + | memref_type + | scalar_type +; + +boolean_type: + BOOLEAN { $$ = get_boolean(ctx.cctx()); } ; scalar_type: @@ -905,6 +913,16 @@ constant_inst: YYERROR; } } + | CONSTANT BOOLEAN_CONSTANT RETURNS data_type { + try { + $$ = inst { + std::make_unique($BOOLEAN_CONSTANT, $data_type, @constant_inst).release() + }; + } catch (compilation_error const &e) { + error(e.loc(), e.what()); + YYERROR; + } + } ; cooperative_matrix_load_inst: @@ -1109,7 +1127,7 @@ group_size_inst: if_inst: IF var[condition] optional_returned_values { - check_type($condition, get_scalar(ctx.cctx(), scalar_type::i1), @condition, @condition); + check_type($condition, get_boolean(ctx.cctx()), @condition, @condition); try { auto loc = @IF; loc.end = @optional_returned_values.end; diff --git a/src/pass/constant_folding.hpp b/src/pass/constant_folding.hpp index 6eb25f5f..69d303dd 100644 --- a/src/pass/constant_folding.hpp +++ b/src/pass/constant_folding.hpp @@ -53,27 +53,31 @@ struct compute_unary_op { data_type ty; location const &loc; + auto operator()(bool a) -> fold_result { + bool val = false; + switch (operation) { + case arithmetic_unary::not_: + val = !a; + break; + default: + throw compilation_error(loc, status::ir_boolean_unsupported); + } + return make_constant(val, ty, loc); + } + template - requires(std::is_integral_v) + requires(std::is_integral_v && !std::is_same_v) auto operator()(T a) -> fold_result { T val = 0; switch (operation) { case arithmetic_unary::abs: - if constexpr (std::is_same_v) { - val = a; - } else { - val = a < 0 ? -a : a; - } + val = a < 0 ? -a : a; break; case arithmetic_unary::neg: val = -a; break; case arithmetic_unary::not_: - if constexpr (std::is_same_v) { - val = !a; - } else { - val = ~a; - } + val = ~a; break; default: throw compilation_error(loc, status::ir_int_unsupported); @@ -156,8 +160,26 @@ struct compute_binary_op { data_type ty; location const &loc; + auto operator()(bool a, bool b) -> fold_result { + bool val = false; + switch (operation) { + case arithmetic::and_: + val = a && b; + break; + case arithmetic::or_: + val = a || b; + break; + case arithmetic::xor_: + val = a != b; + break; + default: + throw compilation_error(loc, status::ir_boolean_unsupported); + } + return make_constant(val, ty, loc); + } + template - requires(std::is_integral_v) + requires(std::is_integral_v && !std::is_same_v) auto operator()(T a, T b) -> fold_result { T val = 0; switch (operation) { @@ -168,11 +190,7 @@ struct compute_binary_op { val = a - b; break; case arithmetic::mul: - if constexpr (std::is_same_v) { - val = a && b; - } else { - val = a * b; - } + val = a * b; break; case arithmetic::div: val = a / b; @@ -181,18 +199,10 @@ struct compute_binary_op { val = a % b; break; case arithmetic::shl: - if constexpr (std::is_same_v) { - throw compilation_error(loc, status::ir_i1_unsupported); - } else { - val = a << b; - } + val = a << b; break; case arithmetic::shr: - if constexpr (std::is_same_v) { - throw compilation_error(loc, status::ir_i1_unsupported); - } else { - val = a >> b; - } + val = a >> b; break; case arithmetic::and_: val = a & b; @@ -251,8 +261,27 @@ struct compute_binop_identities { bool is_second_operand; location const &loc; + auto operator()(bool a) -> fold_result { + switch (operation) { + case arithmetic::and_: + if (!a) { + return make_constant(false, operand.ty(), loc); + } + break; + case arithmetic::or_: + case arithmetic::xor_: + if (!a) { + return &operand; + } + break; + default: + break; + } + return tinytc_value_t{}; + } + template - requires(std::is_integral_v) + requires(std::is_integral_v && !std::is_same_v) auto operator()(T a) -> fold_result { switch (operation) { case arithmetic::add: @@ -424,8 +453,6 @@ template auto value_cast(U const &u) { return value_cas template auto compute_cast(scalar_data_type *to_ty, T A, location const &loc) -> fold_result { switch (to_ty->ty()) { - case scalar_type::i1: - return make_constant(value_cast(A), to_ty, loc); case scalar_type::i8: return make_constant(value_cast(A), to_ty, loc); case scalar_type::i16: diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index db55b17d..9e354ab8 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -176,6 +176,9 @@ auto convert_to_opencl_pass::get_scalar_type(value_node const &v) -> scalar_type clir::data_type convert_to_opencl_pass::operator()(void_data_type const &) { return clir::builtin_type::void_t; } +clir::data_type convert_to_opencl_pass::operator()(boolean_data_type const &) { + return clir::builtin_type::bool_t; +} clir::data_type convert_to_opencl_pass::operator()(coopmatrix_data_type const &ct) { return array_of(to_clir_ty(ct.component_ty()), ct.length(core_cfg_.subgroup_size)); } @@ -330,6 +333,18 @@ std::vector convert_to_opencl_pass::operator()(barrier_inst const &b } std::vector convert_to_opencl_pass::operator()(arith_inst const &a) { + auto const make_boolean = [](arithmetic op, clir::expr a, clir::expr b) -> clir::expr { + switch (op) { + case arithmetic::and_: + return std::move(a) && std::move(b); + case arithmetic::or_: + return std::move(a) || std::move(b); + case arithmetic::xor_: + return std::move(a) != std::move(b); + default: + return nullptr; + } + }; auto const make = [](arithmetic op, clir::expr a, clir::expr b, scalar_type sty) -> clir::expr { switch (op) { case arithmetic::add: @@ -350,14 +365,8 @@ std::vector convert_to_opencl_pass::operator()(arith_inst const &a) case arithmetic::shr: return std::move(a) >> std::move(b); case arithmetic::and_: - if (sty == scalar_type::i1) { - return std::move(a) && std::move(b); - } return std::move(a) & std::move(b); case arithmetic::or_: - if (sty == scalar_type::i1) { - return std::move(a) || std::move(b); - } return std::move(a) | std::move(b); case arithmetic::xor_: return std::move(a) ^ std::move(b); @@ -369,7 +378,13 @@ std::vector convert_to_opencl_pass::operator()(arith_inst const &a) auto lhs_ty = visit(*this, *a.result()->ty()); auto av = val(a.a()); auto bv = val(a.b()); - if (auto st = dyn_cast(a.result(0).ty()); st) { + if (isa(*a.result(0).ty())) { + auto op = make_boolean(a.operation(), av, bv); + if (!bool(op)) { + throw compilation_error(a.loc(), status::ir_boolean_unsupported); + } + return {declaration_assignment(std::move(lhs_ty), std::move(lhs), std::move(op))}; + } else if (auto st = dyn_cast(a.result(0).ty()); st) { auto op = make(a.operation(), av, bv, st->ty()); return {declaration_assignment(std::move(lhs_ty), std::move(lhs), std::move(op))}; } else if (auto ct = dyn_cast(a.result(0).ty()); ct) { @@ -384,7 +399,7 @@ std::vector convert_to_opencl_pass::operator()(arith_inst const &a) } return clinst; } - throw compilation_error(a.loc(), status::ir_expected_coopmatrix_or_scalar); + throw compilation_error(a.loc(), status::ir_expected_coopmatrix_scalar_or_boolean); } std::vector convert_to_opencl_pass::operator()(arith_unary_inst const &a) { @@ -402,9 +417,6 @@ std::vector convert_to_opencl_pass::operator()(arith_unary_inst cons case arithmetic_unary::neg: return -std::move(a); case arithmetic_unary::not_: - if (sty == scalar_type::i1) { - return !std::move(a); - } return ~std::move(a); case arithmetic_unary::conj: return clir::init_vector(to_clir_ty(sty), {a.s(0), -a.s(1)}); @@ -419,7 +431,12 @@ std::vector convert_to_opencl_pass::operator()(arith_unary_inst cons auto lhs = declare(a.result(0)); auto lhs_ty = visit(*this, *a.result()->ty()); auto av = val(a.a()); - if (auto st = dyn_cast(a.a().ty()); st) { + if (isa(*a.result(0).ty())) { + if (a.operation() != arithmetic_unary::not_) { + throw compilation_error(a.loc(), status::ir_boolean_unsupported); + } + return {declaration_assignment(std::move(lhs_ty), std::move(lhs), !std::move(av))}; + } else if (auto st = dyn_cast(a.a().ty()); st) { auto op = make(a.operation(), av, st->ty()); return {declaration_assignment(std::move(lhs_ty), std::move(lhs), std::move(op))}; } else if (auto ct = dyn_cast(a.a().ty()); ct) { @@ -434,7 +451,7 @@ std::vector convert_to_opencl_pass::operator()(arith_unary_inst cons } return clinst; } - throw compilation_error(a.loc(), status::ir_expected_coopmatrix_or_scalar); + throw compilation_error(a.loc(), status::ir_expected_coopmatrix_scalar_or_boolean); } std::vector convert_to_opencl_pass::operator()(cast_inst const &c) { @@ -509,6 +526,9 @@ std::vector convert_to_opencl_pass::operator()(compare_inst const &c std::vector convert_to_opencl_pass::operator()(constant_inst const &c) { auto const get_rhs = [&c](scalar_type ty, short ty_bits) { return std::visit(overloaded{ + [&](bool) -> clir::expr { + throw compilation_error(c.loc(), status::internal_compiler_error); + }, [&](std::int64_t i) { return clir::expr(i, ty_bits); }, [&](double d) { return clir::expr(d, ty_bits); }, [&](std::complex d) { @@ -521,7 +541,13 @@ std::vector convert_to_opencl_pass::operator()(constant_inst const & }; auto lhs = declare(c.result(0)); auto lhs_ty = visit(*this, *c.result()->ty()); - if (auto st = dyn_cast(c.result(0).ty()); st) { + if (isa(*c.result(0).ty())) { + if (!std::holds_alternative(c.value())) { + throw compilation_error(c.loc(), status::internal_compiler_error); + } + return {declaration_assignment(std::move(lhs_ty), std::move(lhs), + clir::expr(std::int8_t{std::get(c.value())}))}; + } else if (auto st = dyn_cast(c.result(0).ty()); st) { auto ty_bits = static_cast(size(st->ty()) * 8); return { declaration_assignment(std::move(lhs_ty), std::move(lhs), get_rhs(st->ty(), ty_bits))}; @@ -537,7 +563,7 @@ std::vector convert_to_opencl_pass::operator()(constant_inst const & } return clinst; } - throw compilation_error(c.loc(), status::ir_expected_coopmatrix_or_scalar); + throw compilation_error(c.loc(), status::ir_expected_coopmatrix_scalar_or_boolean); } std::vector convert_to_opencl_pass::operator()(cooperative_matrix_load_inst const &c) { @@ -581,7 +607,7 @@ std::vector convert_to_opencl_pass::operator()(cooperative_matrix_lo auto row_in_bounds = clir::var{}; if (check_m) { auto m = clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size; - clinst.emplace_back(declaration_assignment(to_clir_ty(scalar_type::i1), row_in_bounds, + clinst.emplace_back(declaration_assignment(clir::builtin_type::bool_t, row_in_bounds, m >= -pv[omode] && m < rem[omode])); } for (std::int64_t k = 0; k < rt->shape(1 - rmode); ++k) { diff --git a/src/pass/convert_to_opencl.hpp b/src/pass/convert_to_opencl.hpp index f109f44e..7ac30813 100644 --- a/src/pass/convert_to_opencl.hpp +++ b/src/pass/convert_to_opencl.hpp @@ -61,6 +61,7 @@ class convert_to_opencl_pass { /* Data type nodes */ clir::data_type operator()(void_data_type const &); + clir::data_type operator()(boolean_data_type const &ct); clir::data_type operator()(coopmatrix_data_type const &ct); clir::data_type operator()(group_data_type const &g); clir::data_type operator()(memref_data_type const &m); diff --git a/src/pass/dead_code_elimination.cpp b/src/pass/dead_code_elimination.cpp index d0d51bd6..739ced1e 100644 --- a/src/pass/dead_code_elimination.cpp +++ b/src/pass/dead_code_elimination.cpp @@ -44,8 +44,8 @@ auto dead_code_analysis::operator()(if_inst &in) -> bool { constant_inst *cond_const = dyn_cast(in.condition().defining_inst()); if (in.num_results() == 0 && cond_const) { // If-instruction is dead if condition is constant and false - return std::holds_alternative(cond_const->value()) && - std::get(cond_const->value()) == 0; + return std::holds_alternative(cond_const->value()) && + std::get(cond_const->value()) == false; } return false; diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 85d359cb..7a92b58d 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -23,6 +23,7 @@ dump_ir_pass::dump_ir_pass(std::ostream &os, int level_limit) : os_(&os), lvl_li /* Data type nodes */ void dump_ir_pass::operator()(void_data_type const &) { *os_ << "void"; } +void dump_ir_pass::operator()(boolean_data_type const &) { *os_ << "bool"; } void dump_ir_pass::operator()(coopmatrix_data_type const &ct) { *os_ << "coopmatrix<"; visit(*this, *ct.ty()); @@ -179,6 +180,7 @@ void dump_ir_pass::operator()(constant_inst const &c) { dump_val(c.result(0)); *os_ << " = constant "; std::visit(overloaded{ + [&](bool b) { *os_ << (b ? "true" : "false"); }, [&](std::int64_t i) { if (is_dynamic_value(i)) { *os_ << "?"; diff --git a/src/pass/dump_ir.hpp b/src/pass/dump_ir.hpp index 4e6dcf18..627059ca 100644 --- a/src/pass/dump_ir.hpp +++ b/src/pass/dump_ir.hpp @@ -23,6 +23,7 @@ class dump_ir_pass { /* Data type nodes */ void operator()(void_data_type const &); + void operator()(boolean_data_type const &); void operator()(coopmatrix_data_type const &ct); void operator()(group_data_type const &g); void operator()(memref_data_type const &m); diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index f006f132..eef9e37b 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -96,7 +96,7 @@ void gemm_microkernel(region_builder &bb, transpose tA, transpose tB, bool atomi auto K0 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, tmp, c_k_block_size, loc)); c_init = compute_c(bb, k_block_size, c_zero, K0, c_init); auto needs_remainder = instant_constant_fold_add(bb, make_cmp(cmp_condition::lt, K0, K, loc)); - auto r = get_int_constant(needs_remainder); + auto r = get_bool_constant(needs_remainder); if (r) { if (*r != 0) { c_init = compute_c(bb, 1, K0, K, c_init); diff --git a/src/recipe.cpp b/src/recipe.cpp index 989745f2..222e9e27 100644 --- a/src/recipe.cpp +++ b/src/recipe.cpp @@ -37,8 +37,6 @@ auto is_argument_zero(scalar_type type, std::size_t arg_size, const void *arg_va return is_argument_zero>(arg_size, arg_value); case scalar_type::c64: return is_argument_zero>(arg_size, arg_value); - case scalar_type::i1: - break; }; throw status::invalid_arguments; } diff --git a/src/scalar_type.cpp b/src/scalar_type.cpp index 6ee734cc..121d43fd 100644 --- a/src/scalar_type.cpp +++ b/src/scalar_type.cpp @@ -36,7 +36,6 @@ bool is_complex_type(scalar_type ty) { bool is_integer_type(scalar_type ty) { switch (ty) { - case scalar_type::i1: case scalar_type::i8: case scalar_type::i16: case scalar_type::i32: @@ -74,8 +73,6 @@ clir::data_type to_clir_ty(scalar_type ty, short size, clir::address_space as, clir::type_qualifier q) { const auto base_type = [](scalar_type ty) { switch (ty) { - case scalar_type::i1: - return clir::builtin_type::bool_t; case scalar_type::i8: return clir::builtin_type::char_t; case scalar_type::i16: @@ -97,7 +94,6 @@ clir::data_type to_clir_ty(scalar_type ty, short size, clir::address_space as, }; const auto components = [](scalar_type ty) -> short { switch (ty) { - case scalar_type::i1: case scalar_type::i8: case scalar_type::i16: case scalar_type::i32: @@ -154,8 +150,6 @@ clir::address_space to_clir_address_space(address_space as) { char const *tinytc_scalar_type_to_string(tinytc_scalar_type_t ty) { switch (ty) { - case tinytc_scalar_type_i1: - return "i1"; case tinytc_scalar_type_i8: return "i8"; case tinytc_scalar_type_i16: @@ -179,7 +173,6 @@ char const *tinytc_scalar_type_to_string(tinytc_scalar_type_t ty) { } size_t tinytc_scalar_type_size(tinytc_scalar_type_t ty) { switch (ty) { - case tinytc_scalar_type_i1: case tinytc_scalar_type_i8: return 1; case tinytc_scalar_type_i16: diff --git a/test/codegen/if.ir b/test/codegen/if.ir index 2efba14f..dcd0b1ad 100644 --- a/test/codegen/if.ir +++ b/test/codegen/if.ir @@ -7,7 +7,7 @@ func @if0(%0: i32) { %c0 = constant 0 -> i32 %1 = cmp.lt %0, %c16 : i32 %2 = cmp.ge %0, %c0 : i32 - %3 = arith.and %1, %2 : i1 + %3 = arith.and %1, %2 : bool if %3 { } else { } diff --git a/test/codegen/scalar_arithmetic.ir b/test/codegen/scalar_arithmetic.ir index 1198b34b..bc9ee44a 100644 --- a/test/codegen/scalar_arithmetic.ir +++ b/test/codegen/scalar_arithmetic.ir @@ -2,7 +2,7 @@ ; SPDX-License-Identifier: BSD-3-Clause ; RUN: %tinytc-oc -O0 < %s | filecheck %s -func @t1(%a: i32, %b: i32, %a1: i1, %b1: i1) { +func @t1(%a: i32, %b: i32, %a1: bool, %b1: bool) { %1 = arith.add %a, %b : i32 %2 = arith.sub %a, %b : i32 %3 = arith.mul %a, %b : i32 @@ -11,13 +11,13 @@ func @t1(%a: i32, %b: i32, %a1: i1, %b1: i1) { %6 = arith.shl %a, %b : i32 %7 = arith.shr %a, %b : i32 %8 = arith.and %a, %b : i32 - %9 = arith.and %a1, %b1 : i1 + %9 = arith.and %a1, %b1 : bool %10 = arith.or %a, %b : i32 - %11 = arith.or %a1, %b1 : i1 + %11 = arith.or %a1, %b1 : bool %12 = arith.xor %a, %b : i32 %13 = arith.neg %a : i32 %14 = arith.not %a : i32 - %15 = arith.not %a1 : i1 + %15 = arith.not %a1 : bool %16 = arith.abs %a : i32 ; CHECK: int x = a + b; ; CHECK-NEXT: int x1 = a - b; diff --git a/test/generator.cpp b/test/generator.cpp index b0be7a2b..a0705c73 100644 --- a/test/generator.cpp +++ b/test/generator.cpp @@ -108,17 +108,6 @@ TEST_CASE("compatible scalar type") { } } - CHECK(compatible_type(scalar_type::i1, scalar_type::i1) == scalar_type::i1); - CHECK(compatible_type(scalar_type::i1, scalar_type::i8) == scalar_type::i8); - CHECK(compatible_type(scalar_type::i1, scalar_type::i16) == scalar_type::i16); - CHECK(compatible_type(scalar_type::i1, scalar_type::i32) == scalar_type::i32); - CHECK(compatible_type(scalar_type::i1, scalar_type::i64) == scalar_type::i64); - CHECK(compatible_type(scalar_type::i1, scalar_type::index) == scalar_type::index); - CHECK(compatible_type(scalar_type::i1, scalar_type::f32) == scalar_type::f32); - CHECK(compatible_type(scalar_type::i1, scalar_type::f64) == scalar_type::f64); - CHECK(compatible_type(scalar_type::i1, scalar_type::c32) == scalar_type::c32); - CHECK(compatible_type(scalar_type::i1, scalar_type::c64) == scalar_type::c64); - CHECK(compatible_type(scalar_type::i8, scalar_type::i8) == scalar_type::i8); CHECK(compatible_type(scalar_type::i8, scalar_type::i16) == scalar_type::i16); CHECK(compatible_type(scalar_type::i8, scalar_type::i32) == scalar_type::i32); diff --git a/test/opt/constant-propagation-safe.ir b/test/opt/constant-propagation-safe.ir index 50447cd7..1de79f15 100644 --- a/test/opt/constant-propagation-safe.ir +++ b/test/opt/constant-propagation-safe.ir @@ -108,3 +108,32 @@ func @identity_ixor(%a: i32) { ; CHECK-LABEL: func @identity_ixor({{.*}} ; CHECK: %2 = arith.add %a, %a : i32 } + +func @identity_band(%a: bool) { + %c0 = constant false -> bool + %0 = arith.and %a, %c0 : bool + %1 = arith.and %c0, %a : bool + %2 = arith.and %0, %1 : bool +; CHECK-LABEL: func @identity_band({{.*}} +; CHECK: %0 = constant false -> bool +; CHECK: %2 = constant false -> bool +; CHECK: %5 = arith.and %0, %2 : bool +} + +func @identity_bor(%a: bool) { + %c0 = constant false -> bool + %0 = arith.or %a, %c0 : bool + %1 = arith.or %c0, %a : bool + %2 = arith.and %0, %1 : bool +; CHECK-LABEL: func @identity_bor({{.*}} +; CHECK: %2 = arith.and %a, %a : bool +} + +func @identity_bxor(%a: bool) { + %c0 = constant false -> bool + %0 = arith.xor %a, %c0 : bool + %1 = arith.xor %c0, %a : bool + %2 = arith.and %0, %1 : bool +; CHECK-LABEL: func @identity_bxor({{.*}} +; CHECK: %2 = arith.and %a, %a : bool +} diff --git a/test/opt/constant-propagation.ir b/test/opt/constant-propagation.ir index fbd3adc8..a03c6d7a 100644 --- a/test/opt/constant-propagation.ir +++ b/test/opt/constant-propagation.ir @@ -56,24 +56,38 @@ func @known_loop_iter_args() { func @known_arith() { %0 = constant 1 -> i64 %1 = constant 2 -> i64 - %3 = constant -2.0 -> f32 - %4 = constant [1.0, -1.0] -> c32 - %5 = arith.not %0 : i64 - %6 = arith.add %0, %1 : i64 - %7 = arith.neg %3 : f32 - %8 = arith.add %4, %4 : c32 - %9 = arith.abs %3 : f32 + %2 = constant -2.0 -> f32 + %3 = constant [1.0, -1.0] -> c32 + %4 = constant false -> bool + %5 = constant true -> bool + %6 = arith.not %0 : i64 + %7 = arith.add %0, %1 : i64 + %8 = arith.neg %2 : f32 + %9 = arith.add %3, %3 : c32 + %10 = arith.abs %2 : f32 + %11 = arith.and %4, %5 : bool + %12 = arith.or %4, %5 : bool + %13 = arith.xor %5, %5 : bool + %14 = arith.not %4 : bool ; CHECK-LABEL: func @known_arith({{.*}} -; CHECK: %4 = constant -2 -> i64 -; CHECK-NEXT: %5 = arith.not %0 : i64 -; CHECK-NEXT: %6 = constant 3 -> i64 -; CHECK-NEXT: %7 = arith.add %0, %1 : i64 -; CHECK-NEXT: %8 = constant 0x1p+1 -> f32 -; CHECK-NEXT: %9 = arith.neg %2 : f32 -; CHECK-NEXT: %10 = constant [0x1p+1,-0x1p+1] -> c32 -; CHECK-NEXT: %11 = arith.add %3, %3 : c32 -; CHECK-NEXT: %12 = constant 0x1p+1 -> f32 -; CHECK-NEXT: %13 = arith.abs %2 : f32 +; CHECK: %6 = constant -2 -> i64 +; CHECK-NEXT: %7 = arith.not %0 : i64 +; CHECK-NEXT: %8 = constant 3 -> i64 +; CHECK-NEXT: %9 = arith.add %0, %1 : i64 +; CHECK-NEXT: %10 = constant 0x1p+1 -> f32 +; CHECK-NEXT: %11 = arith.neg %2 : f32 +; CHECK-NEXT: %12 = constant [0x1p+1,-0x1p+1] -> c32 +; CHECK-NEXT: %13 = arith.add %3, %3 : c32 +; CHECK-NEXT: %14 = constant 0x1p+1 -> f32 +; CHECK-NEXT: %15 = arith.abs %2 : f32 +; CHECK-NEXT: %16 = constant false -> bool +; CHECK-NEXT: %17 = arith.and %4, %5 : bool +; CHECK-NEXT: %18 = constant true -> bool +; CHECK-NEXT: %19 = arith.or %4, %5 : bool +; CHECK-NEXT: %20 = constant false -> bool +; CHECK-NEXT: %21 = arith.xor %5, %5 : bool +; CHECK-NEXT: %22 = constant true -> bool +; CHECK-NEXT: %23 = arith.not %4 : bool } func @known_cast() { @@ -82,10 +96,8 @@ func @known_cast() { %0 = cast %c0 : i32 -> i16 %1 = cast %c0 : i32 -> f32 %2 = cast %c0 : i32 -> c32 - %3 = cast %c0 : i32 -> i1 - %4 = cast %c0 : i32 -> c32 - %5 = cast %c1 : c32 -> c64 - %6 = cast %3 : i1 -> c32 + %3 = cast %c0 : i32 -> c32 + %4 = cast %c1 : c32 -> c64 ; CHECK-LABEL: func @known_cast({{.*}} ; CHECK: %0 = constant -32768 -> i16 ; CHECK-NEXT: %1 = cast %c0 : i32 -> i16 @@ -93,14 +105,10 @@ func @known_cast() { ; CHECK-NEXT: %3 = cast %c0 : i32 -> f32 ; CHECK-NEXT: %4 = constant [0x1p+15,0x0p+0] -> c32 ; CHECK-NEXT: %5 = cast %c0 : i32 -> c32 -; CHECK-NEXT: %6 = constant 1 -> i1 -; CHECK-NEXT: %7 = cast %c0 : i32 -> i1 -; CHECK-NEXT: %8 = constant [0x1p+15,0x0p+0] -> c32 -; CHECK-NEXT: %9 = cast %c0 : i32 -> c32 -; CHECK-NEXT: %10 = constant [0x1.8p+1,-0x1p+1] -> c64 -; CHECK-NEXT: %11 = cast %c1 : c32 -> c64 -; CHECK-NEXT: %12 = constant [0x1p+0,0x0p+0] -> c32 -; CHECK-NEXT: %13 = cast %6 : i1 -> c32 +; CHECK-NEXT: %6 = constant [0x1p+15,0x0p+0] -> c32 +; CHECK-NEXT: %7 = cast %c0 : i32 -> c32 +; CHECK-NEXT: %8 = constant [0x1.8p+1,-0x1p+1] -> c64 +; CHECK-NEXT: %9 = cast %c1 : c32 -> c64 } func @known_compare() { @@ -109,9 +117,9 @@ func @known_compare() { %2 = cmp.eq %0, %0 : f32 %3 = cmp.eq %0, %1 : f32 ; CHECK-LABEL: func @known_compare({{.*}} -; CHECK: %2 = constant 1 -> i1 +; CHECK: %2 = constant true -> bool ; CHECK-NEXT: %3 = cmp.eq %0, %0 : f32 -; CHECK-NEXT: %4 = constant 0 -> i1 +; CHECK-NEXT: %4 = constant false -> bool ; CHECK-NEXT: %5 = cmp.eq %0, %1 : f32 } diff --git a/test/opt/dead-code-elimination.ir b/test/opt/dead-code-elimination.ir index a7e66f66..7ce5d69d 100644 --- a/test/opt/dead-code-elimination.ir +++ b/test/opt/dead-code-elimination.ir @@ -3,18 +3,18 @@ ; RUN: %tinytc-opt -pdead-code-elimination < %s | filecheck %s func @dead_if(%a: memref) { - %c0 = constant 0 -> i1 + %c0 = constant false -> bool if %c0 { %c42 = constant 42.0 -> f64 store %c42, %a[] : memref } - %c1 = constant 1 -> i1 + %c1 = constant true -> bool if %c1 { %c43 = constant 43.0 -> f64 store %c43, %a[] : memref } ; CHECK-LABEL: func @dead_if({{.*}} -; CHECK-NEXT: %c1 = constant 1 -> i1 +; CHECK-NEXT: %c1 = constant true -> bool ; CHECK-NEXT: if %c1 { ; CHECK-NEXT: %c43{{.*}} ; CHECK-NEXT: store{{.*}} @@ -22,7 +22,7 @@ func @dead_if(%a: memref) { } func @dead_if_with_yield(%a: memref) { - %c0 = constant 0 -> i1 + %c0 = constant false -> bool %0 = if %c0 -> (f64) { %c42 = constant 42.0 -> f64 yield %c42 : f64 diff --git a/test/spv/arith.ir b/test/spv/arith.ir index 4028517d..0385843e 100644 --- a/test/spv/arith.ir +++ b/test/spv/arith.ir @@ -9,10 +9,10 @@ ; CHECK: %[[#F32:]] = OpTypeFloat 32 ; CHECK: %[[#C32:]] = OpTypeVector %[[#F32]] 2 -func @tbool(%a: i1, %b: i1) { - %0 = arith.and %a, %b : i1 - %1 = arith.or %a, %b : i1 - %2 = arith.xor %a, %b : i1 +func @tbool(%a: bool, %b: bool) { + %0 = arith.and %a, %b : bool + %1 = arith.or %a, %b : bool + %2 = arith.xor %a, %b : bool ; CHECK: %[[#]] = OpLogicalAnd %[[#BOOL]] %[[#]] %[[#]] ; CHECK-NEXT: %[[#]] = OpLogicalOr %[[#BOOL]] %[[#]] %[[#]] ; CHECK-NEXT: %[[#]] = OpLogicalNotEqual %[[#BOOL]] %[[#]] %[[#]] diff --git a/test/spv/arith_unary.ir b/test/spv/arith_unary.ir index eb1f6ac5..cc84fe0a 100644 --- a/test/spv/arith_unary.ir +++ b/test/spv/arith_unary.ir @@ -10,8 +10,8 @@ ; CHECK: %[[#F32:]] = OpTypeFloat 32 ; CHECK: %[[#C32:]] = OpTypeVector %[[#F32]] 2 -func @tbool(%a: i1) { - %0 = arith.not %a : i1 +func @tbool(%a: bool) { + %0 = arith.not %a : bool ; CHECK: OpLogicalNot %[[#BOOL]] %[[#]] } From 44f782ddeb7eb064d4072420424f2216d89c684e Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 6 Nov 2024 12:48:07 +0100 Subject: [PATCH 087/297] SPIR-V: implement cast Signed-off-by: Carsten Uphoff --- src/pass/constant_folding.cpp | 35 ++++++-- src/pass/convert_to_spirv.cpp | 165 +++++++++++++++++++++++++++++----- test/spv/cast.ir | 50 +++++++++++ 3 files changed, 225 insertions(+), 25 deletions(-) create mode 100644 test/spv/cast.ir diff --git a/src/pass/constant_folding.cpp b/src/pass/constant_folding.cpp index 6602b943..1c453c12 100644 --- a/src/pass/constant_folding.cpp +++ b/src/pass/constant_folding.cpp @@ -27,10 +27,11 @@ template class unary_op_dispatcher { unary_op_dispatcher(scalar_type sw_ty, F &&f) : switch_ty{sw_ty}, computer{std::forward(f)} {} + auto operator()(bool const &) -> fold_result { + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + } auto operator()(std::int64_t const &A) -> fold_result { switch (switch_ty) { - case scalar_type::i1: - return computer.template operator()(A); case scalar_type::i8: return computer.template operator()(A); case scalar_type::i16: @@ -81,8 +82,6 @@ template class binary_op_dispatcher { auto operator()(std::int64_t const &A, std::int64_t const &B) -> fold_result { switch (switch_ty) { - case scalar_type::i1: - return computer.template operator()(A, B); case scalar_type::i8: return computer.template operator()(A, B); case scalar_type::i16: @@ -144,6 +143,24 @@ auto constant_folding::operator()(arith_inst &in) -> fold_result { constant_inst *a_const = dyn_cast(op_a.defining_inst()); constant_inst *b_const = dyn_cast(op_b.defining_inst()); + if (isa(*op_a.ty())) { + if ((a_const && !std::holds_alternative(a_const->value())) || + (b_const && !std::holds_alternative(b_const->value()))) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + if (a_const != nullptr && b_const != nullptr) { + return compute_binary_op{in.operation(), op_a.ty(), in.loc()}( + std::get(a_const->value()), std::get(b_const->value())); + } else if (a_const != nullptr) { + return compute_binop_identities{unsafe_fp_math_, in.operation(), op_b, true, + in.loc()}(std::get(a_const->value())); + } else if (b_const != nullptr) { + return compute_binop_identities{unsafe_fp_math_, in.operation(), op_a, false, + in.loc()}(std::get(b_const->value())); + } + return tinytc_value_t{}; + } + auto at = dyn_cast(op_a.ty()); if (at == nullptr) { // Arithmetic on coopmatrix is component-wise and if a coopmatrix is constant, then all @@ -151,7 +168,7 @@ auto constant_folding::operator()(arith_inst &in) -> fold_result { // constant folding on scalar types. auto ct = dyn_cast(op_a.ty()); if (ct == nullptr) { - throw compilation_error(op_a.loc(), status::ir_expected_coopmatrix_or_scalar); + throw compilation_error(op_a.loc(), status::ir_expected_coopmatrix_scalar_or_boolean); } at = dyn_cast(ct->ty()); } @@ -182,6 +199,14 @@ auto constant_folding::operator()(arith_unary_inst &in) -> fold_result { return tinytc_value_t{}; } + if (isa(*op_a.ty())) { + if (!std::holds_alternative(a_const->value())) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + return compute_unary_op{in.operation(), op_a.ty(), + in.loc()}(std::get(a_const->value())); + } + auto at = dyn_cast(op_a.ty()); if (at == nullptr) { // Arithmetic on coopmatrix is component-wise and if a coopmatrix is constant, then all diff --git a/src/pass/convert_to_spirv.cpp b/src/pass/convert_to_spirv.cpp index 9450ef70..231a443a 100644 --- a/src/pass/convert_to_spirv.cpp +++ b/src/pass/convert_to_spirv.cpp @@ -31,11 +31,14 @@ class spirv_converter { void operator()(inst_node const &in); void operator()(arith_inst const &in); void operator()(arith_unary_inst const &in); + void operator()(cast_inst const &in); void operator()(constant_inst const &in); void run_on_program(program_node const &p); private: + auto get_scalar_type(value_node const &v) -> scalar_type; + auto get_coopmatrix_type(value_node const &v) -> scalar_type; auto declare(value_node const &v, spv::spv_inst *in); auto val(value_node const &v) -> spv::spv_inst *; auto multi_declare(value_node const &v, std::vector insts); @@ -65,6 +68,21 @@ class spirv_converter { core_config core_cfg_ = {}; }; +auto spirv_converter::get_scalar_type(value_node const &v) -> scalar_type { + auto st = dyn_cast(v.ty()); + if (!st) { + throw compilation_error(v.loc(), status::ir_expected_scalar); + } + return st->ty(); +} +auto spirv_converter::get_coopmatrix_type(value_node const &v) -> scalar_type { + auto ct = dyn_cast(v.ty()); + if (!ct) { + throw compilation_error(v.loc(), status::ir_expected_coopmatrix); + } + return ct->component_ty(); +} + auto spirv_converter::declare(value_node const &v, spv::spv_inst *in) { vals_[&v] = in; } auto spirv_converter::val(value_node const &v) -> spv::spv_inst * { if (auto it = vals_.find(&v); it != vals_.end()) { @@ -109,10 +127,11 @@ auto spirv_converter::operator()(data_type_node const &ty) -> spv::spv_inst * { [&](void_data_type const &) -> spv::spv_inst * { return add_to(); }, + [&](boolean_data_type const &) -> spv::spv_inst * { + return add_to(); + }, [&](scalar_data_type const &ty) -> spv::spv_inst * { switch (ty.ty()) { - case scalar_type::i1: - return add_to(); case scalar_type::i8: capabilities_.insert(spv::Capability::Int8); return add_to(8, 1); @@ -180,7 +199,7 @@ void spirv_converter::operator()(arith_inst const &in) { default: break; } - throw compilation_error(in.loc(), status::ir_i1_unsupported); + throw compilation_error(in.loc(), status::ir_boolean_unsupported); }; auto const make_int = [&](arithmetic op, spv::spv_inst *ty, spv::spv_inst *a, spv::spv_inst *b) -> spv::spv_inst * { @@ -229,8 +248,6 @@ void spirv_converter::operator()(arith_inst const &in) { auto const make = [&](scalar_type sty, arithmetic op, spv::spv_inst *ty, spv::spv_inst *a, spv::spv_inst *b) -> spv::spv_inst * { switch (sty) { - case scalar_type::i1: - return make_boolean(op, ty, a, b); case scalar_type::i8: case scalar_type::i16: case scalar_type::i32: @@ -248,7 +265,11 @@ void spirv_converter::operator()(arith_inst const &in) { auto ty = visit(*this, *in.result(0).ty()); - if (auto st = dyn_cast(in.result(0).ty()); st) { + if (isa(*in.result(0).ty())) { + auto av = val(in.a()); + auto bv = val(in.b()); + declare(in.result(0), make_boolean(in.operation(), ty, av, bv)); + } else if (auto st = dyn_cast(in.result(0).ty()); st) { auto av = val(in.a()); auto bv = val(in.b()); declare(in.result(0), make(st->ty(), in.operation(), ty, av, bv)); @@ -278,7 +299,7 @@ void spirv_converter::operator()(arith_unary_inst const &in) { default: break; } - throw compilation_error(in.loc(), status::ir_i1_unsupported); + throw compilation_error(in.loc(), status::ir_boolean_unsupported); }; auto const make_int = [&](arithmetic_unary op, spv::spv_inst *ty, spv::spv_inst *a) -> spv::spv_inst * { @@ -345,8 +366,6 @@ void spirv_converter::operator()(arith_unary_inst const &in) { auto const make = [&](scalar_type sty, arithmetic_unary op, spv::spv_inst *ty, spv::spv_inst *a) -> spv::spv_inst * { switch (sty) { - case scalar_type::i1: - return make_boolean(op, ty, a); case scalar_type::i8: case scalar_type::i16: case scalar_type::i32: @@ -365,8 +384,10 @@ void spirv_converter::operator()(arith_unary_inst const &in) { }; auto ty = visit(*this, *in.result(0).ty()); - - if (auto st = dyn_cast(in.a().ty()); st) { + if (isa(*in.a().ty())) { + auto av = val(in.a()); + declare(in.result(0), make_boolean(in.operation(), ty, av)); + } else if (auto st = dyn_cast(in.a().ty()); st) { auto av = val(in.a()); declare(in.result(0), make(st->ty(), in.operation(), ty, av)); } else if (auto ct = dyn_cast(in.a().ty()); ct) { @@ -385,15 +406,109 @@ void spirv_converter::operator()(arith_unary_inst const &in) { } } +void spirv_converter::operator()(cast_inst const &in) { + auto const cast_from_int = [&](scalar_type to_ty, spv::spv_inst *spv_to_ty, + spv::spv_inst *a) -> spv::spv_inst * { + switch (to_ty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return add(spv_to_ty, a); + case scalar_type::f32: + case scalar_type::f64: + return add(spv_to_ty, a); + case scalar_type::c32: + case scalar_type::c64: { + auto spv_float_ty = visit(*this, *scalar_data_type::get(ctx_, element_type(to_ty))); + auto c0 = add(spv_to_ty); + auto re = add(spv_float_ty, a); + return add(spv_to_ty, re, c0, + std::vector{0}); + } + } + throw compilation_error(in.loc(), status::ir_forbidden_cast); + }; + auto const cast_from_float = [&](scalar_type to_ty, spv::spv_inst *spv_to_ty, + spv::spv_inst *a) -> spv::spv_inst * { + switch (to_ty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return add(spv_to_ty, a); + case scalar_type::f32: + case scalar_type::f64: + return add(spv_to_ty, a); + case scalar_type::c32: + case scalar_type::c64: { + auto spv_float_ty = visit(*this, *scalar_data_type::get(ctx_, element_type(to_ty))); + auto c0 = add(spv_to_ty); + auto re = add(spv_float_ty, a); + return add(spv_to_ty, re, c0, + std::vector{0}); + } + } + throw compilation_error(in.loc(), status::ir_forbidden_cast); + }; + auto const cast_from_complex = [&](scalar_type to_ty, spv::spv_inst *spv_to_ty, + spv::spv_inst *a) -> spv::spv_inst * { + switch (to_ty) { + case scalar_type::c32: + case scalar_type::c64: + return add(spv_to_ty, a); + default: + throw compilation_error(in.loc(), status::ir_forbidden_cast); + } + }; + auto const make = [&](scalar_type to_ty, scalar_type a_ty, spv::spv_inst *spv_to_ty, + spv::spv_inst *a) -> spv::spv_inst * { + switch (a_ty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return cast_from_int(to_ty, spv_to_ty, a); + case scalar_type::f32: + case scalar_type::f64: + return cast_from_float(to_ty, spv_to_ty, a); + case scalar_type::c32: + case scalar_type::c64: { + return cast_from_complex(to_ty, spv_to_ty, a); + } + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + + auto spv_to_ty = visit(*this, *in.result(0).ty()); + + if (auto st = dyn_cast(in.result(0).ty()); st) { + auto av = val(in.a()); + auto a_ty = get_scalar_type(in.a()); + declare(in.result(0), make(st->ty(), a_ty, spv_to_ty, av)); + } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { + auto const length = ct->length(core_cfg_.subgroup_size); + auto insts = std::vector{}; + insts.reserve(length); + + auto &av = multi_val(in.a()); + auto a_ty = get_coopmatrix_type(in.a()); + for (std::int64_t i = 0; i < length; ++i) { + insts.emplace_back(make(ct->component_ty(), a_ty, spv_to_ty, av[i])); + } + + multi_declare(in.result(0), std::move(insts)); + } else { + throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); + } +} + void spirv_converter::operator()(constant_inst const &in) { auto const make = [&](scalar_type sty, spv::spv_inst *spv_ty, constant_inst::value_type const &val) -> spv::spv_inst * { - auto const add_constant_bool = [this, &spv_ty](bool val) -> spv::spv_inst * { - if (val) { - return add_to(spv_ty); - } - return add_to(spv_ty); - }; auto const add_constant = [this, &spv_ty](auto val) -> spv::spv_inst * { return add_to(spv_ty, val); }; @@ -405,10 +520,9 @@ void spirv_converter::operator()(constant_inst const &in) { spv_ty, std::vector{c_re, c_im}); }; const auto visitor = overloaded{ + [&](bool) -> spv::spv_inst * { return nullptr; }, [&](std::int64_t i) -> spv::spv_inst * { switch (sty) { - case scalar_type::i1: - return add_constant_bool(i != 0); case scalar_type::i8: return add_constant(static_cast(i)); case scalar_type::i16: @@ -459,7 +573,18 @@ void spirv_converter::operator()(constant_inst const &in) { auto spv_ty = visit(*this, *in.result(0).ty()); - if (auto st = dyn_cast(in.result(0).ty()); st) { + if (isa(*in.result(0).ty())) { + if (!std::holds_alternative(in.value())) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + auto cst = [&](bool b) -> spv::spv_inst * { + if (b) { + return add_to(spv_ty); + } + return add_to(spv_ty); + }(std::get(in.value())); + declare(in.result(0), cst); + } else if (auto st = dyn_cast(in.result(0).ty()); st) { declare(in.result(0), make(st->ty(), spv_ty, in.value())); } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { auto const length = ct->length(core_cfg_.subgroup_size); diff --git a/test/spv/cast.ir b/test/spv/cast.ir new file mode 100644 index 00000000..a336d243 --- /dev/null +++ b/test/spv/cast.ir @@ -0,0 +1,50 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s + +; CHECK: OpCapability Int8 +; CHECK: OpCapability Int64 +; CHECK: %[[#I64:]] = OpTypeInt 64 1 +; CHECK: %[[#I8:]] = OpTypeInt 8 1 +; CHECK: %[[#F32:]] = OpTypeFloat 32 +; CHECK: %[[#F64:]] = OpTypeFloat 64 +; CHECK: %[[#C64:]] = OpTypeVector %[[#F64]] 2 + +func @tint(%a: i64) { + %0 = cast %a : i64 -> i64 + %1 = cast %a : i64 -> i8 + %2 = cast %a : i64 -> f32 + %3 = cast %a : i64 -> c64 +; CHECK: %[[#]] = OpSConvert %[[#I64]] %[[#]] +; CHECK-NEXT: %[[#]] = OpSConvert %[[#I8]] %[[#]] +; CHECK-NEXT: %[[#]] = OpConvertSToF %[[#F32]] %[[#]] +; CHECK-NEXT: %[[#C64_NULL:]] = OpConstantNull %[[#C64]] +; CHECK-NEXT: %[[#I64_TO_F64:]] = OpConvertSToF %[[#F64]] %[[#]] +; CHECK-NEXT: %[[#]] = OpCompositeInsert %[[#C64]] %[[#I64_TO_F64]] %[[#C64_NULL]] 0 +} + +func @tfloat(%a: f32) { + %1 = cast %a : f32 -> i8 + %2 = cast %a : f32 -> f64 + %3 = cast %a : f32 -> c64 +; CHECK: %[[#]] = OpConvertFToS %[[#I8]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFConvert %[[#F64]] %[[#]] +; CHECK-NEXT: %[[#C64_NULL_2:]] = OpConstantNull %[[#C64]] +; CHECK-NEXT: %[[#F32_TO_F64:]] = OpFConvert %[[#F64]] %[[#]] +; CHECK-NEXT: %[[#]] = OpCompositeInsert %[[#C64]] %[[#F32_TO_F64]] %[[#C64_NULL_2]] 0 +} + +func @tcomplex(%a: c32) { + %1 = cast %a : c32 -> c64 +; CHECK: %[[#]] = OpFConvert %[[#C64]] %[[#]] +} + +func @tfloatcoopmatrix() subgroup_size(16) { + %0 = constant 1.0 -> coopmatrix + %2 = cast %0 : coopmatrix -> coopmatrix +; CHECK: %[[#]] = OpConvertFToS %[[#I8]] %[[#]] +; CHECK-NEXT: %[[#]] = OpConvertFToS %[[#I8]] %[[#]] +; CHECK-NEXT: %[[#]] = OpConvertFToS %[[#I8]] %[[#]] +; CHECK-NEXT: %[[#]] = OpConvertFToS %[[#I8]] %[[#]] +} From 386e33b28d54fa84df15a54981691fb3089b20be Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 6 Nov 2024 14:56:25 +0100 Subject: [PATCH 088/297] Add spirv-val tests Signed-off-by: Carsten Uphoff --- cmake/FindSPIRVTools.cmake | 47 +++++++++++++++++++++++++++++++++++ src/pass/convert_to_spirv.cpp | 18 ++++++++------ test/CMakeLists.txt | 20 ++++++++++++++- test/spv/arith.ir | 2 +- test/spv/arith_unary.ir | 2 +- test/spv/cast.ir | 10 ++++---- 6 files changed, 84 insertions(+), 15 deletions(-) create mode 100644 cmake/FindSPIRVTools.cmake diff --git a/cmake/FindSPIRVTools.cmake b/cmake/FindSPIRVTools.cmake new file mode 100644 index 00000000..b1f4de66 --- /dev/null +++ b/cmake/FindSPIRVTools.cmake @@ -0,0 +1,47 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause + +# Try to find SPIR-V Tools + +# The following definitions are added on success +# +# SPIRVTools_FOUND - SPIR-V Tools was found +# SPIRVTools_SPIRV_VAL - spirv-val executable +# SPIRVTools_VERSION - SPIR-V Tools version +# +# The followings hints may be passed in the environment: +# +# RE2C_ROOT +# + +if(SPIRVTools_SPIRV_VAL) + set(SPIRVTools_FOUND TRUE) +else() + find_program(SPIRVTools_SPIRV_VAL NAMES spirv-val + HINTS + ENV SPIRVTools_ROOT + ENV PATH + ) + + include(FindPackageHandleStandardArgs) + find_package_handle_standard_args(SPIRVTools DEFAULT_MSG SPIRVTools_SPIRV_VAL) + + execute_process(COMMAND ${SPIRVTools_SPIRV_VAL} --version + OUTPUT_VARIABLE SPIRVTools_version_output + ERROR_VARIABLE SPIRVTools_version_error + RESULT_VARIABLE SPIRVTools_version_result + OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(NOT ${SPIRVTools_version_result} EQUAL 0) + set(SPIRVTools_message "Command \"{SPIRVTools_SPIRV_VAL} --version\" failed:\n${SPIRVTools_version_output}\n${SPIRVTools_version_error}") + if(SPIRVTools_FIND_REQUIRED) + message(SEND_ERROR ${SPIRVTools_message}) + else() + message(${SPIRVTools_message}) + endif() + else() + string(REGEX REPLACE "SPIRV-Tools ([v0-9\.]+) .*" "\\1" SPIRVTools_VERSION "${SPIRVTools_version_output}") + endif() + + mark_as_advanced(SPIRVTools_SPIRV_VAL) +endif() diff --git a/src/pass/convert_to_spirv.cpp b/src/pass/convert_to_spirv.cpp index 231a443a..0234217c 100644 --- a/src/pass/convert_to_spirv.cpp +++ b/src/pass/convert_to_spirv.cpp @@ -134,24 +134,25 @@ auto spirv_converter::operator()(data_type_node const &ty) -> spv::spv_inst * { switch (ty.ty()) { case scalar_type::i8: capabilities_.insert(spv::Capability::Int8); - return add_to(8, 1); + return add_to(8, 0); case scalar_type::i16: capabilities_.insert(spv::Capability::Int16); - return add_to(16, 1); + return add_to(16, 0); case scalar_type::i32: - return add_to(32, 1); + return add_to(32, 0); case scalar_type::i64: capabilities_.insert(spv::Capability::Int64); - return add_to(64, 1); + return add_to(64, 0); case scalar_type::index: { const auto sz = size(ty.ty()); if (sz == 8) { capabilities_.insert(spv::Capability::Int64); } - return add_to(sz * 8, 1); + return add_to(sz * 8, 0); } case scalar_type::f32: case scalar_type::f64: + capabilities_.insert(spv::Capability::Float64); return add_to( size(ty.ty()) * 8, std::nullopt); case scalar_type::c32: { @@ -422,7 +423,7 @@ void spirv_converter::operator()(cast_inst const &in) { case scalar_type::c32: case scalar_type::c64: { auto spv_float_ty = visit(*this, *scalar_data_type::get(ctx_, element_type(to_ty))); - auto c0 = add(spv_to_ty); + auto c0 = add_to(spv_to_ty); auto re = add(spv_float_ty, a); return add(spv_to_ty, re, c0, std::vector{0}); @@ -445,7 +446,7 @@ void spirv_converter::operator()(cast_inst const &in) { case scalar_type::c32: case scalar_type::c64: { auto spv_float_ty = visit(*this, *scalar_data_type::get(ctx_, element_type(to_ty))); - auto c0 = add(spv_to_ty); + auto c0 = add_to(spv_to_ty); auto re = add(spv_float_ty, a); return add(spv_to_ty, re, c0, std::vector{0}); @@ -465,6 +466,9 @@ void spirv_converter::operator()(cast_inst const &in) { }; auto const make = [&](scalar_type to_ty, scalar_type a_ty, spv::spv_inst *spv_to_ty, spv::spv_inst *a) -> spv::spv_inst * { + if (a_ty == to_ty) { + return add(spv_to_ty, a); + } switch (a_ty) { case scalar_type::i8: case scalar_type::i16: diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 344dd152..9a1e9068 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -5,6 +5,8 @@ include(CTest) include(CommonOptions) include(${PROJECT_SOURCE_DIR}/external/doctest/cmake/doctest.cmake) +find_package(SPIRVTools) + add_library(test-lib STATIC main.cpp) target_include_directories(test-lib PUBLIC ${PROJECT_SOURCE_DIR}/external) target_compile_features(test-lib PUBLIC cxx_std_20) @@ -42,11 +44,27 @@ file(GENERATE ) set(LIT_COMMAND lit "${CMAKE_CURRENT_BINARY_DIR}" -v) -add_test(lit-check ${LIT_COMMAND}) +add_test(NAME lit-check COMMAND ${LIT_COMMAND}) set_tests_properties(lit-check PROPERTIES LABELS "lit") add_custom_target(lit-check COMMAND ${LIT_COMMAND}) add_dependencies(lit-check tinytc-oc tinytc-opt) +if(SPIRVTools_FOUND) + set(SPIRV_VAL_SOURCES + spv/arith.ir + spv/arith_unary.ir + spv/cast.ir + spv/unique_function_type.ir + ) + foreach(SOURCE IN LISTS SPIRV_VAL_SOURCES) + get_filename_component(TEST_NAME ${SOURCE} NAME_WE) + set(CHECK_COMMAND $ -O0 -gspirv ${CMAKE_CURRENT_SOURCE_DIR}/${SOURCE} | spirv-as - -o - | ${SPIRVTools_SPIRV_VAL} -) + list(JOIN CHECK_COMMAND " " CHECK_COMMAND_STR) + add_test(NAME spirv-val-${TEST_NAME} COMMAND bash -c "${CHECK_COMMAND_STR}") + add_custom_target(spirv-val-${TEST_NAME} COMMAND ${CHECK_COMMAND}) + add_dependencies(spirv-val-${TEST_NAME} tinytc-oc) + endforeach() +endif() if(BUILD_OPENCL) add_subdirectory(cl) diff --git a/test/spv/arith.ir b/test/spv/arith.ir index 0385843e..21b4f9cd 100644 --- a/test/spv/arith.ir +++ b/test/spv/arith.ir @@ -5,7 +5,7 @@ ; CHECK: OpCapability Int64 ; CHECK: %[[#BOOL:]] = OpTypeBool -; CHECK: %[[#I64:]] = OpTypeInt 64 1 +; CHECK: %[[#I64:]] = OpTypeInt 64 0 ; CHECK: %[[#F32:]] = OpTypeFloat 32 ; CHECK: %[[#C32:]] = OpTypeVector %[[#F32]] 2 diff --git a/test/spv/arith_unary.ir b/test/spv/arith_unary.ir index cc84fe0a..30640db4 100644 --- a/test/spv/arith_unary.ir +++ b/test/spv/arith_unary.ir @@ -6,7 +6,7 @@ ; CHECK: OpCapability Int64 ; CHECK: %[[#EXT:]] = OpExtInstImport "OpenCL.std" ; CHECK: %[[#BOOL:]] = OpTypeBool -; CHECK: %[[#I64:]] = OpTypeInt 64 1 +; CHECK: %[[#I64:]] = OpTypeInt 64 0 ; CHECK: %[[#F32:]] = OpTypeFloat 32 ; CHECK: %[[#C32:]] = OpTypeVector %[[#F32]] 2 diff --git a/test/spv/cast.ir b/test/spv/cast.ir index a336d243..1da2dc3b 100644 --- a/test/spv/cast.ir +++ b/test/spv/cast.ir @@ -5,21 +5,22 @@ ; CHECK: OpCapability Int8 ; CHECK: OpCapability Int64 -; CHECK: %[[#I64:]] = OpTypeInt 64 1 -; CHECK: %[[#I8:]] = OpTypeInt 8 1 +; CHECK: %[[#I64:]] = OpTypeInt 64 0 +; CHECK: %[[#I8:]] = OpTypeInt 8 0 ; CHECK: %[[#F32:]] = OpTypeFloat 32 ; CHECK: %[[#F64:]] = OpTypeFloat 64 ; CHECK: %[[#C64:]] = OpTypeVector %[[#F64]] 2 +; CHECK: %[[#C64_NULL:]] = OpConstantNull %[[#C64]] +; CHECK: %[[#C64_NULL_2:]] = OpConstantNull %[[#C64]] func @tint(%a: i64) { %0 = cast %a : i64 -> i64 %1 = cast %a : i64 -> i8 %2 = cast %a : i64 -> f32 %3 = cast %a : i64 -> c64 -; CHECK: %[[#]] = OpSConvert %[[#I64]] %[[#]] +; CHECK: %[[#]] = OpCopyObject %[[#I64]] %[[#]] ; CHECK-NEXT: %[[#]] = OpSConvert %[[#I8]] %[[#]] ; CHECK-NEXT: %[[#]] = OpConvertSToF %[[#F32]] %[[#]] -; CHECK-NEXT: %[[#C64_NULL:]] = OpConstantNull %[[#C64]] ; CHECK-NEXT: %[[#I64_TO_F64:]] = OpConvertSToF %[[#F64]] %[[#]] ; CHECK-NEXT: %[[#]] = OpCompositeInsert %[[#C64]] %[[#I64_TO_F64]] %[[#C64_NULL]] 0 } @@ -30,7 +31,6 @@ func @tfloat(%a: f32) { %3 = cast %a : f32 -> c64 ; CHECK: %[[#]] = OpConvertFToS %[[#I8]] %[[#]] ; CHECK-NEXT: %[[#]] = OpFConvert %[[#F64]] %[[#]] -; CHECK-NEXT: %[[#C64_NULL_2:]] = OpConstantNull %[[#C64]] ; CHECK-NEXT: %[[#F32_TO_F64:]] = OpFConvert %[[#F64]] %[[#]] ; CHECK-NEXT: %[[#]] = OpCompositeInsert %[[#C64]] %[[#F32_TO_F64]] %[[#C64_NULL_2]] 0 } From 695817f62c4c02f2adbf2eb71ac9dd64ab4ea349 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 6 Nov 2024 16:24:54 +0100 Subject: [PATCH 089/297] SPIR-V: Compare Signed-off-by: Carsten Uphoff --- src/pass/convert_to_spirv.cpp | 83 +++++++++++++++++++++++++++++++++++ test/CMakeLists.txt | 1 + test/spv/compare.ir | 46 +++++++++++++++++++ 3 files changed, 130 insertions(+) create mode 100644 test/spv/compare.ir diff --git a/src/pass/convert_to_spirv.cpp b/src/pass/convert_to_spirv.cpp index 0234217c..afa5a313 100644 --- a/src/pass/convert_to_spirv.cpp +++ b/src/pass/convert_to_spirv.cpp @@ -32,6 +32,7 @@ class spirv_converter { void operator()(arith_inst const &in); void operator()(arith_unary_inst const &in); void operator()(cast_inst const &in); + void operator()(compare_inst const &in); void operator()(constant_inst const &in); void run_on_program(program_node const &p); @@ -65,6 +66,7 @@ class spirv_converter { std::unordered_set capabilities_; std::unordered_multimap function_tys_; spv::spv_inst *opencl_ext_ = nullptr; + spv::spv_inst *bool2_ty_ = nullptr; core_config core_cfg_ = {}; }; @@ -510,6 +512,87 @@ void spirv_converter::operator()(cast_inst const &in) { } } +void spirv_converter::operator()(compare_inst const &in) { + auto const compare_int = [&](cmp_condition cond, spv::spv_inst *spv_to_ty, spv::spv_inst *a, + spv::spv_inst *b) -> spv::spv_inst * { + switch (cond) { + case cmp_condition::eq: + return add(spv_to_ty, a, b); + case cmp_condition::ne: + return add(spv_to_ty, a, b); + case cmp_condition::gt: + return add(spv_to_ty, a, b); + case cmp_condition::ge: + return add(spv_to_ty, a, b); + case cmp_condition::lt: + return add(spv_to_ty, a, b); + case cmp_condition::le: + return add(spv_to_ty, a, b); + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + auto const compare_float = [&](cmp_condition cond, spv::spv_inst *spv_to_ty, spv::spv_inst *a, + spv::spv_inst *b) -> spv::spv_inst * { + switch (cond) { + case cmp_condition::eq: + return add(spv_to_ty, a, b); + case cmp_condition::ne: + return add(spv_to_ty, a, b); + case cmp_condition::gt: + return add(spv_to_ty, a, b); + case cmp_condition::ge: + return add(spv_to_ty, a, b); + case cmp_condition::lt: + return add(spv_to_ty, a, b); + case cmp_condition::le: + return add(spv_to_ty, a, b); + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + auto const compare_complex = [&](cmp_condition cond, spv::spv_inst *spv_to_ty, spv::spv_inst *a, + spv::spv_inst *b) -> spv::spv_inst * { + if (!bool2_ty_) { + bool2_ty_ = add_to(spv_to_ty, 2); + } + switch (cond) { + case cmp_condition::eq: { + auto components_equal = add(bool2_ty_, a, b); + return add(spv_to_ty, components_equal); + } + case cmp_condition::ne: { + auto components_not_equal = add(bool2_ty_, a, b); + return add(spv_to_ty, components_not_equal); + } + default: + throw compilation_error(in.loc(), status::ir_complex_unsupported); + } + }; + auto const make = [&](scalar_type a_ty, cmp_condition cond, spv::spv_inst *spv_to_ty, + spv::spv_inst *a, spv::spv_inst *b) -> spv::spv_inst * { + switch (a_ty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return compare_int(cond, spv_to_ty, a, b); + case scalar_type::f32: + case scalar_type::f64: + return compare_float(cond, spv_to_ty, a, b); + case scalar_type::c32: + case scalar_type::c64: + return compare_complex(cond, spv_to_ty, a, b); + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + + auto spv_to_ty = visit(*this, *in.result(0).ty()); + auto av = val(in.a()); + auto bv = val(in.b()); + auto a_ty = get_scalar_type(in.a()); + declare(in.result(0), make(a_ty, in.cond(), spv_to_ty, av, bv)); +} + void spirv_converter::operator()(constant_inst const &in) { auto const make = [&](scalar_type sty, spv::spv_inst *spv_ty, constant_inst::value_type const &val) -> spv::spv_inst * { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 9a1e9068..f4141d92 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -54,6 +54,7 @@ if(SPIRVTools_FOUND) spv/arith.ir spv/arith_unary.ir spv/cast.ir + spv/compare.ir spv/unique_function_type.ir ) foreach(SOURCE IN LISTS SPIRV_VAL_SOURCES) diff --git a/test/spv/compare.ir b/test/spv/compare.ir new file mode 100644 index 00000000..ffafb14a --- /dev/null +++ b/test/spv/compare.ir @@ -0,0 +1,46 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s + +; CHECK: %[[#BOOL:]] = OpTypeBool +; CHECK: %[[#BOOL2:]] = OpTypeVector %[[#BOOL]] 2 + +func @tint(%a: i64, %b: i64) { + %0 = cmp.eq %a, %b : i64 + %1 = cmp.ne %a, %b : i64 + %2 = cmp.gt %a, %b : i64 + %3 = cmp.ge %a, %b : i64 + %4 = cmp.lt %a, %b : i64 + %5 = cmp.le %a, %b : i64 +; CHECK: %[[#]] = OpIEqual %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpINotEqual %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpSGreaterThan %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpSGreaterThanEqual %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpSLessThan %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpSLessThanEqual %[[#BOOL]] %[[#]] %[[#]] +} + +func @tfloat(%a: f32, %b: f32) { + %0 = cmp.eq %a, %b : f32 + %1 = cmp.ne %a, %b : f32 + %2 = cmp.gt %a, %b : f32 + %3 = cmp.ge %a, %b : f32 + %4 = cmp.lt %a, %b : f32 + %5 = cmp.le %a, %b : f32 +; CHECK: %[[#]] = OpFOrdEqual %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFUnordNotEqual %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFOrdGreaterThan %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFOrdGreaterThanEqual %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFOrdLessThan %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFOrdLessThanEqual %[[#BOOL]] %[[#]] %[[#]] +} + +func @tcomplex(%a: c32, %b: c32) { + %0 = cmp.eq %a, %b : c32 + %1 = cmp.ne %a, %b : c32 +; CHECK: %[[#COMPONENTS_EQUAL:]] = OpFOrdEqual %[[#BOOL2]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpAll %[[#BOOL]] %[[#COMPONENTS_EQUAL]] +; CHECK-NEXT: %[[#COMPONENTS_NOT_EQUAL:]] = OpFUnordNotEqual %[[#BOOL2]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpAll %[[#BOOL]] %[[#COMPONENTS_NOT_EQUAL]] +} From 4dafc7133cddd811d17361bce07a7103c1350c19 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 6 Nov 2024 18:54:13 +0100 Subject: [PATCH 090/297] SPIR-V: Add barrier Signed-off-by: Carsten Uphoff --- src/pass/convert_to_spirv.cpp | 98 ++++++++++++++++++++++++++++------- test/CMakeLists.txt | 1 + test/spv/barrier.ir | 23 ++++++++ test/spv/cast.ir | 3 +- 4 files changed, 105 insertions(+), 20 deletions(-) create mode 100644 test/spv/barrier.ir diff --git a/src/pass/convert_to_spirv.cpp b/src/pass/convert_to_spirv.cpp index afa5a313..46c5d518 100644 --- a/src/pass/convert_to_spirv.cpp +++ b/src/pass/convert_to_spirv.cpp @@ -23,7 +23,7 @@ class spirv_converter { public: inline spirv_converter(::tinytc_core_info const *info, spv::mod &mod, tinytc_compiler_context_t ctx) - : info_(info), mod_(&mod), ctx_(ctx) {} + : info_(info), mod_(&mod), ctx_(ctx), unique_(*this) {} auto operator()(data_type_node const &ty) -> spv::spv_inst *; @@ -31,6 +31,7 @@ class spirv_converter { void operator()(inst_node const &in); void operator()(arith_inst const &in); void operator()(arith_unary_inst const &in); + void operator()(barrier_inst const &in); void operator()(cast_inst const &in); void operator()(compare_inst const &in); void operator()(constant_inst const &in); @@ -38,6 +39,63 @@ class spirv_converter { void run_on_program(program_node const &p); private: + class unique_insts { + public: + inline unique_insts(spirv_converter &conv) : conv_(conv) {} + inline auto bool2_ty() -> spv::spv_inst * { + if (!bool2_ty_) { + auto bool_ty = visit(conv_, *boolean_data_type::get(conv_.ctx_)); + bool2_ty_ = + conv_.add_to(bool_ty, 2); + } + return bool2_ty_; + } + inline auto bool_constant(bool b) { + if (b) { + if (!bool_true_) { + auto bool_ty = visit(conv_, *boolean_data_type::get(conv_.ctx_)); + bool_true_ = + conv_.add_to(bool_ty); + } + return bool_true_; + } + if (!bool_false_) { + auto bool_ty = visit(conv_, *boolean_data_type::get(conv_.ctx_)); + bool_false_ = + conv_.add_to(bool_ty); + } + return bool_false_; + } + inline auto i32_constant(std::int32_t cst) -> spv::spv_inst * { + auto it = i32_cst_.find(cst); + if (it == i32_cst_.end()) { + auto i32_ty = visit(conv_, *scalar_data_type::get(conv_.ctx_, scalar_type::i32)); + auto cst_inst = conv_.add_to( + i32_ty, spv::LiteralContextDependentNumber{cst}); + i32_cst_[cst] = cst_inst; + return cst_inst; + } + return it->second; + } + + inline auto null_constant(spv::spv_inst *spv_ty) -> spv::spv_inst * { + auto it = null_cst_.find(spv_ty); + if (it == null_cst_.end()) { + auto in = conv_.add_to(spv_ty); + null_cst_[spv_ty] = in; + return in; + } + return it->second; + } + + private: + spirv_converter &conv_; + spv::spv_inst *bool2_ty_ = nullptr; + spv::spv_inst *bool_true_ = nullptr, *bool_false_ = nullptr; + std::unordered_map i32_cst_; + std::unordered_map null_cst_; + }; + auto get_scalar_type(value_node const &v) -> scalar_type; auto get_coopmatrix_type(value_node const &v) -> scalar_type; auto declare(value_node const &v, spv::spv_inst *in); @@ -60,13 +118,13 @@ class spirv_converter { ::tinytc_core_info const *info_; spv::mod *mod_; tinytc_compiler_context_t ctx_; + unique_insts unique_; std::unordered_map spv_tys_; std::unordered_map vals_; std::unordered_map> multi_vals_; std::unordered_set capabilities_; std::unordered_multimap function_tys_; spv::spv_inst *opencl_ext_ = nullptr; - spv::spv_inst *bool2_ty_ = nullptr; core_config core_cfg_ = {}; }; @@ -409,6 +467,21 @@ void spirv_converter::operator()(arith_unary_inst const &in) { } } +void spirv_converter::operator()(barrier_inst const &in) { + std::int32_t fence = 0; + if (in.has_fence(address_space::global)) { + fence = fence | static_cast(spv::MemorySemantics::CrossWorkgroupMemory) | + static_cast(spv::MemorySemantics::SequentiallyConsistent); + } + if (in.has_fence(address_space::local)) { + fence = fence | static_cast(spv::MemorySemantics::WorkgroupMemory) | + static_cast(spv::MemorySemantics::SequentiallyConsistent); + } + auto scope = unique_.i32_constant(static_cast(spv::Scope::Workgroup)); + auto memory_semantics = unique_.i32_constant(fence); + add(scope, scope, memory_semantics); +} + void spirv_converter::operator()(cast_inst const &in) { auto const cast_from_int = [&](scalar_type to_ty, spv::spv_inst *spv_to_ty, spv::spv_inst *a) -> spv::spv_inst * { @@ -425,9 +498,8 @@ void spirv_converter::operator()(cast_inst const &in) { case scalar_type::c32: case scalar_type::c64: { auto spv_float_ty = visit(*this, *scalar_data_type::get(ctx_, element_type(to_ty))); - auto c0 = add_to(spv_to_ty); auto re = add(spv_float_ty, a); - return add(spv_to_ty, re, c0, + return add(spv_to_ty, re, unique_.null_constant(spv_to_ty), std::vector{0}); } } @@ -448,9 +520,8 @@ void spirv_converter::operator()(cast_inst const &in) { case scalar_type::c32: case scalar_type::c64: { auto spv_float_ty = visit(*this, *scalar_data_type::get(ctx_, element_type(to_ty))); - auto c0 = add_to(spv_to_ty); auto re = add(spv_float_ty, a); - return add(spv_to_ty, re, c0, + return add(spv_to_ty, re, unique_.null_constant(spv_to_ty), std::vector{0}); } } @@ -551,16 +622,13 @@ void spirv_converter::operator()(compare_inst const &in) { }; auto const compare_complex = [&](cmp_condition cond, spv::spv_inst *spv_to_ty, spv::spv_inst *a, spv::spv_inst *b) -> spv::spv_inst * { - if (!bool2_ty_) { - bool2_ty_ = add_to(spv_to_ty, 2); - } switch (cond) { case cmp_condition::eq: { - auto components_equal = add(bool2_ty_, a, b); + auto components_equal = add(unique_.bool2_ty(), a, b); return add(spv_to_ty, components_equal); } case cmp_condition::ne: { - auto components_not_equal = add(bool2_ty_, a, b); + auto components_not_equal = add(unique_.bool2_ty(), a, b); return add(spv_to_ty, components_not_equal); } default: @@ -664,13 +732,7 @@ void spirv_converter::operator()(constant_inst const &in) { if (!std::holds_alternative(in.value())) { throw compilation_error(in.loc(), status::internal_compiler_error); } - auto cst = [&](bool b) -> spv::spv_inst * { - if (b) { - return add_to(spv_ty); - } - return add_to(spv_ty); - }(std::get(in.value())); - declare(in.result(0), cst); + declare(in.result(0), unique_.bool_constant(std::get(in.value()))); } else if (auto st = dyn_cast(in.result(0).ty()); st) { declare(in.result(0), make(st->ty(), spv_ty, in.value())); } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f4141d92..2364b7ea 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -53,6 +53,7 @@ if(SPIRVTools_FOUND) set(SPIRV_VAL_SOURCES spv/arith.ir spv/arith_unary.ir + spv/barrier.ir spv/cast.ir spv/compare.ir spv/unique_function_type.ir diff --git a/test/spv/barrier.ir b/test/spv/barrier.ir new file mode 100644 index 00000000..dbb6e6dd --- /dev/null +++ b/test/spv/barrier.ir @@ -0,0 +1,23 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s + +; CHECK: %[[#I32:]] = OpTypeInt 32 0 +; CHECK: %[[#SCOPE:]] = OpConstant %[[#I32]] 2 +; CHECK: %[[#NO_SEMANTICS:]] = OpConstant %[[#I32]] 0 +; CHECK: %[[#GLOBAL_SEMANTICS:]] = OpConstant %[[#I32]] 528 +; CHECK: %[[#LOCAL_SEMANTICS:]] = OpConstant %[[#I32]] 272 +; CHECK: %[[#GLOBAL_LOCAL_SEMANTICS:]] = OpConstant %[[#I32]] 784 + +func @tbarrier() { + barrier + barrier.global + barrier.local + barrier.global.local +; CHECK: OpControlBarrier %[[#SCOPE]] %[[#SCOPE]] %[[#NO_SEMANTICS]] +; CHECK-NEXT: OpControlBarrier %[[#SCOPE]] %[[#SCOPE]] %[[#GLOBAL_SEMANTICS]] +; CHECK-NEXT: OpControlBarrier %[[#SCOPE]] %[[#SCOPE]] %[[#LOCAL_SEMANTICS]] +; CHECK-NEXT: OpControlBarrier %[[#SCOPE]] %[[#SCOPE]] %[[#GLOBAL_LOCAL_SEMANTICS]] +} + diff --git a/test/spv/cast.ir b/test/spv/cast.ir index 1da2dc3b..f1dc66d9 100644 --- a/test/spv/cast.ir +++ b/test/spv/cast.ir @@ -11,7 +11,6 @@ ; CHECK: %[[#F64:]] = OpTypeFloat 64 ; CHECK: %[[#C64:]] = OpTypeVector %[[#F64]] 2 ; CHECK: %[[#C64_NULL:]] = OpConstantNull %[[#C64]] -; CHECK: %[[#C64_NULL_2:]] = OpConstantNull %[[#C64]] func @tint(%a: i64) { %0 = cast %a : i64 -> i64 @@ -32,7 +31,7 @@ func @tfloat(%a: f32) { ; CHECK: %[[#]] = OpConvertFToS %[[#I8]] %[[#]] ; CHECK-NEXT: %[[#]] = OpFConvert %[[#F64]] %[[#]] ; CHECK-NEXT: %[[#F32_TO_F64:]] = OpFConvert %[[#F64]] %[[#]] -; CHECK-NEXT: %[[#]] = OpCompositeInsert %[[#C64]] %[[#F32_TO_F64]] %[[#C64_NULL_2]] 0 +; CHECK-NEXT: %[[#]] = OpCompositeInsert %[[#C64]] %[[#F32_TO_F64]] %[[#C64_NULL]] 0 } func @tcomplex(%a: c32) { From fc52cfa3724bfee26fa3fedec8dc5a5f76b899ad Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 7 Nov 2024 14:45:18 +0100 Subject: [PATCH 091/297] Create SPIR-V uniquifier class; fix includes Signed-off-by: Carsten Uphoff --- src/CMakeLists.txt | 1 + src/codegen_tools.cpp | 2 + src/codegen_tools.hpp | 3 + src/node/region_node.cpp | 2 + src/pass/convert_to_spirv.cpp | 405 ++++++++++----------------- src/pass/convert_to_spirv.hpp | 3 - src/pass/lower_linalg.cpp | 1 + src/spv/module.cpp | 1 + src/spv/module.hpp | 14 +- src/spv/pass/dump_asm.cpp | 7 + src/spv/pass/dump_asm.hpp | 2 +- src/spv/uniquifier.cpp | 155 ++++++++++ src/spv/uniquifier.hpp | 46 +++ test/cl/test_runtime.cpp | 3 + test/cl/test_runtime.hpp | 2 + test/linalg_blas_a2.cpp | 1 - test/linalg_blas_a2.hpp | 5 +- test/linalg_blas_a3.cpp | 2 +- test/linalg_blas_a3.hpp | 3 +- test/linalg_types.cpp | 2 + test/linalg_types.hpp | 5 +- test/spv/cast.ir | 2 +- test/ze/test_runtime.cpp | 1 + test/ze/test_runtime.hpp | 2 + tools/argparser/argparser_common.hpp | 1 + tools/offline_compiler/main.cpp | 1 + tools/opt/main.cpp | 1 + 27 files changed, 404 insertions(+), 269 deletions(-) create mode 100644 src/spv/uniquifier.cpp create mode 100644 src/spv/uniquifier.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 70a480e5..18dd85f9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -67,6 +67,7 @@ set(SOURCES spv/names.cpp spv/opencl.std.cpp spv/pass/dump_asm.cpp + spv/uniquifier.cpp source.cpp tiling.cpp value.cpp diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 5a3b626f..fe21f4bc 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause #include "codegen_tools.hpp" +#include "compiler_context.hpp" #include "error.hpp" #include "node/data_type_node.hpp" #include "node/inst_node.hpp" @@ -22,6 +23,7 @@ #include #include #include +#include using namespace clir; diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index 26255e27..6807c0d8 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -7,11 +7,13 @@ #include "device_info.hpp" #include "node/data_type_node.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.h" #include "tinytc/types.hpp" #include #include #include +#include #include #include @@ -20,6 +22,7 @@ #include #include #include +#include namespace tinytc { diff --git a/src/node/region_node.cpp b/src/node/region_node.cpp index a5b42aad..2ef10d3d 100644 --- a/src/node/region_node.cpp +++ b/src/node/region_node.cpp @@ -2,7 +2,9 @@ // SPDX-License-Identifier: BSD-3-Clause #include "node/region_node.hpp" +#include "node/data_type_node.hpp" #include "node/inst_node.hpp" +#include "support/ilist_base.hpp" #include "tinytc/tinytc.h" #include diff --git a/src/pass/convert_to_spirv.cpp b/src/pass/convert_to_spirv.cpp index 46c5d518..e8c1e3ae 100644 --- a/src/pass/convert_to_spirv.cpp +++ b/src/pass/convert_to_spirv.cpp @@ -2,20 +2,36 @@ // SPDX-License-Identifier: BSD-3-Clause #include "pass/convert_to_spirv.hpp" +#include "compiler_context.hpp" +#include "error.hpp" #include "node/data_type_node.hpp" #include "node/function_node.hpp" #include "node/inst_node.hpp" #include "node/region_node.hpp" #include "node/value_node.hpp" #include "scalar_type.hpp" +#include "spv/enums.hpp" #include "spv/instructions.hpp" #include "spv/opencl.std.hpp" +#include "spv/uniquifier.hpp" +#include "support/casting.hpp" #include "support/fnv1a.hpp" +#include "support/ilist_base.hpp" +#include "support/util.hpp" #include "support/visit.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" #include +#include +#include +#include +#include +#include #include -#include +#include +#include +#include namespace tinytc { @@ -23,9 +39,7 @@ class spirv_converter { public: inline spirv_converter(::tinytc_core_info const *info, spv::mod &mod, tinytc_compiler_context_t ctx) - : info_(info), mod_(&mod), ctx_(ctx), unique_(*this) {} - - auto operator()(data_type_node const &ty) -> spv::spv_inst *; + : info_(info), mod_(&mod), ctx_(ctx), unique_(ctx, mod) {} // Instruction nodes void operator()(inst_node const &in); @@ -35,67 +49,16 @@ class spirv_converter { void operator()(cast_inst const &in); void operator()(compare_inst const &in); void operator()(constant_inst const &in); + void operator()(group_id_inst const &in); + void operator()(group_size_inst const &in); + void operator()(num_subgroups_inst const &in); + void operator()(subgroup_id_inst const &in); + void operator()(subgroup_local_id_inst const &in); + void operator()(subgroup_size_inst const &in); void run_on_program(program_node const &p); private: - class unique_insts { - public: - inline unique_insts(spirv_converter &conv) : conv_(conv) {} - inline auto bool2_ty() -> spv::spv_inst * { - if (!bool2_ty_) { - auto bool_ty = visit(conv_, *boolean_data_type::get(conv_.ctx_)); - bool2_ty_ = - conv_.add_to(bool_ty, 2); - } - return bool2_ty_; - } - inline auto bool_constant(bool b) { - if (b) { - if (!bool_true_) { - auto bool_ty = visit(conv_, *boolean_data_type::get(conv_.ctx_)); - bool_true_ = - conv_.add_to(bool_ty); - } - return bool_true_; - } - if (!bool_false_) { - auto bool_ty = visit(conv_, *boolean_data_type::get(conv_.ctx_)); - bool_false_ = - conv_.add_to(bool_ty); - } - return bool_false_; - } - inline auto i32_constant(std::int32_t cst) -> spv::spv_inst * { - auto it = i32_cst_.find(cst); - if (it == i32_cst_.end()) { - auto i32_ty = visit(conv_, *scalar_data_type::get(conv_.ctx_, scalar_type::i32)); - auto cst_inst = conv_.add_to( - i32_ty, spv::LiteralContextDependentNumber{cst}); - i32_cst_[cst] = cst_inst; - return cst_inst; - } - return it->second; - } - - inline auto null_constant(spv::spv_inst *spv_ty) -> spv::spv_inst * { - auto it = null_cst_.find(spv_ty); - if (it == null_cst_.end()) { - auto in = conv_.add_to(spv_ty); - null_cst_[spv_ty] = in; - return in; - } - return it->second; - } - - private: - spirv_converter &conv_; - spv::spv_inst *bool2_ty_ = nullptr; - spv::spv_inst *bool_true_ = nullptr, *bool_false_ = nullptr; - std::unordered_map i32_cst_; - std::unordered_map null_cst_; - }; - auto get_scalar_type(value_node const &v) -> scalar_type; auto get_coopmatrix_type(value_node const &v) -> scalar_type; auto declare(value_node const &v, spv::spv_inst *in); @@ -105,26 +68,14 @@ class spirv_converter { auto declare_function_type(std::vector params) -> spv::spv_inst *; void run_on_region(region_node const &fn); void run_on_function(function_node const &fn); - template auto add_to(Args &&...args) -> T * { - auto ptr = std::make_unique(std::forward(args)...).release(); - mod_->insts(S).push_back(ptr); - return ptr; - } - - template auto add(Args &&...args) -> T * { - return add_to(std::forward(args)...); - } ::tinytc_core_info const *info_; spv::mod *mod_; tinytc_compiler_context_t ctx_; - unique_insts unique_; - std::unordered_map spv_tys_; + spv::uniquifier unique_; std::unordered_map vals_; std::unordered_map> multi_vals_; - std::unordered_set capabilities_; std::unordered_multimap function_tys_; - spv::spv_inst *opencl_ext_ = nullptr; core_config core_cfg_ = {}; }; @@ -172,76 +123,13 @@ auto spirv_converter::declare_function_type(std::vector params) return it->second; } } - auto void_ty = visit(*this, *void_data_type::get(ctx_)); + auto void_ty = unique_.spv_ty(*void_data_type::get(ctx_)); return function_tys_ - .emplace(map_key, add_to( - void_ty, std::move(params))) + .emplace(map_key, mod_->add_to(spv::section::type_const_var, void_ty, + std::move(params))) ->second; } -auto spirv_converter::operator()(data_type_node const &ty) -> spv::spv_inst * { - auto it = spv_tys_.find(&ty); - if (it == spv_tys_.end()) { - auto spv_ty = visit( - overloaded{ - [&](void_data_type const &) -> spv::spv_inst * { - return add_to(); - }, - [&](boolean_data_type const &) -> spv::spv_inst * { - return add_to(); - }, - [&](scalar_data_type const &ty) -> spv::spv_inst * { - switch (ty.ty()) { - case scalar_type::i8: - capabilities_.insert(spv::Capability::Int8); - return add_to(8, 0); - case scalar_type::i16: - capabilities_.insert(spv::Capability::Int16); - return add_to(16, 0); - case scalar_type::i32: - return add_to(32, 0); - case scalar_type::i64: - capabilities_.insert(spv::Capability::Int64); - return add_to(64, 0); - case scalar_type::index: { - const auto sz = size(ty.ty()); - if (sz == 8) { - capabilities_.insert(spv::Capability::Int64); - } - return add_to(sz * 8, 0); - } - case scalar_type::f32: - case scalar_type::f64: - capabilities_.insert(spv::Capability::Float64); - return add_to( - size(ty.ty()) * 8, std::nullopt); - case scalar_type::c32: { - auto float_ty = - visit(*this, *scalar_data_type::get(ctx_, scalar_type::f32)); - return add_to(float_ty, 2); - } - case scalar_type::c64: { - auto float_ty = - visit(*this, *scalar_data_type::get(ctx_, scalar_type::f64)); - return add_to(float_ty, 2); - } - } - throw status::internal_compiler_error; - }, - [&](coopmatrix_data_type const &ty) -> spv::spv_inst * { - return visit(*this, *ty.ty()); - }, - [](auto const &) -> spv::spv_inst * { - // @todo - throw status::not_implemented; - }}, - ty); - spv_tys_[&ty] = spv_ty; - return spv_ty; - } - return it->second; -} - void spirv_converter::operator()(inst_node const &in) { // @todo throw compilation_error(in.loc(), status::not_implemented); @@ -252,11 +140,11 @@ void spirv_converter::operator()(arith_inst const &in) { spv::spv_inst *b) -> spv::spv_inst * { switch (op) { case arithmetic::and_: - return add(ty, a, b); + return mod_->add(ty, a, b); case arithmetic::or_: - return add(ty, a, b); + return mod_->add(ty, a, b); case arithmetic::xor_: - return add(ty, a, b); + return mod_->add(ty, a, b); default: break; } @@ -266,25 +154,25 @@ void spirv_converter::operator()(arith_inst const &in) { spv::spv_inst *b) -> spv::spv_inst * { switch (op) { case arithmetic::add: - return add(ty, a, b); + return mod_->add(ty, a, b); case arithmetic::sub: - return add(ty, a, b); + return mod_->add(ty, a, b); case arithmetic::mul: - return add(ty, a, b); + return mod_->add(ty, a, b); case arithmetic::div: - return add(ty, a, b); + return mod_->add(ty, a, b); case arithmetic::rem: - return add(ty, a, b); + return mod_->add(ty, a, b); case arithmetic::shl: - return add(ty, a, b); + return mod_->add(ty, a, b); case arithmetic::shr: - return add(ty, a, b); + return mod_->add(ty, a, b); case arithmetic::and_: - return add(ty, a, b); + return mod_->add(ty, a, b); case arithmetic::or_: - return add(ty, a, b); + return mod_->add(ty, a, b); case arithmetic::xor_: - return add(ty, a, b); + return mod_->add(ty, a, b); } throw compilation_error(in.loc(), status::internal_compiler_error); }; @@ -292,15 +180,15 @@ void spirv_converter::operator()(arith_inst const &in) { spv::spv_inst *b) -> spv::spv_inst * { switch (op) { case arithmetic::add: - return add(ty, a, b); + return mod_->add(ty, a, b); case arithmetic::sub: - return add(ty, a, b); + return mod_->add(ty, a, b); case arithmetic::mul: - return add(ty, a, b); + return mod_->add(ty, a, b); case arithmetic::div: - return add(ty, a, b); + return mod_->add(ty, a, b); case arithmetic::rem: - return add(ty, a, b); + return mod_->add(ty, a, b); default: break; } @@ -324,7 +212,7 @@ void spirv_converter::operator()(arith_inst const &in) { throw compilation_error(in.loc(), status::internal_compiler_error); }; - auto ty = visit(*this, *in.result(0).ty()); + auto ty = unique_.spv_ty(*in.result(0).ty()); if (isa(*in.result(0).ty())) { auto av = val(in.a()); @@ -356,7 +244,7 @@ void spirv_converter::operator()(arith_unary_inst const &in) { spv::spv_inst *a) -> spv::spv_inst * { switch (op) { case arithmetic_unary::not_: - return add(ty, a); + return mod_->add(ty, a); default: break; } @@ -366,13 +254,13 @@ void spirv_converter::operator()(arith_unary_inst const &in) { spv::spv_inst *a) -> spv::spv_inst * { switch (op) { case arithmetic_unary::abs: - return add(ty, opencl_ext_, - static_cast(spv::OpenCLEntrypoint::s_abs), - std::vector{a}); + return mod_->add( + ty, unique_.opencl_ext(), static_cast(spv::OpenCLEntrypoint::s_abs), + std::vector{a}); case arithmetic_unary::neg: - return add(ty, a); + return mod_->add(ty, a); case arithmetic_unary::not_: - return add(ty, a); + return mod_->add(ty, a); default: break; } @@ -382,11 +270,11 @@ void spirv_converter::operator()(arith_unary_inst const &in) { spv::spv_inst *a) -> spv::spv_inst * { switch (op) { case arithmetic_unary::abs: - return add(ty, opencl_ext_, - static_cast(spv::OpenCLEntrypoint::fabs), - std::vector{a}); + return mod_->add(ty, unique_.opencl_ext(), + static_cast(spv::OpenCLEntrypoint::fabs), + std::vector{a}); case arithmetic_unary::neg: - return add(ty, a); + return mod_->add(ty, a); default: break; } @@ -396,29 +284,31 @@ void spirv_converter::operator()(arith_unary_inst const &in) { spv::spv_inst *a) -> spv::spv_inst * { switch (op) { case arithmetic_unary::abs: { - auto spv_a_ty = visit(*this, *scalar_data_type::get(ctx_, sty)); - auto a2 = add(spv_a_ty, a, a); - auto a2_0 = add(ty, a2, std::vector{0}); - auto a2_1 = add(ty, a2, std::vector{1}); - auto a2_0p1 = add(ty, a2_0, a2_1); - return add(ty, opencl_ext_, - static_cast(spv::OpenCLEntrypoint::sqrt), - std::vector{a2_0p1}); + auto spv_a_ty = unique_.spv_ty(*scalar_data_type::get(ctx_, sty)); + auto a2 = mod_->add(spv_a_ty, a, a); + auto a2_0 = + mod_->add(ty, a2, std::vector{0}); + auto a2_1 = + mod_->add(ty, a2, std::vector{1}); + auto a2_0p1 = mod_->add(ty, a2_0, a2_1); + return mod_->add(ty, unique_.opencl_ext(), + static_cast(spv::OpenCLEntrypoint::sqrt), + std::vector{a2_0p1}); } case arithmetic_unary::neg: - return add(ty, a); + return mod_->add(ty, a); case arithmetic_unary::conj: { - auto spv_float_ty = visit(*this, *scalar_data_type::get(ctx_, element_type(sty))); - auto a_im = - add(spv_float_ty, a, std::vector{1}); - auto neg_a_im = add(spv_float_ty, a_im); - return add(ty, neg_a_im, a, - std::vector{1}); + auto spv_float_ty = unique_.spv_ty(*scalar_data_type::get(ctx_, element_type(sty))); + auto a_im = mod_->add(spv_float_ty, a, + std::vector{1}); + auto neg_a_im = mod_->add(spv_float_ty, a_im); + return mod_->add(ty, neg_a_im, a, + std::vector{1}); } case arithmetic_unary::im: - return add(ty, a, std::vector{1}); + return mod_->add(ty, a, std::vector{1}); case arithmetic_unary::re: - return add(ty, a, std::vector{0}); + return mod_->add(ty, a, std::vector{0}); default: break; } @@ -444,7 +334,7 @@ void spirv_converter::operator()(arith_unary_inst const &in) { throw compilation_error(in.loc(), status::internal_compiler_error); }; - auto ty = visit(*this, *in.result(0).ty()); + auto ty = unique_.spv_ty(*in.result(0).ty()); if (isa(*in.a().ty())) { auto av = val(in.a()); declare(in.result(0), make_boolean(in.operation(), ty, av)); @@ -479,7 +369,7 @@ void spirv_converter::operator()(barrier_inst const &in) { } auto scope = unique_.i32_constant(static_cast(spv::Scope::Workgroup)); auto memory_semantics = unique_.i32_constant(fence); - add(scope, scope, memory_semantics); + mod_->add(scope, scope, memory_semantics); } void spirv_converter::operator()(cast_inst const &in) { @@ -491,16 +381,17 @@ void spirv_converter::operator()(cast_inst const &in) { case scalar_type::i32: case scalar_type::i64: case scalar_type::index: - return add(spv_to_ty, a); + return mod_->add(spv_to_ty, a); case scalar_type::f32: case scalar_type::f64: - return add(spv_to_ty, a); + return mod_->add(spv_to_ty, a); case scalar_type::c32: case scalar_type::c64: { - auto spv_float_ty = visit(*this, *scalar_data_type::get(ctx_, element_type(to_ty))); - auto re = add(spv_float_ty, a); - return add(spv_to_ty, re, unique_.null_constant(spv_to_ty), - std::vector{0}); + auto spv_float_ty = unique_.spv_ty(*scalar_data_type::get(ctx_, element_type(to_ty))); + auto re = mod_->add(spv_float_ty, a); + return mod_->add(spv_to_ty, re, + unique_.null_constant(spv_to_ty), + std::vector{0}); } } throw compilation_error(in.loc(), status::ir_forbidden_cast); @@ -513,16 +404,17 @@ void spirv_converter::operator()(cast_inst const &in) { case scalar_type::i32: case scalar_type::i64: case scalar_type::index: - return add(spv_to_ty, a); + return mod_->add(spv_to_ty, a); case scalar_type::f32: case scalar_type::f64: - return add(spv_to_ty, a); + return mod_->add(spv_to_ty, a); case scalar_type::c32: case scalar_type::c64: { - auto spv_float_ty = visit(*this, *scalar_data_type::get(ctx_, element_type(to_ty))); - auto re = add(spv_float_ty, a); - return add(spv_to_ty, re, unique_.null_constant(spv_to_ty), - std::vector{0}); + auto spv_float_ty = unique_.spv_ty(*scalar_data_type::get(ctx_, element_type(to_ty))); + auto re = mod_->add(spv_float_ty, a); + return mod_->add(spv_to_ty, re, + unique_.null_constant(spv_to_ty), + std::vector{0}); } } throw compilation_error(in.loc(), status::ir_forbidden_cast); @@ -532,7 +424,7 @@ void spirv_converter::operator()(cast_inst const &in) { switch (to_ty) { case scalar_type::c32: case scalar_type::c64: - return add(spv_to_ty, a); + return mod_->add(spv_to_ty, a); default: throw compilation_error(in.loc(), status::ir_forbidden_cast); } @@ -540,7 +432,7 @@ void spirv_converter::operator()(cast_inst const &in) { auto const make = [&](scalar_type to_ty, scalar_type a_ty, spv::spv_inst *spv_to_ty, spv::spv_inst *a) -> spv::spv_inst * { if (a_ty == to_ty) { - return add(spv_to_ty, a); + return mod_->add(spv_to_ty, a); } switch (a_ty) { case scalar_type::i8: @@ -560,7 +452,7 @@ void spirv_converter::operator()(cast_inst const &in) { throw compilation_error(in.loc(), status::internal_compiler_error); }; - auto spv_to_ty = visit(*this, *in.result(0).ty()); + auto spv_to_ty = unique_.spv_ty(*in.result(0).ty()); if (auto st = dyn_cast(in.result(0).ty()); st) { auto av = val(in.a()); @@ -588,17 +480,17 @@ void spirv_converter::operator()(compare_inst const &in) { spv::spv_inst *b) -> spv::spv_inst * { switch (cond) { case cmp_condition::eq: - return add(spv_to_ty, a, b); + return mod_->add(spv_to_ty, a, b); case cmp_condition::ne: - return add(spv_to_ty, a, b); + return mod_->add(spv_to_ty, a, b); case cmp_condition::gt: - return add(spv_to_ty, a, b); + return mod_->add(spv_to_ty, a, b); case cmp_condition::ge: - return add(spv_to_ty, a, b); + return mod_->add(spv_to_ty, a, b); case cmp_condition::lt: - return add(spv_to_ty, a, b); + return mod_->add(spv_to_ty, a, b); case cmp_condition::le: - return add(spv_to_ty, a, b); + return mod_->add(spv_to_ty, a, b); } throw compilation_error(in.loc(), status::internal_compiler_error); }; @@ -606,17 +498,17 @@ void spirv_converter::operator()(compare_inst const &in) { spv::spv_inst *b) -> spv::spv_inst * { switch (cond) { case cmp_condition::eq: - return add(spv_to_ty, a, b); + return mod_->add(spv_to_ty, a, b); case cmp_condition::ne: - return add(spv_to_ty, a, b); + return mod_->add(spv_to_ty, a, b); case cmp_condition::gt: - return add(spv_to_ty, a, b); + return mod_->add(spv_to_ty, a, b); case cmp_condition::ge: - return add(spv_to_ty, a, b); + return mod_->add(spv_to_ty, a, b); case cmp_condition::lt: - return add(spv_to_ty, a, b); + return mod_->add(spv_to_ty, a, b); case cmp_condition::le: - return add(spv_to_ty, a, b); + return mod_->add(spv_to_ty, a, b); } throw compilation_error(in.loc(), status::internal_compiler_error); }; @@ -624,12 +516,12 @@ void spirv_converter::operator()(compare_inst const &in) { spv::spv_inst *b) -> spv::spv_inst * { switch (cond) { case cmp_condition::eq: { - auto components_equal = add(unique_.bool2_ty(), a, b); - return add(spv_to_ty, components_equal); + auto components_equal = mod_->add(unique_.bool2_ty(), a, b); + return mod_->add(spv_to_ty, components_equal); } case cmp_condition::ne: { - auto components_not_equal = add(unique_.bool2_ty(), a, b); - return add(spv_to_ty, components_not_equal); + auto components_not_equal = mod_->add(unique_.bool2_ty(), a, b); + return mod_->add(spv_to_ty, components_not_equal); } default: throw compilation_error(in.loc(), status::ir_complex_unsupported); @@ -654,7 +546,7 @@ void spirv_converter::operator()(compare_inst const &in) { throw compilation_error(in.loc(), status::internal_compiler_error); }; - auto spv_to_ty = visit(*this, *in.result(0).ty()); + auto spv_to_ty = unique_.spv_ty(*in.result(0).ty()); auto av = val(in.a()); auto bv = val(in.b()); auto a_ty = get_scalar_type(in.a()); @@ -665,14 +557,16 @@ void spirv_converter::operator()(constant_inst const &in) { auto const make = [&](scalar_type sty, spv::spv_inst *spv_ty, constant_inst::value_type const &val) -> spv::spv_inst * { auto const add_constant = [this, &spv_ty](auto val) -> spv::spv_inst * { - return add_to(spv_ty, val); + return mod_->add_to(spv::section::type_const_var, spv_ty, val); }; auto const add_constant_complex = [this, &spv_ty](spv::spv_inst *spv_float_ty, auto re, auto im) -> spv::spv_inst * { - auto c_re = add_to(spv_float_ty, re); - auto c_im = add_to(spv_float_ty, im); - return add_to( - spv_ty, std::vector{c_re, c_im}); + auto c_re = + mod_->add_to(spv::section::type_const_var, spv_float_ty, re); + auto c_im = + mod_->add_to(spv::section::type_const_var, spv_float_ty, im); + return mod_->add_to(spv::section::type_const_var, spv_ty, + std::vector{c_re, c_im}); }; const auto visitor = overloaded{ [&](bool) -> spv::spv_inst * { return nullptr; }, @@ -705,13 +599,13 @@ void spirv_converter::operator()(constant_inst const &in) { switch (sty) { case scalar_type::c32: { auto spv_float_ty = - visit(*this, *scalar_data_type::get(ctx_, scalar_type::f32)); + unique_.spv_ty(*scalar_data_type::get(ctx_, scalar_type::f32)); return add_constant_complex(spv_float_ty, static_cast(d.real()), static_cast(d.imag())); } case scalar_type::c64: { auto spv_float_ty = - visit(*this, *scalar_data_type::get(ctx_, scalar_type::f64)); + unique_.spv_ty(*scalar_data_type::get(ctx_, scalar_type::f64)); return add_constant_complex(spv_float_ty, d.real(), d.imag()); } default: @@ -726,7 +620,7 @@ void spirv_converter::operator()(constant_inst const &in) { return cst; }; - auto spv_ty = visit(*this, *in.result(0).ty()); + auto spv_ty = unique_.spv_ty(*in.result(0).ty()); if (isa(*in.result(0).ty())) { if (!std::holds_alternative(in.value())) { @@ -745,12 +639,19 @@ void spirv_converter::operator()(constant_inst const &in) { } } +void spirv_converter::operator()(group_id_inst const &in) {} +void spirv_converter::operator()(group_size_inst const &in) {} +void spirv_converter::operator()(num_subgroups_inst const &in) {} +void spirv_converter::operator()(subgroup_id_inst const &in) {} +void spirv_converter::operator()(subgroup_local_id_inst const &in) {} +void spirv_converter::operator()(subgroup_size_inst const &in) {} + void spirv_converter::run_on_region(region_node const ®) { - add(); + mod_->add(); for (auto const &i : reg) { visit(*this, i); } - add(); + mod_->add(); } void spirv_converter::run_on_function(function_node const &fn) { @@ -766,50 +667,46 @@ void spirv_converter::run_on_function(function_node const &fn) { auto params = std::vector{}; params.reserve(fn.num_params()); for (auto const &p : fn.params()) { - params.push_back(visit(*this, *p.ty())); + params.push_back(unique_.spv_ty(*p.ty())); } return params; }()); // Function - auto void_ty = visit(*this, *void_data_type::get(ctx_)); - auto fun = add(void_ty, spv::FunctionControl::None, fun_ty); + auto void_ty = unique_.spv_ty(*void_data_type::get(ctx_)); + auto fun = mod_->add(void_ty, spv::FunctionControl::None, fun_ty); for (auto const &p : fn.params()) { - declare(p, add(visit(*this, *p.ty()))); + declare(p, mod_->add(unique_.spv_ty(*p.ty()))); } run_on_region(fn.body()); - add(); + mod_->add(); // Entry point - add_to( - spv::ExecutionModel::Kernel, fun, std::string{fn.name()}, std::vector{}); + mod_->add_to(spv::section::entry_point, spv::ExecutionModel::Kernel, fun, + std::string{fn.name()}, std::vector{}); // Execution mode auto const work_group_size = fn.work_group_size(); - add_to( - fun, spv::ExecutionMode::LocalSize, - spv::ExecutionModeAttr{ - std::array{work_group_size[0], work_group_size[1], 1}}); - add_to( - fun, spv::ExecutionMode::SubgroupSize, spv::ExecutionModeAttr{fn.subgroup_size()}); + mod_->add_to(spv::section::execution_mode, fun, + spv::ExecutionMode::LocalSize, + spv::ExecutionModeAttr{std::array{ + work_group_size[0], work_group_size[1], 1}}); + mod_->add_to(spv::section::execution_mode, fun, + spv::ExecutionMode::SubgroupSize, + spv::ExecutionModeAttr{fn.subgroup_size()}); } void spirv_converter::run_on_program(program_node const &p) { - capabilities_.clear(); - capabilities_.insert(spv::Capability::Addresses); - capabilities_.insert(spv::Capability::Kernel); - capabilities_.insert(spv::Capability::SubgroupDispatch); - opencl_ext_ = add_to(spv::OpenCLExt); - add_to(spv::AddressingModel::Physical64, - spv::MemoryModel::OpenCL); + unique_.capability(spv::Capability::Addresses); + unique_.capability(spv::Capability::Kernel); + unique_.capability(spv::Capability::SubgroupDispatch); + + mod_->add_to(spv::section::memory_model, spv::AddressingModel::Physical64, + spv::MemoryModel::OpenCL); for (auto const &fn : p) { run_on_function(fn); } - - for (auto const &cap : capabilities_) { - add_to(cap); - } } convert_to_spirv_pass::convert_to_spirv_pass(::tinytc_core_info const *info) diff --git a/src/pass/convert_to_spirv.hpp b/src/pass/convert_to_spirv.hpp index a520a9ca..7559511a 100644 --- a/src/pass/convert_to_spirv.hpp +++ b/src/pass/convert_to_spirv.hpp @@ -7,11 +7,8 @@ #include "device_info.hpp" #include "node/program_node.hpp" #include "spv/module.hpp" -#include "tinytc/types.h" -#include "tinytc/types.hpp" #include -#include namespace tinytc { diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index eef9e37b..4ed4eb5f 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -23,6 +23,7 @@ #include #include +#include #include #include #include diff --git a/src/spv/module.cpp b/src/spv/module.cpp index 953968bd..03bb9511 100644 --- a/src/spv/module.cpp +++ b/src/spv/module.cpp @@ -3,6 +3,7 @@ #include "spv/module.hpp" #include "spv/instructions.hpp" +#include "support/ilist_base.hpp" namespace tinytc { void ilist_callbacks::node_added(spv::spv_inst *) {} diff --git a/src/spv/module.hpp b/src/spv/module.hpp index ab0bb0ff..5e9dec61 100644 --- a/src/spv/module.hpp +++ b/src/spv/module.hpp @@ -4,12 +4,13 @@ #ifndef MODULE_20241029_HPP #define MODULE_20241029_HPP -#include "reference_counted.hpp" #include "support/ilist.hpp" -#include "support/ilist_base.hpp" #include #include +#include +#include +#include namespace tinytc { @@ -55,6 +56,15 @@ class mod final { inline auto major_version() const -> std::int32_t { return major_version_; } inline auto minor_version() const -> std::int32_t { return minor_version_; } + template auto add_to(section s, Args &&...args) -> T * { + auto ptr = std::make_unique(std::forward(args)...).release(); + insts(s).push_back(ptr); + return ptr; + } + template auto add(Args &&...args) -> T * { + return add_to(section::function, std::forward(args)...); + } + private: std::array, num_module_sections> insts_; std::int32_t major_version_, minor_version_; diff --git a/src/spv/pass/dump_asm.cpp b/src/spv/pass/dump_asm.cpp index 5b23dfb5..0a318e1f 100644 --- a/src/spv/pass/dump_asm.cpp +++ b/src/spv/pass/dump_asm.cpp @@ -2,12 +2,19 @@ // SPDX-License-Identifier: BSD-3-Clause #include "spv/pass/dump_asm.hpp" +#include "spv/enums.hpp" #include "spv/module.hpp" #include "spv/opencl.std.hpp" #include "support/casting.hpp" +#include "support/ilist.hpp" +#include "support/ilist_base.hpp" +#include "tinytc/types.hpp" #include +#include +#include #include +#include namespace tinytc::spv { diff --git a/src/spv/pass/dump_asm.hpp b/src/spv/pass/dump_asm.hpp index eeb9763e..dee7128f 100644 --- a/src/spv/pass/dump_asm.hpp +++ b/src/spv/pass/dump_asm.hpp @@ -8,9 +8,9 @@ #include "spv/names.hpp" #include "spv/visit.hpp" +#include #include #include -#include namespace tinytc::spv { diff --git a/src/spv/uniquifier.cpp b/src/spv/uniquifier.cpp new file mode 100644 index 00000000..a3191c6f --- /dev/null +++ b/src/spv/uniquifier.cpp @@ -0,0 +1,155 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/uniquifier.hpp" +#include "compiler_context.hpp" +#include "node/data_type_node.hpp" +#include "spv/instructions.hpp" +#include "spv/opencl.std.hpp" +#include "support/visit.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" + +#include +#include + +namespace tinytc::spv { + +uniquifier::uniquifier(tinytc_compiler_context_t ctx, spv::mod &m) : ctx_(ctx), mod_(&m) {} + +auto uniquifier::bool2_ty() -> spv::spv_inst * { + if (!bool2_ty_) { + auto bool_ty = spv_ty(*boolean_data_type::get(ctx_)); + bool2_ty_ = mod_->add_to(spv::section::type_const_var, bool_ty, 2); + } + return bool2_ty_; +} + +auto uniquifier::bool_constant(bool b) -> spv::spv_inst * { + if (b) { + if (!bool_true_) { + auto bool_ty = spv_ty(*boolean_data_type::get(ctx_)); + bool_true_ = mod_->add_to(spv::section::type_const_var, bool_ty); + } + return bool_true_; + } + if (!bool_false_) { + auto bool_ty = spv_ty(*boolean_data_type::get(ctx_)); + bool_false_ = mod_->add_to(spv::section::type_const_var, bool_ty); + } + return bool_false_; +} +// inline auto builtin(spv::BuiltIn b) -> spv::spv_inst* { +// auto it = builtin_.find(b); +// if (it == builtin_.end()) { +// auto var = add_to(); +// add_to(spv::section::decoration,); +// auto i32_ty = visit( *scalar_data_type::get(ctx_, scalar_type::i32)); +// auto cst_inst = add_to(spv::section::type_const_var, +// i32_ty, spv::LiteralContextDependentNumber{cst}); +// i32_cst_[cst] = cst_inst; +// return cst_inst; +//} +// return it->second; +//} + +void uniquifier::capability(spv::Capability cap) { + if (!capabilities_.contains(cap)) { + mod_->add_to(spv::section::capability, cap); + capabilities_.insert(cap); + } +} + +auto uniquifier::i32_constant(std::int32_t cst) -> spv::spv_inst * { + auto it = i32_cst_.find(cst); + if (it == i32_cst_.end()) { + auto i32_ty = spv_ty(*scalar_data_type::get(ctx_, scalar_type::i32)); + auto cst_inst = mod_->add_to(spv::section::type_const_var, i32_ty, + spv::LiteralContextDependentNumber{cst}); + i32_cst_[cst] = cst_inst; + return cst_inst; + } + return it->second; +} + +auto uniquifier::null_constant(spv::spv_inst *spv_ty) -> spv::spv_inst * { + auto it = null_cst_.find(spv_ty); + if (it == null_cst_.end()) { + auto in = mod_->add_to(spv::section::type_const_var, spv_ty); + null_cst_[spv_ty] = in; + return in; + } + return it->second; +} + +auto uniquifier::opencl_ext() -> spv::spv_inst * { + if (opencl_ext_ == nullptr) { + opencl_ext_ = mod_->add_to(spv::section::ext_inst, spv::OpenCLExt); + } + return opencl_ext_; +} + +auto uniquifier::spv_ty(data_type_node const &ty) -> spv::spv_inst * { + auto it = spv_tys_.find(&ty); + if (it == spv_tys_.end()) { + auto spv_ty_inst = visit( + overloaded{ + [&](void_data_type const &) -> spv::spv_inst * { + return mod_->add_to(spv::section::type_const_var); + }, + [&](boolean_data_type const &) -> spv::spv_inst * { + return mod_->add_to(spv::section::type_const_var); + }, + [&](scalar_data_type const &ty) -> spv::spv_inst * { + switch (ty.ty()) { + case scalar_type::i8: + capability(spv::Capability::Int8); + return mod_->add_to(spv::section::type_const_var, 8, 0); + case scalar_type::i16: + capability(spv::Capability::Int16); + return mod_->add_to(spv::section::type_const_var, 16, 0); + case scalar_type::i32: + return mod_->add_to(spv::section::type_const_var, 32, 0); + case scalar_type::i64: + capability(spv::Capability::Int64); + return mod_->add_to(spv::section::type_const_var, 64, 0); + case scalar_type::index: { + const auto sz = size(ty.ty()); + if (sz == 8) { + capability(spv::Capability::Int64); + } + return mod_->add_to(spv::section::type_const_var, sz * 8, + 0); + } + case scalar_type::f32: + case scalar_type::f64: + capability(spv::Capability::Float64); + return mod_->add_to(spv::section::type_const_var, + size(ty.ty()) * 8, std::nullopt); + case scalar_type::c32: { + auto float_ty = spv_ty(*scalar_data_type::get(ctx_, scalar_type::f32)); + return mod_->add_to(spv::section::type_const_var, + float_ty, 2); + } + case scalar_type::c64: { + auto float_ty = spv_ty(*scalar_data_type::get(ctx_, scalar_type::f64)); + return mod_->add_to(spv::section::type_const_var, + float_ty, 2); + } + } + throw status::internal_compiler_error; + }, + [&](coopmatrix_data_type const &ty) -> spv::spv_inst * { return spv_ty(*ty.ty()); }, + [](auto const &) -> spv::spv_inst * { + // @todo + throw status::not_implemented; + }}, + ty); + spv_tys_[&ty] = spv_ty_inst; + return spv_ty_inst; + } + return it->second; +} + +} // namespace tinytc::spv + diff --git a/src/spv/uniquifier.hpp b/src/spv/uniquifier.hpp new file mode 100644 index 00000000..b2381706 --- /dev/null +++ b/src/spv/uniquifier.hpp @@ -0,0 +1,46 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef UNIQUIFIER_20241107_HPP +#define UNIQUIFIER_20241107_HPP + +#include "spv/enums.hpp" +#include "spv/module.hpp" +#include "tinytc/types.h" + +#include +#include +#include + +namespace tinytc::spv { + +class spv_inst; + +class uniquifier { + public: + uniquifier(tinytc_compiler_context_t ctx, spv::mod &m); + + auto bool2_ty() -> spv::spv_inst *; + auto bool_constant(bool b) -> spv::spv_inst *; + void capability(spv::Capability cap); + auto i32_constant(std::int32_t cst) -> spv::spv_inst *; + auto null_constant(spv::spv_inst *spv_ty) -> spv::spv_inst *; + auto opencl_ext() -> spv::spv_inst *; + auto spv_ty(tinytc_data_type const &ty) -> spv::spv_inst *; + + private: + tinytc_compiler_context_t ctx_; + spv::mod *mod_; + spv::spv_inst *bool2_ty_ = nullptr; + spv::spv_inst *bool_true_ = nullptr, *bool_false_ = nullptr; + spv::spv_inst *opencl_ext_ = nullptr; + std::unordered_map builtin_; + std::unordered_set capabilities_; + std::unordered_map i32_cst_; + std::unordered_map null_cst_; + std::unordered_map spv_tys_; +}; + +} // namespace tinytc::spv + +#endif // UNIQUIFIER_20241107_HPP diff --git a/test/cl/test_runtime.cpp b/test/cl/test_runtime.cpp index f51875ba..0ca2e664 100644 --- a/test/cl/test_runtime.cpp +++ b/test/cl/test_runtime.cpp @@ -4,9 +4,12 @@ #include "test_runtime.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/tinytc_cl.hpp" +#include "tinytc/types.h" #include +#include #include +#include #include using tinytc::CL_CHECK_STATUS; diff --git a/test/cl/test_runtime.hpp b/test/cl/test_runtime.hpp index d645356b..27af6e07 100644 --- a/test/cl/test_runtime.hpp +++ b/test/cl/test_runtime.hpp @@ -7,9 +7,11 @@ #include "cl/argument_handler.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/tinytc_cl.hpp" +#include "tinytc/types.hpp" #include #include +#include class opencl_test_runtime { public: diff --git a/test/linalg_blas_a2.cpp b/test/linalg_blas_a2.cpp index 13d727b5..79223b85 100644 --- a/test/linalg_blas_a2.cpp +++ b/test/linalg_blas_a2.cpp @@ -3,7 +3,6 @@ #include "linalg_blas_a2.hpp" -#include #include namespace tinytc::test { diff --git a/test/linalg_blas_a2.hpp b/test/linalg_blas_a2.hpp index 31e4ee03..5b2a9f3e 100644 --- a/test/linalg_blas_a2.hpp +++ b/test/linalg_blas_a2.hpp @@ -6,12 +6,11 @@ #include "linalg_types.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" -#include -#include -#include #include #include +#include #include namespace tinytc::test { diff --git a/test/linalg_blas_a3.cpp b/test/linalg_blas_a3.cpp index e639407c..6c75946b 100644 --- a/test/linalg_blas_a3.cpp +++ b/test/linalg_blas_a3.cpp @@ -3,8 +3,8 @@ #include "linalg_blas_a3.hpp" -#include #include +#include namespace tinytc::test { diff --git a/test/linalg_blas_a3.hpp b/test/linalg_blas_a3.hpp index 6f2e5feb..970b2e09 100644 --- a/test/linalg_blas_a3.hpp +++ b/test/linalg_blas_a3.hpp @@ -6,10 +6,9 @@ #include "linalg_types.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" -#include #include -#include #include #include #include diff --git a/test/linalg_types.cpp b/test/linalg_types.cpp index 36172f4c..0e8dff0c 100644 --- a/test/linalg_types.cpp +++ b/test/linalg_types.cpp @@ -3,6 +3,8 @@ #include "linalg_types.hpp" +#include + namespace tinytc::test { tensor_layout::tensor_layout(array_view shape, array_view stride, diff --git a/test/linalg_types.hpp b/test/linalg_types.hpp index af8dc803..c6d85bfe 100644 --- a/test/linalg_types.hpp +++ b/test/linalg_types.hpp @@ -5,10 +5,13 @@ #define LINALG_TYPES_20241023_HPP #include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" -#include #include +#include +#include #include +#include #include namespace tinytc::test { diff --git a/test/spv/cast.ir b/test/spv/cast.ir index f1dc66d9..264ccc36 100644 --- a/test/spv/cast.ir +++ b/test/spv/cast.ir @@ -3,8 +3,8 @@ ; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s -; CHECK: OpCapability Int8 ; CHECK: OpCapability Int64 +; CHECK: OpCapability Int8 ; CHECK: %[[#I64:]] = OpTypeInt 64 0 ; CHECK: %[[#I8:]] = OpTypeInt 8 0 ; CHECK: %[[#F32:]] = OpTypeFloat 32 diff --git a/test/ze/test_runtime.cpp b/test/ze/test_runtime.cpp index 13a43e76..1d0002a7 100644 --- a/test/ze/test_runtime.cpp +++ b/test/ze/test_runtime.cpp @@ -4,6 +4,7 @@ #include "test_runtime.hpp" #include +#include using tinytc::ZE_CHECK_STATUS; diff --git a/test/ze/test_runtime.hpp b/test/ze/test_runtime.hpp index 23cd7ca5..837bdb93 100644 --- a/test/ze/test_runtime.hpp +++ b/test/ze/test_runtime.hpp @@ -6,8 +6,10 @@ #include "tinytc/tinytc.hpp" #include "tinytc/tinytc_ze.hpp" +#include "tinytc/types.hpp" #include +#include #include class level_zero_test_runtime { diff --git a/tools/argparser/argparser_common.hpp b/tools/argparser/argparser_common.hpp index 7919bebe..a91c378f 100644 --- a/tools/argparser/argparser_common.hpp +++ b/tools/argparser/argparser_common.hpp @@ -4,6 +4,7 @@ #ifndef ARGPARSER_COMMON_20241010_HPP #define ARGPARSER_COMMON_20241010_HPP +#include "tinytc/types.h" #include "tinytc/types.hpp" #include diff --git a/tools/offline_compiler/main.cpp b/tools/offline_compiler/main.cpp index 360dcc6f..0123e5eb 100644 --- a/tools/offline_compiler/main.cpp +++ b/tools/offline_compiler/main.cpp @@ -5,6 +5,7 @@ #include "argparser_common.hpp" #include "support/fnv1a.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.h" #include "tinytc/types.hpp" #include diff --git a/tools/opt/main.cpp b/tools/opt/main.cpp index e858f7ba..543dfcc1 100644 --- a/tools/opt/main.cpp +++ b/tools/opt/main.cpp @@ -4,6 +4,7 @@ #include "argparser.hpp" #include "argparser_common.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.h" #include "tinytc/types.hpp" #include From 3432b1f929b04b28a0202b302e74682097cdfd00 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 7 Nov 2024 15:32:35 +0100 Subject: [PATCH 092/297] Reorganize SPIR-V conversion code Signed-off-by: Carsten Uphoff --- src/CMakeLists.txt | 1 + src/pass/convert_to_spirv.cpp | 708 +--------------------------------- src/spv/converter.cpp | 645 +++++++++++++++++++++++++++++++ src/spv/converter.hpp | 63 +++ src/spv/uniquifier.cpp | 108 +++--- src/spv/uniquifier.hpp | 38 +- 6 files changed, 794 insertions(+), 769 deletions(-) create mode 100644 src/spv/converter.cpp create mode 100644 src/spv/converter.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 18dd85f9..c7fe5e1c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -63,6 +63,7 @@ set(SOURCES region.cpp required_extensions.cpp scalar_type.cpp + spv/converter.cpp spv/module.cpp spv/names.cpp spv/opencl.std.cpp diff --git a/src/pass/convert_to_spirv.cpp b/src/pass/convert_to_spirv.cpp index e8c1e3ae..5b221b13 100644 --- a/src/pass/convert_to_spirv.cpp +++ b/src/pass/convert_to_spirv.cpp @@ -2,713 +2,13 @@ // SPDX-License-Identifier: BSD-3-Clause #include "pass/convert_to_spirv.hpp" -#include "compiler_context.hpp" -#include "error.hpp" -#include "node/data_type_node.hpp" -#include "node/function_node.hpp" -#include "node/inst_node.hpp" -#include "node/region_node.hpp" -#include "node/value_node.hpp" -#include "scalar_type.hpp" -#include "spv/enums.hpp" -#include "spv/instructions.hpp" -#include "spv/opencl.std.hpp" -#include "spv/uniquifier.hpp" -#include "support/casting.hpp" -#include "support/fnv1a.hpp" -#include "support/ilist_base.hpp" -#include "support/util.hpp" -#include "support/visit.hpp" -#include "tinytc/types.h" +#include "spv/converter.hpp" #include "tinytc/types.hpp" -#include -#include -#include -#include -#include -#include -#include #include -#include -#include namespace tinytc { -class spirv_converter { - public: - inline spirv_converter(::tinytc_core_info const *info, spv::mod &mod, - tinytc_compiler_context_t ctx) - : info_(info), mod_(&mod), ctx_(ctx), unique_(ctx, mod) {} - - // Instruction nodes - void operator()(inst_node const &in); - void operator()(arith_inst const &in); - void operator()(arith_unary_inst const &in); - void operator()(barrier_inst const &in); - void operator()(cast_inst const &in); - void operator()(compare_inst const &in); - void operator()(constant_inst const &in); - void operator()(group_id_inst const &in); - void operator()(group_size_inst const &in); - void operator()(num_subgroups_inst const &in); - void operator()(subgroup_id_inst const &in); - void operator()(subgroup_local_id_inst const &in); - void operator()(subgroup_size_inst const &in); - - void run_on_program(program_node const &p); - - private: - auto get_scalar_type(value_node const &v) -> scalar_type; - auto get_coopmatrix_type(value_node const &v) -> scalar_type; - auto declare(value_node const &v, spv::spv_inst *in); - auto val(value_node const &v) -> spv::spv_inst *; - auto multi_declare(value_node const &v, std::vector insts); - auto multi_val(value_node const &v) -> std::vector &; - auto declare_function_type(std::vector params) -> spv::spv_inst *; - void run_on_region(region_node const &fn); - void run_on_function(function_node const &fn); - - ::tinytc_core_info const *info_; - spv::mod *mod_; - tinytc_compiler_context_t ctx_; - spv::uniquifier unique_; - std::unordered_map vals_; - std::unordered_map> multi_vals_; - std::unordered_multimap function_tys_; - core_config core_cfg_ = {}; -}; - -auto spirv_converter::get_scalar_type(value_node const &v) -> scalar_type { - auto st = dyn_cast(v.ty()); - if (!st) { - throw compilation_error(v.loc(), status::ir_expected_scalar); - } - return st->ty(); -} -auto spirv_converter::get_coopmatrix_type(value_node const &v) -> scalar_type { - auto ct = dyn_cast(v.ty()); - if (!ct) { - throw compilation_error(v.loc(), status::ir_expected_coopmatrix); - } - return ct->component_ty(); -} - -auto spirv_converter::declare(value_node const &v, spv::spv_inst *in) { vals_[&v] = in; } -auto spirv_converter::val(value_node const &v) -> spv::spv_inst * { - if (auto it = vals_.find(&v); it != vals_.end()) { - return it->second; - } - throw compilation_error(v.loc(), status::spirv_undefined_value); -} -auto spirv_converter::multi_declare(value_node const &v, std::vector insts) { - multi_vals_[&v] = std::move(insts); -} -auto spirv_converter::multi_val(value_node const &v) -> std::vector & { - if (auto it = multi_vals_.find(&v); it != multi_vals_.end()) { - return it->second; - } - throw compilation_error(v.loc(), status::spirv_undefined_value); -} -auto spirv_converter::declare_function_type(std::vector params) - -> spv::spv_inst * { - auto map_key = fnv1a0(); - for (auto const &p : params) { - map_key = fnv1a_step(map_key, p); - } - auto range = function_tys_.equal_range(map_key); - for (auto it = range.first; it != range.second; ++it) { - if (std::equal(params.begin(), params.end(), it->second->op1().begin(), - it->second->op1().end())) { - return it->second; - } - } - auto void_ty = unique_.spv_ty(*void_data_type::get(ctx_)); - return function_tys_ - .emplace(map_key, mod_->add_to(spv::section::type_const_var, void_ty, - std::move(params))) - ->second; -} - -void spirv_converter::operator()(inst_node const &in) { - // @todo - throw compilation_error(in.loc(), status::not_implemented); -} - -void spirv_converter::operator()(arith_inst const &in) { - auto const make_boolean = [&](arithmetic op, spv::spv_inst *ty, spv::spv_inst *a, - spv::spv_inst *b) -> spv::spv_inst * { - switch (op) { - case arithmetic::and_: - return mod_->add(ty, a, b); - case arithmetic::or_: - return mod_->add(ty, a, b); - case arithmetic::xor_: - return mod_->add(ty, a, b); - default: - break; - } - throw compilation_error(in.loc(), status::ir_boolean_unsupported); - }; - auto const make_int = [&](arithmetic op, spv::spv_inst *ty, spv::spv_inst *a, - spv::spv_inst *b) -> spv::spv_inst * { - switch (op) { - case arithmetic::add: - return mod_->add(ty, a, b); - case arithmetic::sub: - return mod_->add(ty, a, b); - case arithmetic::mul: - return mod_->add(ty, a, b); - case arithmetic::div: - return mod_->add(ty, a, b); - case arithmetic::rem: - return mod_->add(ty, a, b); - case arithmetic::shl: - return mod_->add(ty, a, b); - case arithmetic::shr: - return mod_->add(ty, a, b); - case arithmetic::and_: - return mod_->add(ty, a, b); - case arithmetic::or_: - return mod_->add(ty, a, b); - case arithmetic::xor_: - return mod_->add(ty, a, b); - } - throw compilation_error(in.loc(), status::internal_compiler_error); - }; - auto const make_float_complex = [&](arithmetic op, spv::spv_inst *ty, spv::spv_inst *a, - spv::spv_inst *b) -> spv::spv_inst * { - switch (op) { - case arithmetic::add: - return mod_->add(ty, a, b); - case arithmetic::sub: - return mod_->add(ty, a, b); - case arithmetic::mul: - return mod_->add(ty, a, b); - case arithmetic::div: - return mod_->add(ty, a, b); - case arithmetic::rem: - return mod_->add(ty, a, b); - default: - break; - } - throw compilation_error(in.loc(), status::ir_fp_unsupported); - }; - auto const make = [&](scalar_type sty, arithmetic op, spv::spv_inst *ty, spv::spv_inst *a, - spv::spv_inst *b) -> spv::spv_inst * { - switch (sty) { - case scalar_type::i8: - case scalar_type::i16: - case scalar_type::i32: - case scalar_type::i64: - case scalar_type::index: - return make_int(op, ty, a, b); - case scalar_type::f32: - case scalar_type::f64: - case scalar_type::c32: - case scalar_type::c64: - return make_float_complex(op, ty, a, b); - } - throw compilation_error(in.loc(), status::internal_compiler_error); - }; - - auto ty = unique_.spv_ty(*in.result(0).ty()); - - if (isa(*in.result(0).ty())) { - auto av = val(in.a()); - auto bv = val(in.b()); - declare(in.result(0), make_boolean(in.operation(), ty, av, bv)); - } else if (auto st = dyn_cast(in.result(0).ty()); st) { - auto av = val(in.a()); - auto bv = val(in.b()); - declare(in.result(0), make(st->ty(), in.operation(), ty, av, bv)); - } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { - auto const length = ct->length(core_cfg_.subgroup_size); - auto insts = std::vector{}; - insts.reserve(length); - - auto &av = multi_val(in.a()); - auto &bv = multi_val(in.b()); - for (std::int64_t i = 0; i < length; ++i) { - insts.emplace_back(make(ct->component_ty(), in.operation(), ty, av[i], bv[i])); - } - - multi_declare(in.result(0), std::move(insts)); - } else { - throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); - } -} - -void spirv_converter::operator()(arith_unary_inst const &in) { - auto const make_boolean = [&](arithmetic_unary op, spv::spv_inst *ty, - spv::spv_inst *a) -> spv::spv_inst * { - switch (op) { - case arithmetic_unary::not_: - return mod_->add(ty, a); - default: - break; - } - throw compilation_error(in.loc(), status::ir_boolean_unsupported); - }; - auto const make_int = [&](arithmetic_unary op, spv::spv_inst *ty, - spv::spv_inst *a) -> spv::spv_inst * { - switch (op) { - case arithmetic_unary::abs: - return mod_->add( - ty, unique_.opencl_ext(), static_cast(spv::OpenCLEntrypoint::s_abs), - std::vector{a}); - case arithmetic_unary::neg: - return mod_->add(ty, a); - case arithmetic_unary::not_: - return mod_->add(ty, a); - default: - break; - } - throw compilation_error(in.loc(), status::internal_compiler_error); - }; - auto const make_float = [&](arithmetic_unary op, spv::spv_inst *ty, - spv::spv_inst *a) -> spv::spv_inst * { - switch (op) { - case arithmetic_unary::abs: - return mod_->add(ty, unique_.opencl_ext(), - static_cast(spv::OpenCLEntrypoint::fabs), - std::vector{a}); - case arithmetic_unary::neg: - return mod_->add(ty, a); - default: - break; - } - throw compilation_error(in.loc(), status::internal_compiler_error); - }; - auto const make_complex = [&](arithmetic_unary op, scalar_type sty, spv::spv_inst *ty, - spv::spv_inst *a) -> spv::spv_inst * { - switch (op) { - case arithmetic_unary::abs: { - auto spv_a_ty = unique_.spv_ty(*scalar_data_type::get(ctx_, sty)); - auto a2 = mod_->add(spv_a_ty, a, a); - auto a2_0 = - mod_->add(ty, a2, std::vector{0}); - auto a2_1 = - mod_->add(ty, a2, std::vector{1}); - auto a2_0p1 = mod_->add(ty, a2_0, a2_1); - return mod_->add(ty, unique_.opencl_ext(), - static_cast(spv::OpenCLEntrypoint::sqrt), - std::vector{a2_0p1}); - } - case arithmetic_unary::neg: - return mod_->add(ty, a); - case arithmetic_unary::conj: { - auto spv_float_ty = unique_.spv_ty(*scalar_data_type::get(ctx_, element_type(sty))); - auto a_im = mod_->add(spv_float_ty, a, - std::vector{1}); - auto neg_a_im = mod_->add(spv_float_ty, a_im); - return mod_->add(ty, neg_a_im, a, - std::vector{1}); - } - case arithmetic_unary::im: - return mod_->add(ty, a, std::vector{1}); - case arithmetic_unary::re: - return mod_->add(ty, a, std::vector{0}); - default: - break; - } - throw compilation_error(in.loc(), status::internal_compiler_error); - }; - auto const make = [&](scalar_type sty, arithmetic_unary op, spv::spv_inst *ty, - spv::spv_inst *a) -> spv::spv_inst * { - switch (sty) { - case scalar_type::i8: - case scalar_type::i16: - case scalar_type::i32: - case scalar_type::i64: - case scalar_type::index: - return make_int(op, ty, a); - case scalar_type::f32: - case scalar_type::f64: - return make_float(op, ty, a); - case scalar_type::c32: - case scalar_type::c64: { - return make_complex(op, sty, ty, a); - } - } - throw compilation_error(in.loc(), status::internal_compiler_error); - }; - - auto ty = unique_.spv_ty(*in.result(0).ty()); - if (isa(*in.a().ty())) { - auto av = val(in.a()); - declare(in.result(0), make_boolean(in.operation(), ty, av)); - } else if (auto st = dyn_cast(in.a().ty()); st) { - auto av = val(in.a()); - declare(in.result(0), make(st->ty(), in.operation(), ty, av)); - } else if (auto ct = dyn_cast(in.a().ty()); ct) { - auto const length = ct->length(core_cfg_.subgroup_size); - auto insts = std::vector{}; - insts.reserve(length); - - auto &av = multi_val(in.a()); - for (std::int64_t i = 0; i < length; ++i) { - insts.emplace_back(make(ct->component_ty(), in.operation(), ty, av[i])); - } - - multi_declare(in.result(0), std::move(insts)); - } else { - throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); - } -} - -void spirv_converter::operator()(barrier_inst const &in) { - std::int32_t fence = 0; - if (in.has_fence(address_space::global)) { - fence = fence | static_cast(spv::MemorySemantics::CrossWorkgroupMemory) | - static_cast(spv::MemorySemantics::SequentiallyConsistent); - } - if (in.has_fence(address_space::local)) { - fence = fence | static_cast(spv::MemorySemantics::WorkgroupMemory) | - static_cast(spv::MemorySemantics::SequentiallyConsistent); - } - auto scope = unique_.i32_constant(static_cast(spv::Scope::Workgroup)); - auto memory_semantics = unique_.i32_constant(fence); - mod_->add(scope, scope, memory_semantics); -} - -void spirv_converter::operator()(cast_inst const &in) { - auto const cast_from_int = [&](scalar_type to_ty, spv::spv_inst *spv_to_ty, - spv::spv_inst *a) -> spv::spv_inst * { - switch (to_ty) { - case scalar_type::i8: - case scalar_type::i16: - case scalar_type::i32: - case scalar_type::i64: - case scalar_type::index: - return mod_->add(spv_to_ty, a); - case scalar_type::f32: - case scalar_type::f64: - return mod_->add(spv_to_ty, a); - case scalar_type::c32: - case scalar_type::c64: { - auto spv_float_ty = unique_.spv_ty(*scalar_data_type::get(ctx_, element_type(to_ty))); - auto re = mod_->add(spv_float_ty, a); - return mod_->add(spv_to_ty, re, - unique_.null_constant(spv_to_ty), - std::vector{0}); - } - } - throw compilation_error(in.loc(), status::ir_forbidden_cast); - }; - auto const cast_from_float = [&](scalar_type to_ty, spv::spv_inst *spv_to_ty, - spv::spv_inst *a) -> spv::spv_inst * { - switch (to_ty) { - case scalar_type::i8: - case scalar_type::i16: - case scalar_type::i32: - case scalar_type::i64: - case scalar_type::index: - return mod_->add(spv_to_ty, a); - case scalar_type::f32: - case scalar_type::f64: - return mod_->add(spv_to_ty, a); - case scalar_type::c32: - case scalar_type::c64: { - auto spv_float_ty = unique_.spv_ty(*scalar_data_type::get(ctx_, element_type(to_ty))); - auto re = mod_->add(spv_float_ty, a); - return mod_->add(spv_to_ty, re, - unique_.null_constant(spv_to_ty), - std::vector{0}); - } - } - throw compilation_error(in.loc(), status::ir_forbidden_cast); - }; - auto const cast_from_complex = [&](scalar_type to_ty, spv::spv_inst *spv_to_ty, - spv::spv_inst *a) -> spv::spv_inst * { - switch (to_ty) { - case scalar_type::c32: - case scalar_type::c64: - return mod_->add(spv_to_ty, a); - default: - throw compilation_error(in.loc(), status::ir_forbidden_cast); - } - }; - auto const make = [&](scalar_type to_ty, scalar_type a_ty, spv::spv_inst *spv_to_ty, - spv::spv_inst *a) -> spv::spv_inst * { - if (a_ty == to_ty) { - return mod_->add(spv_to_ty, a); - } - switch (a_ty) { - case scalar_type::i8: - case scalar_type::i16: - case scalar_type::i32: - case scalar_type::i64: - case scalar_type::index: - return cast_from_int(to_ty, spv_to_ty, a); - case scalar_type::f32: - case scalar_type::f64: - return cast_from_float(to_ty, spv_to_ty, a); - case scalar_type::c32: - case scalar_type::c64: { - return cast_from_complex(to_ty, spv_to_ty, a); - } - } - throw compilation_error(in.loc(), status::internal_compiler_error); - }; - - auto spv_to_ty = unique_.spv_ty(*in.result(0).ty()); - - if (auto st = dyn_cast(in.result(0).ty()); st) { - auto av = val(in.a()); - auto a_ty = get_scalar_type(in.a()); - declare(in.result(0), make(st->ty(), a_ty, spv_to_ty, av)); - } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { - auto const length = ct->length(core_cfg_.subgroup_size); - auto insts = std::vector{}; - insts.reserve(length); - - auto &av = multi_val(in.a()); - auto a_ty = get_coopmatrix_type(in.a()); - for (std::int64_t i = 0; i < length; ++i) { - insts.emplace_back(make(ct->component_ty(), a_ty, spv_to_ty, av[i])); - } - - multi_declare(in.result(0), std::move(insts)); - } else { - throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); - } -} - -void spirv_converter::operator()(compare_inst const &in) { - auto const compare_int = [&](cmp_condition cond, spv::spv_inst *spv_to_ty, spv::spv_inst *a, - spv::spv_inst *b) -> spv::spv_inst * { - switch (cond) { - case cmp_condition::eq: - return mod_->add(spv_to_ty, a, b); - case cmp_condition::ne: - return mod_->add(spv_to_ty, a, b); - case cmp_condition::gt: - return mod_->add(spv_to_ty, a, b); - case cmp_condition::ge: - return mod_->add(spv_to_ty, a, b); - case cmp_condition::lt: - return mod_->add(spv_to_ty, a, b); - case cmp_condition::le: - return mod_->add(spv_to_ty, a, b); - } - throw compilation_error(in.loc(), status::internal_compiler_error); - }; - auto const compare_float = [&](cmp_condition cond, spv::spv_inst *spv_to_ty, spv::spv_inst *a, - spv::spv_inst *b) -> spv::spv_inst * { - switch (cond) { - case cmp_condition::eq: - return mod_->add(spv_to_ty, a, b); - case cmp_condition::ne: - return mod_->add(spv_to_ty, a, b); - case cmp_condition::gt: - return mod_->add(spv_to_ty, a, b); - case cmp_condition::ge: - return mod_->add(spv_to_ty, a, b); - case cmp_condition::lt: - return mod_->add(spv_to_ty, a, b); - case cmp_condition::le: - return mod_->add(spv_to_ty, a, b); - } - throw compilation_error(in.loc(), status::internal_compiler_error); - }; - auto const compare_complex = [&](cmp_condition cond, spv::spv_inst *spv_to_ty, spv::spv_inst *a, - spv::spv_inst *b) -> spv::spv_inst * { - switch (cond) { - case cmp_condition::eq: { - auto components_equal = mod_->add(unique_.bool2_ty(), a, b); - return mod_->add(spv_to_ty, components_equal); - } - case cmp_condition::ne: { - auto components_not_equal = mod_->add(unique_.bool2_ty(), a, b); - return mod_->add(spv_to_ty, components_not_equal); - } - default: - throw compilation_error(in.loc(), status::ir_complex_unsupported); - } - }; - auto const make = [&](scalar_type a_ty, cmp_condition cond, spv::spv_inst *spv_to_ty, - spv::spv_inst *a, spv::spv_inst *b) -> spv::spv_inst * { - switch (a_ty) { - case scalar_type::i8: - case scalar_type::i16: - case scalar_type::i32: - case scalar_type::i64: - case scalar_type::index: - return compare_int(cond, spv_to_ty, a, b); - case scalar_type::f32: - case scalar_type::f64: - return compare_float(cond, spv_to_ty, a, b); - case scalar_type::c32: - case scalar_type::c64: - return compare_complex(cond, spv_to_ty, a, b); - } - throw compilation_error(in.loc(), status::internal_compiler_error); - }; - - auto spv_to_ty = unique_.spv_ty(*in.result(0).ty()); - auto av = val(in.a()); - auto bv = val(in.b()); - auto a_ty = get_scalar_type(in.a()); - declare(in.result(0), make(a_ty, in.cond(), spv_to_ty, av, bv)); -} - -void spirv_converter::operator()(constant_inst const &in) { - auto const make = [&](scalar_type sty, spv::spv_inst *spv_ty, - constant_inst::value_type const &val) -> spv::spv_inst * { - auto const add_constant = [this, &spv_ty](auto val) -> spv::spv_inst * { - return mod_->add_to(spv::section::type_const_var, spv_ty, val); - }; - auto const add_constant_complex = [this, &spv_ty](spv::spv_inst *spv_float_ty, auto re, - auto im) -> spv::spv_inst * { - auto c_re = - mod_->add_to(spv::section::type_const_var, spv_float_ty, re); - auto c_im = - mod_->add_to(spv::section::type_const_var, spv_float_ty, im); - return mod_->add_to(spv::section::type_const_var, spv_ty, - std::vector{c_re, c_im}); - }; - const auto visitor = overloaded{ - [&](bool) -> spv::spv_inst * { return nullptr; }, - [&](std::int64_t i) -> spv::spv_inst * { - switch (sty) { - case scalar_type::i8: - return add_constant(static_cast(i)); - case scalar_type::i16: - return add_constant(static_cast(i)); - case scalar_type::i32: - return add_constant(static_cast(i)); - case scalar_type::i64: - case scalar_type::index: - return add_constant(i); - default: - return nullptr; - } - }, - [&](double d) -> spv::spv_inst * { - switch (sty) { - case scalar_type::f32: - return add_constant(static_cast(d)); - case scalar_type::f64: - return add_constant(d); - default: - return nullptr; - } - }, - [&](std::complex d) -> spv::spv_inst * { - switch (sty) { - case scalar_type::c32: { - auto spv_float_ty = - unique_.spv_ty(*scalar_data_type::get(ctx_, scalar_type::f32)); - return add_constant_complex(spv_float_ty, static_cast(d.real()), - static_cast(d.imag())); - } - case scalar_type::c64: { - auto spv_float_ty = - unique_.spv_ty(*scalar_data_type::get(ctx_, scalar_type::f64)); - return add_constant_complex(spv_float_ty, d.real(), d.imag()); - } - default: - return nullptr; - } - }, - }; - auto cst = std::visit(visitor, val); - if (cst == nullptr) { - throw compilation_error(in.loc(), status::internal_compiler_error); - } - return cst; - }; - - auto spv_ty = unique_.spv_ty(*in.result(0).ty()); - - if (isa(*in.result(0).ty())) { - if (!std::holds_alternative(in.value())) { - throw compilation_error(in.loc(), status::internal_compiler_error); - } - declare(in.result(0), unique_.bool_constant(std::get(in.value()))); - } else if (auto st = dyn_cast(in.result(0).ty()); st) { - declare(in.result(0), make(st->ty(), spv_ty, in.value())); - } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { - auto const length = ct->length(core_cfg_.subgroup_size); - auto cst = make(ct->component_ty(), spv_ty, in.value()); - - multi_declare(in.result(0), std::vector(length, cst)); - } else { - throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); - } -} - -void spirv_converter::operator()(group_id_inst const &in) {} -void spirv_converter::operator()(group_size_inst const &in) {} -void spirv_converter::operator()(num_subgroups_inst const &in) {} -void spirv_converter::operator()(subgroup_id_inst const &in) {} -void spirv_converter::operator()(subgroup_local_id_inst const &in) {} -void spirv_converter::operator()(subgroup_size_inst const &in) {} - -void spirv_converter::run_on_region(region_node const ®) { - mod_->add(); - for (auto const &i : reg) { - visit(*this, i); - } - mod_->add(); -} - -void spirv_converter::run_on_function(function_node const &fn) { - auto const subgroup_size = fn.subgroup_size(); - try { - core_cfg_ = info_->get_core_config(subgroup_size); - } catch (std::out_of_range const &e) { - throw compilation_error(fn.loc(), status::unsupported_subgroup_size); - } - - // Function type - auto fun_ty = declare_function_type([&] { - auto params = std::vector{}; - params.reserve(fn.num_params()); - for (auto const &p : fn.params()) { - params.push_back(unique_.spv_ty(*p.ty())); - } - return params; - }()); - - // Function - auto void_ty = unique_.spv_ty(*void_data_type::get(ctx_)); - auto fun = mod_->add(void_ty, spv::FunctionControl::None, fun_ty); - for (auto const &p : fn.params()) { - declare(p, mod_->add(unique_.spv_ty(*p.ty()))); - } - run_on_region(fn.body()); - mod_->add(); - - // Entry point - mod_->add_to(spv::section::entry_point, spv::ExecutionModel::Kernel, fun, - std::string{fn.name()}, std::vector{}); - - // Execution mode - auto const work_group_size = fn.work_group_size(); - mod_->add_to(spv::section::execution_mode, fun, - spv::ExecutionMode::LocalSize, - spv::ExecutionModeAttr{std::array{ - work_group_size[0], work_group_size[1], 1}}); - mod_->add_to(spv::section::execution_mode, fun, - spv::ExecutionMode::SubgroupSize, - spv::ExecutionModeAttr{fn.subgroup_size()}); -} - -void spirv_converter::run_on_program(program_node const &p) { - unique_.capability(spv::Capability::Addresses); - unique_.capability(spv::Capability::Kernel); - unique_.capability(spv::Capability::SubgroupDispatch); - - mod_->add_to(spv::section::memory_model, spv::AddressingModel::Physical64, - spv::MemoryModel::OpenCL); - - for (auto const &fn : p) { - run_on_function(fn); - } -} - convert_to_spirv_pass::convert_to_spirv_pass(::tinytc_core_info const *info) : info_(std::move(info)) { if (info_ == nullptr) { @@ -717,11 +17,7 @@ convert_to_spirv_pass::convert_to_spirv_pass(::tinytc_core_info const *info) } auto convert_to_spirv_pass::run_on_program(program_node const &p) -> std::unique_ptr { - auto m = std::make_unique(); - - spirv_converter(info_, *m, p.context()).run_on_program(p); - - return m; + return spv::convert_prog_to_spirv(p, *info_); } } // namespace tinytc diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp new file mode 100644 index 00000000..90847511 --- /dev/null +++ b/src/spv/converter.cpp @@ -0,0 +1,645 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/converter.hpp" +#include "compiler_context.hpp" +#include "error.hpp" +#include "node/data_type_node.hpp" +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/program_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "scalar_type.hpp" +#include "spv/enums.hpp" +#include "spv/instructions.hpp" +#include "spv/opencl.std.hpp" +#include "spv/uniquifier.hpp" +#include "support/casting.hpp" +#include "support/ilist_base.hpp" +#include "support/util.hpp" +#include "support/visit.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +auto convert_prog_to_spirv(tinytc_prog const &p, + tinytc_core_info const &info) -> std::unique_ptr { + auto m = std::make_unique(); + + auto conv = inst_converter{p.context(), *m}; + + conv.unique().capability(Capability::Addresses); + conv.unique().capability(Capability::Kernel); + conv.unique().capability(Capability::SubgroupDispatch); + + m->add_to(section::memory_model, AddressingModel::Physical64, + MemoryModel::OpenCL); + + for (auto const &fn : p) { + try { + conv.run_on_function(fn, info.get_core_config(fn.subgroup_size())); + } catch (std::out_of_range const &e) { + throw compilation_error(fn.loc(), status::unsupported_subgroup_size); + } + } + + return m; +} + +inst_converter::inst_converter(tinytc_compiler_context_t ctx, mod &m) + : ctx_(ctx), mod_(&m), unique_(ctx, m) {} + +auto inst_converter::get_scalar_type(value_node const &v) -> scalar_type { + auto st = dyn_cast(v.ty()); + if (!st) { + throw compilation_error(v.loc(), status::ir_expected_scalar); + } + return st->ty(); +} +auto inst_converter::get_coopmatrix_type(value_node const &v) -> scalar_type { + auto ct = dyn_cast(v.ty()); + if (!ct) { + throw compilation_error(v.loc(), status::ir_expected_coopmatrix); + } + return ct->component_ty(); +} + +auto inst_converter::declare(value_node const &v, spv_inst *in) { vals_[&v] = in; } +auto inst_converter::val(value_node const &v) -> spv_inst * { + if (auto it = vals_.find(&v); it != vals_.end()) { + return it->second; + } + throw compilation_error(v.loc(), status::spirv_undefined_value); +} +auto inst_converter::multi_declare(value_node const &v, std::vector insts) { + multi_vals_[&v] = std::move(insts); +} +auto inst_converter::multi_val(value_node const &v) -> std::vector & { + if (auto it = multi_vals_.find(&v); it != multi_vals_.end()) { + return it->second; + } + throw compilation_error(v.loc(), status::spirv_undefined_value); +} + +void inst_converter::operator()(inst_node const &in) { + // @todo + throw compilation_error(in.loc(), status::not_implemented); +} + +void inst_converter::operator()(arith_inst const &in) { + auto const make_boolean = [&](arithmetic op, spv_inst *ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (op) { + case arithmetic::and_: + return mod_->add(ty, a, b); + case arithmetic::or_: + return mod_->add(ty, a, b); + case arithmetic::xor_: + return mod_->add(ty, a, b); + default: + break; + } + throw compilation_error(in.loc(), status::ir_boolean_unsupported); + }; + auto const make_int = [&](arithmetic op, spv_inst *ty, spv_inst *a, spv_inst *b) -> spv_inst * { + switch (op) { + case arithmetic::add: + return mod_->add(ty, a, b); + case arithmetic::sub: + return mod_->add(ty, a, b); + case arithmetic::mul: + return mod_->add(ty, a, b); + case arithmetic::div: + return mod_->add(ty, a, b); + case arithmetic::rem: + return mod_->add(ty, a, b); + case arithmetic::shl: + return mod_->add(ty, a, b); + case arithmetic::shr: + return mod_->add(ty, a, b); + case arithmetic::and_: + return mod_->add(ty, a, b); + case arithmetic::or_: + return mod_->add(ty, a, b); + case arithmetic::xor_: + return mod_->add(ty, a, b); + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + auto const make_float_complex = [&](arithmetic op, spv_inst *ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (op) { + case arithmetic::add: + return mod_->add(ty, a, b); + case arithmetic::sub: + return mod_->add(ty, a, b); + case arithmetic::mul: + return mod_->add(ty, a, b); + case arithmetic::div: + return mod_->add(ty, a, b); + case arithmetic::rem: + return mod_->add(ty, a, b); + default: + break; + } + throw compilation_error(in.loc(), status::ir_fp_unsupported); + }; + auto const make = [&](scalar_type sty, arithmetic op, spv_inst *ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (sty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return make_int(op, ty, a, b); + case scalar_type::f32: + case scalar_type::f64: + case scalar_type::c32: + case scalar_type::c64: + return make_float_complex(op, ty, a, b); + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + + auto ty = unique_.spv_ty(*in.result(0).ty()); + + if (isa(*in.result(0).ty())) { + auto av = val(in.a()); + auto bv = val(in.b()); + declare(in.result(0), make_boolean(in.operation(), ty, av, bv)); + } else if (auto st = dyn_cast(in.result(0).ty()); st) { + auto av = val(in.a()); + auto bv = val(in.b()); + declare(in.result(0), make(st->ty(), in.operation(), ty, av, bv)); + } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { + auto const length = ct->length(core_cfg_.subgroup_size); + auto insts = std::vector{}; + insts.reserve(length); + + auto &av = multi_val(in.a()); + auto &bv = multi_val(in.b()); + for (std::int64_t i = 0; i < length; ++i) { + insts.emplace_back(make(ct->component_ty(), in.operation(), ty, av[i], bv[i])); + } + + multi_declare(in.result(0), std::move(insts)); + } else { + throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); + } +} + +void inst_converter::operator()(arith_unary_inst const &in) { + auto const make_boolean = [&](arithmetic_unary op, spv_inst *ty, spv_inst *a) -> spv_inst * { + switch (op) { + case arithmetic_unary::not_: + return mod_->add(ty, a); + default: + break; + } + throw compilation_error(in.loc(), status::ir_boolean_unsupported); + }; + auto const make_int = [&](arithmetic_unary op, spv_inst *ty, spv_inst *a) -> spv_inst * { + switch (op) { + case arithmetic_unary::abs: + return mod_->add(ty, unique_.opencl_ext(), + static_cast(OpenCLEntrypoint::s_abs), + std::vector{a}); + case arithmetic_unary::neg: + return mod_->add(ty, a); + case arithmetic_unary::not_: + return mod_->add(ty, a); + default: + break; + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + auto const make_float = [&](arithmetic_unary op, spv_inst *ty, spv_inst *a) -> spv_inst * { + switch (op) { + case arithmetic_unary::abs: + return mod_->add(ty, unique_.opencl_ext(), + static_cast(OpenCLEntrypoint::fabs), + std::vector{a}); + case arithmetic_unary::neg: + return mod_->add(ty, a); + default: + break; + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + auto const make_complex = [&](arithmetic_unary op, scalar_type sty, spv_inst *ty, + spv_inst *a) -> spv_inst * { + switch (op) { + case arithmetic_unary::abs: { + auto spv_a_ty = unique_.spv_ty(*scalar_data_type::get(ctx_, sty)); + auto a2 = mod_->add(spv_a_ty, a, a); + auto a2_0 = mod_->add(ty, a2, std::vector{0}); + auto a2_1 = mod_->add(ty, a2, std::vector{1}); + auto a2_0p1 = mod_->add(ty, a2_0, a2_1); + return mod_->add(ty, unique_.opencl_ext(), + static_cast(OpenCLEntrypoint::sqrt), + std::vector{a2_0p1}); + } + case arithmetic_unary::neg: + return mod_->add(ty, a); + case arithmetic_unary::conj: { + auto spv_float_ty = unique_.spv_ty(*scalar_data_type::get(ctx_, element_type(sty))); + auto a_im = + mod_->add(spv_float_ty, a, std::vector{1}); + auto neg_a_im = mod_->add(spv_float_ty, a_im); + return mod_->add(ty, neg_a_im, a, std::vector{1}); + } + case arithmetic_unary::im: + return mod_->add(ty, a, std::vector{1}); + case arithmetic_unary::re: + return mod_->add(ty, a, std::vector{0}); + default: + break; + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + auto const make = [&](scalar_type sty, arithmetic_unary op, spv_inst *ty, + spv_inst *a) -> spv_inst * { + switch (sty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return make_int(op, ty, a); + case scalar_type::f32: + case scalar_type::f64: + return make_float(op, ty, a); + case scalar_type::c32: + case scalar_type::c64: { + return make_complex(op, sty, ty, a); + } + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + + auto ty = unique_.spv_ty(*in.result(0).ty()); + if (isa(*in.a().ty())) { + auto av = val(in.a()); + declare(in.result(0), make_boolean(in.operation(), ty, av)); + } else if (auto st = dyn_cast(in.a().ty()); st) { + auto av = val(in.a()); + declare(in.result(0), make(st->ty(), in.operation(), ty, av)); + } else if (auto ct = dyn_cast(in.a().ty()); ct) { + auto const length = ct->length(core_cfg_.subgroup_size); + auto insts = std::vector{}; + insts.reserve(length); + + auto &av = multi_val(in.a()); + for (std::int64_t i = 0; i < length; ++i) { + insts.emplace_back(make(ct->component_ty(), in.operation(), ty, av[i])); + } + + multi_declare(in.result(0), std::move(insts)); + } else { + throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); + } +} + +void inst_converter::operator()(barrier_inst const &in) { + std::int32_t fence = 0; + if (in.has_fence(address_space::global)) { + fence = fence | static_cast(MemorySemantics::CrossWorkgroupMemory) | + static_cast(MemorySemantics::SequentiallyConsistent); + } + if (in.has_fence(address_space::local)) { + fence = fence | static_cast(MemorySemantics::WorkgroupMemory) | + static_cast(MemorySemantics::SequentiallyConsistent); + } + auto scope = unique_.i32_constant(static_cast(Scope::Workgroup)); + auto memory_semantics = unique_.i32_constant(fence); + mod_->add(scope, scope, memory_semantics); +} + +void inst_converter::operator()(cast_inst const &in) { + auto const cast_from_int = [&](scalar_type to_ty, spv_inst *spv_to_ty, + spv_inst *a) -> spv_inst * { + switch (to_ty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return mod_->add(spv_to_ty, a); + case scalar_type::f32: + case scalar_type::f64: + return mod_->add(spv_to_ty, a); + case scalar_type::c32: + case scalar_type::c64: { + auto spv_float_ty = unique_.spv_ty(*scalar_data_type::get(ctx_, element_type(to_ty))); + auto re = mod_->add(spv_float_ty, a); + return mod_->add(spv_to_ty, re, unique_.null_constant(spv_to_ty), + std::vector{0}); + } + } + throw compilation_error(in.loc(), status::ir_forbidden_cast); + }; + auto const cast_from_float = [&](scalar_type to_ty, spv_inst *spv_to_ty, + spv_inst *a) -> spv_inst * { + switch (to_ty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return mod_->add(spv_to_ty, a); + case scalar_type::f32: + case scalar_type::f64: + return mod_->add(spv_to_ty, a); + case scalar_type::c32: + case scalar_type::c64: { + auto spv_float_ty = unique_.spv_ty(*scalar_data_type::get(ctx_, element_type(to_ty))); + auto re = mod_->add(spv_float_ty, a); + return mod_->add(spv_to_ty, re, unique_.null_constant(spv_to_ty), + std::vector{0}); + } + } + throw compilation_error(in.loc(), status::ir_forbidden_cast); + }; + auto const cast_from_complex = [&](scalar_type to_ty, spv_inst *spv_to_ty, + spv_inst *a) -> spv_inst * { + switch (to_ty) { + case scalar_type::c32: + case scalar_type::c64: + return mod_->add(spv_to_ty, a); + default: + throw compilation_error(in.loc(), status::ir_forbidden_cast); + } + }; + auto const make = [&](scalar_type to_ty, scalar_type a_ty, spv_inst *spv_to_ty, + spv_inst *a) -> spv_inst * { + if (a_ty == to_ty) { + return mod_->add(spv_to_ty, a); + } + switch (a_ty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return cast_from_int(to_ty, spv_to_ty, a); + case scalar_type::f32: + case scalar_type::f64: + return cast_from_float(to_ty, spv_to_ty, a); + case scalar_type::c32: + case scalar_type::c64: { + return cast_from_complex(to_ty, spv_to_ty, a); + } + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + + auto spv_to_ty = unique_.spv_ty(*in.result(0).ty()); + + if (auto st = dyn_cast(in.result(0).ty()); st) { + auto av = val(in.a()); + auto a_ty = get_scalar_type(in.a()); + declare(in.result(0), make(st->ty(), a_ty, spv_to_ty, av)); + } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { + auto const length = ct->length(core_cfg_.subgroup_size); + auto insts = std::vector{}; + insts.reserve(length); + + auto &av = multi_val(in.a()); + auto a_ty = get_coopmatrix_type(in.a()); + for (std::int64_t i = 0; i < length; ++i) { + insts.emplace_back(make(ct->component_ty(), a_ty, spv_to_ty, av[i])); + } + + multi_declare(in.result(0), std::move(insts)); + } else { + throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); + } +} + +void inst_converter::operator()(compare_inst const &in) { + auto const compare_int = [&](cmp_condition cond, spv_inst *spv_to_ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (cond) { + case cmp_condition::eq: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::ne: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::gt: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::ge: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::lt: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::le: + return mod_->add(spv_to_ty, a, b); + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + auto const compare_float = [&](cmp_condition cond, spv_inst *spv_to_ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (cond) { + case cmp_condition::eq: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::ne: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::gt: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::ge: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::lt: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::le: + return mod_->add(spv_to_ty, a, b); + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + auto const compare_complex = [&](cmp_condition cond, spv_inst *spv_to_ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (cond) { + case cmp_condition::eq: { + auto components_equal = mod_->add(unique_.bool2_ty(), a, b); + return mod_->add(spv_to_ty, components_equal); + } + case cmp_condition::ne: { + auto components_not_equal = mod_->add(unique_.bool2_ty(), a, b); + return mod_->add(spv_to_ty, components_not_equal); + } + default: + throw compilation_error(in.loc(), status::ir_complex_unsupported); + } + }; + auto const make = [&](scalar_type a_ty, cmp_condition cond, spv_inst *spv_to_ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (a_ty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return compare_int(cond, spv_to_ty, a, b); + case scalar_type::f32: + case scalar_type::f64: + return compare_float(cond, spv_to_ty, a, b); + case scalar_type::c32: + case scalar_type::c64: + return compare_complex(cond, spv_to_ty, a, b); + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + + auto spv_to_ty = unique_.spv_ty(*in.result(0).ty()); + auto av = val(in.a()); + auto bv = val(in.b()); + auto a_ty = get_scalar_type(in.a()); + declare(in.result(0), make(a_ty, in.cond(), spv_to_ty, av, bv)); +} + +void inst_converter::operator()(constant_inst const &in) { + auto const make = [&](scalar_type sty, spv_inst *spv_ty, + constant_inst::value_type const &val) -> spv_inst * { + auto const add_constant = [this, &spv_ty](auto val) -> spv_inst * { + return mod_->add_to(section::type_const_var, spv_ty, val); + }; + auto const add_constant_complex = [this, &spv_ty](spv_inst *spv_float_ty, auto re, + auto im) -> spv_inst * { + auto c_re = mod_->add_to(section::type_const_var, spv_float_ty, re); + auto c_im = mod_->add_to(section::type_const_var, spv_float_ty, im); + return mod_->add_to(section::type_const_var, spv_ty, + std::vector{c_re, c_im}); + }; + const auto visitor = overloaded{ + [&](bool) -> spv_inst * { return nullptr; }, + [&](std::int64_t i) -> spv_inst * { + switch (sty) { + case scalar_type::i8: + return add_constant(static_cast(i)); + case scalar_type::i16: + return add_constant(static_cast(i)); + case scalar_type::i32: + return add_constant(static_cast(i)); + case scalar_type::i64: + case scalar_type::index: + return add_constant(i); + default: + return nullptr; + } + }, + [&](double d) -> spv_inst * { + switch (sty) { + case scalar_type::f32: + return add_constant(static_cast(d)); + case scalar_type::f64: + return add_constant(d); + default: + return nullptr; + } + }, + [&](std::complex d) -> spv_inst * { + switch (sty) { + case scalar_type::c32: { + auto spv_float_ty = + unique_.spv_ty(*scalar_data_type::get(ctx_, scalar_type::f32)); + return add_constant_complex(spv_float_ty, static_cast(d.real()), + static_cast(d.imag())); + } + case scalar_type::c64: { + auto spv_float_ty = + unique_.spv_ty(*scalar_data_type::get(ctx_, scalar_type::f64)); + return add_constant_complex(spv_float_ty, d.real(), d.imag()); + } + default: + return nullptr; + } + }, + }; + auto cst = std::visit(visitor, val); + if (cst == nullptr) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + return cst; + }; + + auto spv_ty = unique_.spv_ty(*in.result(0).ty()); + + if (isa(*in.result(0).ty())) { + if (!std::holds_alternative(in.value())) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + declare(in.result(0), unique_.bool_constant(std::get(in.value()))); + } else if (auto st = dyn_cast(in.result(0).ty()); st) { + declare(in.result(0), make(st->ty(), spv_ty, in.value())); + } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { + auto const length = ct->length(core_cfg_.subgroup_size); + auto cst = make(ct->component_ty(), spv_ty, in.value()); + + multi_declare(in.result(0), std::vector(length, cst)); + } else { + throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); + } +} + +void inst_converter::operator()(group_id_inst const &in) {} +void inst_converter::operator()(group_size_inst const &in) {} +void inst_converter::operator()(num_subgroups_inst const &in) {} +void inst_converter::operator()(subgroup_id_inst const &in) {} +void inst_converter::operator()(subgroup_local_id_inst const &in) {} +void inst_converter::operator()(subgroup_size_inst const &in) {} + +void inst_converter::run_on_region(region_node const ®) { + mod_->add(); + for (auto const &i : reg) { + visit(*this, i); + } + mod_->add(); +} + +void inst_converter::run_on_function(function_node const &fn, core_config const &core_cfg) { + core_cfg_ = core_cfg; + + // Function type + auto fun_ty = unique_.spv_function_ty([&] { + auto params = std::vector{}; + params.reserve(fn.num_params()); + for (auto const &p : fn.params()) { + params.emplace_back(unique_.spv_ty(*p.ty())); + } + return params; + }()); + + // Function + auto void_ty = unique_.spv_ty(*void_data_type::get(ctx_)); + auto fun = mod_->add(void_ty, FunctionControl::None, fun_ty); + for (auto const &p : fn.params()) { + declare(p, mod_->add(unique_.spv_ty(*p.ty()))); + } + run_on_region(fn.body()); + mod_->add(); + + // Entry point + mod_->add_to(section::entry_point, ExecutionModel::Kernel, fun, + std::string{fn.name()}, std::vector{}); + + // Execution mode + auto const work_group_size = fn.work_group_size(); + mod_->add_to( + section::execution_mode, fun, ExecutionMode::LocalSize, + ExecutionModeAttr{std::array{work_group_size[0], work_group_size[1], 1}}); + mod_->add_to(section::execution_mode, fun, ExecutionMode::SubgroupSize, + ExecutionModeAttr{fn.subgroup_size()}); +} + +} // namespace tinytc::spv + diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp new file mode 100644 index 00000000..9855a7dd --- /dev/null +++ b/src/spv/converter.hpp @@ -0,0 +1,63 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "compiler_context.hpp" +#include "device_info.hpp" +#include "node/inst_node.hpp" +#include "spv/module.hpp" +#include "spv/uniquifier.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include +#include +#include + +namespace tinytc::spv { + +class spv_inst; + +auto convert_prog_to_spirv(tinytc_prog const &p, + tinytc_core_info const &info) -> std::unique_ptr; + +class inst_converter { + public: + inst_converter(tinytc_compiler_context_t ctx, mod &m); + + // Instruction nodes + void operator()(inst_node const &in); + void operator()(arith_inst const &in); + void operator()(arith_unary_inst const &in); + void operator()(barrier_inst const &in); + void operator()(cast_inst const &in); + void operator()(compare_inst const &in); + void operator()(constant_inst const &in); + void operator()(group_id_inst const &in); + void operator()(group_size_inst const &in); + void operator()(num_subgroups_inst const &in); + void operator()(subgroup_id_inst const &in); + void operator()(subgroup_local_id_inst const &in); + void operator()(subgroup_size_inst const &in); + + void run_on_region(tinytc_region const ®); + void run_on_function(tinytc_func const &fn, core_config const &core_cfg); + + inline auto unique() -> uniquifier & { return unique_; } + + private: + auto get_scalar_type(tinytc_value const &v) -> scalar_type; + auto get_coopmatrix_type(tinytc_value const &v) -> scalar_type; + auto declare(tinytc_value const &v, spv_inst *in); + auto val(tinytc_value const &v) -> spv_inst *; + auto multi_declare(tinytc_value const &v, std::vector insts); + auto multi_val(tinytc_value const &v) -> std::vector &; + + tinytc_compiler_context_t ctx_; + mod *mod_; + uniquifier unique_; + std::unordered_map vals_; + std::unordered_map> multi_vals_; + core_config core_cfg_ = {}; +}; + +} // namespace tinytc::spv diff --git a/src/spv/uniquifier.cpp b/src/spv/uniquifier.cpp index a3191c6f..16d25e2a 100644 --- a/src/spv/uniquifier.cpp +++ b/src/spv/uniquifier.cpp @@ -6,141 +6,157 @@ #include "node/data_type_node.hpp" #include "spv/instructions.hpp" #include "spv/opencl.std.hpp" +#include "support/fnv1a.hpp" +#include "support/fnv1a_array_view.hpp" #include "support/visit.hpp" -#include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" +#include #include #include +#include namespace tinytc::spv { -uniquifier::uniquifier(tinytc_compiler_context_t ctx, spv::mod &m) : ctx_(ctx), mod_(&m) {} +uniquifier::uniquifier(tinytc_compiler_context_t ctx, mod &m) : ctx_(ctx), mod_(&m) {} -auto uniquifier::bool2_ty() -> spv::spv_inst * { +auto uniquifier::bool2_ty() -> spv_inst * { if (!bool2_ty_) { auto bool_ty = spv_ty(*boolean_data_type::get(ctx_)); - bool2_ty_ = mod_->add_to(spv::section::type_const_var, bool_ty, 2); + bool2_ty_ = mod_->add_to(section::type_const_var, bool_ty, 2); } return bool2_ty_; } -auto uniquifier::bool_constant(bool b) -> spv::spv_inst * { +auto uniquifier::bool_constant(bool b) -> spv_inst * { if (b) { if (!bool_true_) { auto bool_ty = spv_ty(*boolean_data_type::get(ctx_)); - bool_true_ = mod_->add_to(spv::section::type_const_var, bool_ty); + bool_true_ = mod_->add_to(section::type_const_var, bool_ty); } return bool_true_; } if (!bool_false_) { auto bool_ty = spv_ty(*boolean_data_type::get(ctx_)); - bool_false_ = mod_->add_to(spv::section::type_const_var, bool_ty); + bool_false_ = mod_->add_to(section::type_const_var, bool_ty); } return bool_false_; } -// inline auto builtin(spv::BuiltIn b) -> spv::spv_inst* { +// inline auto builtin(BuiltIn b) -> spv_inst* { // auto it = builtin_.find(b); // if (it == builtin_.end()) { -// auto var = add_to(); -// add_to(spv::section::decoration,); +// auto var = add_to(); +// add_to(section::decoration,); // auto i32_ty = visit( *scalar_data_type::get(ctx_, scalar_type::i32)); -// auto cst_inst = add_to(spv::section::type_const_var, -// i32_ty, spv::LiteralContextDependentNumber{cst}); +// auto cst_inst = add_to(section::type_const_var, +// i32_ty, LiteralContextDependentNumber{cst}); // i32_cst_[cst] = cst_inst; // return cst_inst; //} // return it->second; //} -void uniquifier::capability(spv::Capability cap) { +void uniquifier::capability(Capability cap) { if (!capabilities_.contains(cap)) { - mod_->add_to(spv::section::capability, cap); + mod_->add_to(section::capability, cap); capabilities_.insert(cap); } } -auto uniquifier::i32_constant(std::int32_t cst) -> spv::spv_inst * { +auto uniquifier::i32_constant(std::int32_t cst) -> spv_inst * { auto it = i32_cst_.find(cst); if (it == i32_cst_.end()) { auto i32_ty = spv_ty(*scalar_data_type::get(ctx_, scalar_type::i32)); - auto cst_inst = mod_->add_to(spv::section::type_const_var, i32_ty, - spv::LiteralContextDependentNumber{cst}); + auto cst_inst = mod_->add_to(section::type_const_var, i32_ty, + LiteralContextDependentNumber{cst}); i32_cst_[cst] = cst_inst; return cst_inst; } return it->second; } -auto uniquifier::null_constant(spv::spv_inst *spv_ty) -> spv::spv_inst * { +auto uniquifier::null_constant(spv_inst *spv_ty) -> spv_inst * { auto it = null_cst_.find(spv_ty); if (it == null_cst_.end()) { - auto in = mod_->add_to(spv::section::type_const_var, spv_ty); + auto in = mod_->add_to(section::type_const_var, spv_ty); null_cst_[spv_ty] = in; return in; } return it->second; } -auto uniquifier::opencl_ext() -> spv::spv_inst * { +auto uniquifier::opencl_ext() -> spv_inst * { if (opencl_ext_ == nullptr) { - opencl_ext_ = mod_->add_to(spv::section::ext_inst, spv::OpenCLExt); + opencl_ext_ = mod_->add_to(section::ext_inst, OpenCLExt); } return opencl_ext_; } -auto uniquifier::spv_ty(data_type_node const &ty) -> spv::spv_inst * { +auto uniquifier::spv_function_ty(array_view params) -> spv_inst * { + const auto map_key = fnv1a_step(fnv1a0(), params); + auto range = spv_function_tys_.equal_range(map_key); + for (auto it = range.first; it != range.second; ++it) { + if (std::equal(params.begin(), params.end(), it->second->op1().begin(), + it->second->op1().end())) { + return it->second; + } + } + auto void_ty = spv_ty(*void_data_type::get(ctx_)); + return spv_function_tys_ + .emplace(map_key, + mod_->add_to(section::type_const_var, void_ty, std::move(params))) + ->second; +} + +auto uniquifier::spv_ty(data_type_node const &ty) -> spv_inst * { auto it = spv_tys_.find(&ty); if (it == spv_tys_.end()) { auto spv_ty_inst = visit( overloaded{ - [&](void_data_type const &) -> spv::spv_inst * { - return mod_->add_to(spv::section::type_const_var); + [&](void_data_type const &) -> spv_inst * { + return mod_->add_to(section::type_const_var); }, - [&](boolean_data_type const &) -> spv::spv_inst * { - return mod_->add_to(spv::section::type_const_var); + [&](boolean_data_type const &) -> spv_inst * { + return mod_->add_to(section::type_const_var); }, - [&](scalar_data_type const &ty) -> spv::spv_inst * { + [&](scalar_data_type const &ty) -> spv_inst * { switch (ty.ty()) { case scalar_type::i8: - capability(spv::Capability::Int8); - return mod_->add_to(spv::section::type_const_var, 8, 0); + capability(Capability::Int8); + return mod_->add_to(section::type_const_var, 8, 0); case scalar_type::i16: - capability(spv::Capability::Int16); - return mod_->add_to(spv::section::type_const_var, 16, 0); + capability(Capability::Int16); + return mod_->add_to(section::type_const_var, 16, 0); case scalar_type::i32: - return mod_->add_to(spv::section::type_const_var, 32, 0); + return mod_->add_to(section::type_const_var, 32, 0); case scalar_type::i64: - capability(spv::Capability::Int64); - return mod_->add_to(spv::section::type_const_var, 64, 0); + capability(Capability::Int64); + return mod_->add_to(section::type_const_var, 64, 0); case scalar_type::index: { const auto sz = size(ty.ty()); if (sz == 8) { - capability(spv::Capability::Int64); + capability(Capability::Int64); } - return mod_->add_to(spv::section::type_const_var, sz * 8, - 0); + return mod_->add_to(section::type_const_var, sz * 8, 0); } case scalar_type::f32: case scalar_type::f64: - capability(spv::Capability::Float64); - return mod_->add_to(spv::section::type_const_var, - size(ty.ty()) * 8, std::nullopt); + capability(Capability::Float64); + return mod_->add_to(section::type_const_var, size(ty.ty()) * 8, + std::nullopt); case scalar_type::c32: { auto float_ty = spv_ty(*scalar_data_type::get(ctx_, scalar_type::f32)); - return mod_->add_to(spv::section::type_const_var, - float_ty, 2); + return mod_->add_to(section::type_const_var, float_ty, 2); } case scalar_type::c64: { auto float_ty = spv_ty(*scalar_data_type::get(ctx_, scalar_type::f64)); - return mod_->add_to(spv::section::type_const_var, - float_ty, 2); + return mod_->add_to(section::type_const_var, float_ty, 2); } } throw status::internal_compiler_error; }, - [&](coopmatrix_data_type const &ty) -> spv::spv_inst * { return spv_ty(*ty.ty()); }, - [](auto const &) -> spv::spv_inst * { + [&](coopmatrix_data_type const &ty) -> spv_inst * { return spv_ty(*ty.ty()); }, + [](auto const &) -> spv_inst * { // @todo throw status::not_implemented; }}, diff --git a/src/spv/uniquifier.hpp b/src/spv/uniquifier.hpp index b2381706..320ad35c 100644 --- a/src/spv/uniquifier.hpp +++ b/src/spv/uniquifier.hpp @@ -6,6 +6,7 @@ #include "spv/enums.hpp" #include "spv/module.hpp" +#include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include @@ -15,30 +16,33 @@ namespace tinytc::spv { class spv_inst; +class OpTypeFunction; class uniquifier { public: - uniquifier(tinytc_compiler_context_t ctx, spv::mod &m); + uniquifier(tinytc_compiler_context_t ctx, mod &m); - auto bool2_ty() -> spv::spv_inst *; - auto bool_constant(bool b) -> spv::spv_inst *; - void capability(spv::Capability cap); - auto i32_constant(std::int32_t cst) -> spv::spv_inst *; - auto null_constant(spv::spv_inst *spv_ty) -> spv::spv_inst *; - auto opencl_ext() -> spv::spv_inst *; - auto spv_ty(tinytc_data_type const &ty) -> spv::spv_inst *; + auto bool2_ty() -> spv_inst *; + auto bool_constant(bool b) -> spv_inst *; + void capability(Capability cap); + auto i32_constant(std::int32_t cst) -> spv_inst *; + auto null_constant(spv_inst *spv_ty) -> spv_inst *; + auto opencl_ext() -> spv_inst *; + auto spv_function_ty(array_view params) -> spv_inst *; + auto spv_ty(tinytc_data_type const &ty) -> spv_inst *; private: tinytc_compiler_context_t ctx_; - spv::mod *mod_; - spv::spv_inst *bool2_ty_ = nullptr; - spv::spv_inst *bool_true_ = nullptr, *bool_false_ = nullptr; - spv::spv_inst *opencl_ext_ = nullptr; - std::unordered_map builtin_; - std::unordered_set capabilities_; - std::unordered_map i32_cst_; - std::unordered_map null_cst_; - std::unordered_map spv_tys_; + mod *mod_; + spv_inst *bool2_ty_ = nullptr; + spv_inst *bool_true_ = nullptr, *bool_false_ = nullptr; + spv_inst *opencl_ext_ = nullptr; + std::unordered_map builtin_; + std::unordered_set capabilities_; + std::unordered_map i32_cst_; + std::unordered_map null_cst_; + std::unordered_multimap spv_function_tys_; + std::unordered_map spv_tys_; }; } // namespace tinytc::spv From a13f454c0df437e04602b2685ada1767d5ac6cb5 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 7 Nov 2024 17:37:16 +0100 Subject: [PATCH 093/297] SPIR-V: Add builtin variables Signed-off-by: Carsten Uphoff --- src/scalar_type.cpp | 5 ++ src/scalar_type.hpp | 3 + src/spv/converter.cpp | 78 ++++++++++++----- src/spv/converter.hpp | 3 + src/spv/enums.hpp | 6 +- src/spv/instructions.hpp | 153 +++++++++++++++++++------------- src/spv/names.hpp | 6 +- src/spv/pass/dump_asm.cpp | 11 ++- src/spv/uniquifier.cpp | 174 ++++++++++++++++++++++++------------- src/spv/uniquifier.hpp | 35 +++++++- src/spv/visit.hpp | 34 +++++++- test/CMakeLists.txt | 1 + test/spv/builtin.ir | 50 +++++++++++ tools/spirvgen/spirvgen.py | 25 +++++- 14 files changed, 418 insertions(+), 166 deletions(-) create mode 100644 test/spv/builtin.ir diff --git a/src/scalar_type.cpp b/src/scalar_type.cpp index 121d43fd..374202a0 100644 --- a/src/scalar_type.cpp +++ b/src/scalar_type.cpp @@ -65,6 +65,11 @@ scalar_type compatible_type(scalar_type a_ty, scalar_type b_ty) { return enum_cast(max); } +std::int32_t alignment(scalar_type ty, component_count count) { + const std::int32_t scale = count == component_count::v3 ? 4 : static_cast(count); + return scale * tinytc_scalar_type_size(static_cast(ty)); +} + clir::data_type to_clir_ty(scalar_type ty, clir::address_space as, clir::type_qualifier q) { return to_clir_ty(ty, 1, as, q); } diff --git a/src/scalar_type.hpp b/src/scalar_type.hpp index 47fe0b97..cca70e64 100644 --- a/src/scalar_type.hpp +++ b/src/scalar_type.hpp @@ -15,11 +15,14 @@ namespace tinytc { using host_index_type = std::int64_t; +enum class component_count { v1 = 1, v2 = 2, v3 = 3, v4 = 4, v8 = 8, v16 = 16 }; + bool is_floating_type(scalar_type ty); bool is_complex_type(scalar_type ty); bool is_integer_type(scalar_type ty); scalar_type element_type(scalar_type ty); scalar_type compatible_type(scalar_type a_ty, scalar_type b_ty); +std::int32_t alignment(scalar_type ty, component_count count = component_count::v1); clir::data_type to_clir_ty(scalar_type ty, clir::address_space as = clir::address_space::generic_t, clir::type_qualifier q = clir::type_qualifier::none); clir::data_type to_clir_ty(scalar_type ty, short size, diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index 90847511..61903f1f 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -76,6 +76,17 @@ auto inst_converter::get_coopmatrix_type(value_node const &v) -> scalar_type { return ct->component_ty(); } +auto inst_converter::load_builtin(BuiltIn b) -> spv_inst * { + auto builtin = unique_.builtin_var(b); + if (auto it = std::find(builtins_used_by_function_.begin(), builtins_used_by_function_.end(), + builtin); + it == builtins_used_by_function_.end()) { + builtins_used_by_function_.push_back(builtin); + } + return mod_->add(unique_.builtin_pointee_ty(b), builtin, MemoryAccess::Aligned, + unique_.builtin_alignment(b)); +} + auto inst_converter::declare(value_node const &v, spv_inst *in) { vals_[&v] = in; } auto inst_converter::val(value_node const &v) -> spv_inst * { if (auto it = vals_.find(&v); it != vals_.end()) { @@ -174,7 +185,7 @@ void inst_converter::operator()(arith_inst const &in) { throw compilation_error(in.loc(), status::internal_compiler_error); }; - auto ty = unique_.spv_ty(*in.result(0).ty()); + auto ty = unique_.spv_ty(in.result(0).ty()); if (isa(*in.result(0).ty())) { auto av = val(in.a()); @@ -243,7 +254,7 @@ void inst_converter::operator()(arith_unary_inst const &in) { spv_inst *a) -> spv_inst * { switch (op) { case arithmetic_unary::abs: { - auto spv_a_ty = unique_.spv_ty(*scalar_data_type::get(ctx_, sty)); + auto spv_a_ty = unique_.spv_ty(scalar_data_type::get(ctx_, sty)); auto a2 = mod_->add(spv_a_ty, a, a); auto a2_0 = mod_->add(ty, a2, std::vector{0}); auto a2_1 = mod_->add(ty, a2, std::vector{1}); @@ -255,7 +266,7 @@ void inst_converter::operator()(arith_unary_inst const &in) { case arithmetic_unary::neg: return mod_->add(ty, a); case arithmetic_unary::conj: { - auto spv_float_ty = unique_.spv_ty(*scalar_data_type::get(ctx_, element_type(sty))); + auto spv_float_ty = unique_.spv_ty(scalar_data_type::get(ctx_, element_type(sty))); auto a_im = mod_->add(spv_float_ty, a, std::vector{1}); auto neg_a_im = mod_->add(spv_float_ty, a_im); @@ -290,7 +301,7 @@ void inst_converter::operator()(arith_unary_inst const &in) { throw compilation_error(in.loc(), status::internal_compiler_error); }; - auto ty = unique_.spv_ty(*in.result(0).ty()); + auto ty = unique_.spv_ty(in.result(0).ty()); if (isa(*in.a().ty())) { auto av = val(in.a()); declare(in.result(0), make_boolean(in.operation(), ty, av)); @@ -343,7 +354,7 @@ void inst_converter::operator()(cast_inst const &in) { return mod_->add(spv_to_ty, a); case scalar_type::c32: case scalar_type::c64: { - auto spv_float_ty = unique_.spv_ty(*scalar_data_type::get(ctx_, element_type(to_ty))); + auto spv_float_ty = unique_.spv_ty(scalar_data_type::get(ctx_, element_type(to_ty))); auto re = mod_->add(spv_float_ty, a); return mod_->add(spv_to_ty, re, unique_.null_constant(spv_to_ty), std::vector{0}); @@ -365,7 +376,7 @@ void inst_converter::operator()(cast_inst const &in) { return mod_->add(spv_to_ty, a); case scalar_type::c32: case scalar_type::c64: { - auto spv_float_ty = unique_.spv_ty(*scalar_data_type::get(ctx_, element_type(to_ty))); + auto spv_float_ty = unique_.spv_ty(scalar_data_type::get(ctx_, element_type(to_ty))); auto re = mod_->add(spv_float_ty, a); return mod_->add(spv_to_ty, re, unique_.null_constant(spv_to_ty), std::vector{0}); @@ -406,7 +417,7 @@ void inst_converter::operator()(cast_inst const &in) { throw compilation_error(in.loc(), status::internal_compiler_error); }; - auto spv_to_ty = unique_.spv_ty(*in.result(0).ty()); + auto spv_to_ty = unique_.spv_ty(in.result(0).ty()); if (auto st = dyn_cast(in.result(0).ty()); st) { auto av = val(in.a()); @@ -500,7 +511,7 @@ void inst_converter::operator()(compare_inst const &in) { throw compilation_error(in.loc(), status::internal_compiler_error); }; - auto spv_to_ty = unique_.spv_ty(*in.result(0).ty()); + auto spv_to_ty = unique_.spv_ty(in.result(0).ty()); auto av = val(in.a()); auto bv = val(in.b()); auto a_ty = get_scalar_type(in.a()); @@ -551,13 +562,13 @@ void inst_converter::operator()(constant_inst const &in) { switch (sty) { case scalar_type::c32: { auto spv_float_ty = - unique_.spv_ty(*scalar_data_type::get(ctx_, scalar_type::f32)); + unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::f32)); return add_constant_complex(spv_float_ty, static_cast(d.real()), static_cast(d.imag())); } case scalar_type::c64: { auto spv_float_ty = - unique_.spv_ty(*scalar_data_type::get(ctx_, scalar_type::f64)); + unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::f64)); return add_constant_complex(spv_float_ty, d.real(), d.imag()); } default: @@ -572,7 +583,7 @@ void inst_converter::operator()(constant_inst const &in) { return cst; }; - auto spv_ty = unique_.spv_ty(*in.result(0).ty()); + auto spv_ty = unique_.spv_ty(in.result(0).ty()); if (isa(*in.result(0).ty())) { if (!std::holds_alternative(in.value())) { @@ -591,19 +602,38 @@ void inst_converter::operator()(constant_inst const &in) { } } -void inst_converter::operator()(group_id_inst const &in) {} -void inst_converter::operator()(group_size_inst const &in) {} -void inst_converter::operator()(num_subgroups_inst const &in) {} -void inst_converter::operator()(subgroup_id_inst const &in) {} -void inst_converter::operator()(subgroup_local_id_inst const &in) {} -void inst_converter::operator()(subgroup_size_inst const &in) {} +void inst_converter::operator()(group_id_inst const &in) { + auto gid = load_builtin(BuiltIn::GlobalInvocationId); + auto index_ty = unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); + declare(in.result(0), + mod_->add(index_ty, gid, std::vector{2})); +} +void inst_converter::operator()(group_size_inst const &in) { + auto gs = load_builtin(BuiltIn::GlobalSize); + auto index_ty = unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); + declare(in.result(0), + mod_->add(index_ty, gs, std::vector{2})); +} +void inst_converter::operator()(num_subgroups_inst const &in) { + declare(in.result(0), load_builtin(BuiltIn::NumSubgroups)); +} + +void inst_converter::operator()(parallel_inst const &in) { run_on_region(in.body()); } + +void inst_converter::operator()(subgroup_id_inst const &in) { + declare(in.result(0), load_builtin(BuiltIn::SubgroupId)); +} +void inst_converter::operator()(subgroup_local_id_inst const &in) { + declare(in.result(0), load_builtin(BuiltIn::SubgroupLocalInvocationId)); +} +void inst_converter::operator()(subgroup_size_inst const &in) { + declare(in.result(0), load_builtin(BuiltIn::SubgroupSize)); +} void inst_converter::run_on_region(region_node const ®) { - mod_->add(); for (auto const &i : reg) { visit(*this, i); } - mod_->add(); } void inst_converter::run_on_function(function_node const &fn, core_config const &core_cfg) { @@ -614,23 +644,25 @@ void inst_converter::run_on_function(function_node const &fn, core_config const auto params = std::vector{}; params.reserve(fn.num_params()); for (auto const &p : fn.params()) { - params.emplace_back(unique_.spv_ty(*p.ty())); + params.emplace_back(unique_.spv_ty(p.ty())); } return params; }()); // Function - auto void_ty = unique_.spv_ty(*void_data_type::get(ctx_)); + auto void_ty = unique_.spv_ty(void_data_type::get(ctx_)); auto fun = mod_->add(void_ty, FunctionControl::None, fun_ty); for (auto const &p : fn.params()) { - declare(p, mod_->add(unique_.spv_ty(*p.ty()))); + declare(p, mod_->add(unique_.spv_ty(p.ty()))); } + mod_->add(); run_on_region(fn.body()); + mod_->add(); mod_->add(); // Entry point mod_->add_to(section::entry_point, ExecutionModel::Kernel, fun, - std::string{fn.name()}, std::vector{}); + std::string{fn.name()}, std::move(builtins_used_by_function_)); // Execution mode auto const work_group_size = fn.work_group_size(); diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index 9855a7dd..00db815b 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -35,6 +35,7 @@ class inst_converter { void operator()(group_id_inst const &in); void operator()(group_size_inst const &in); void operator()(num_subgroups_inst const &in); + void operator()(parallel_inst const &in); void operator()(subgroup_id_inst const &in); void operator()(subgroup_local_id_inst const &in); void operator()(subgroup_size_inst const &in); @@ -47,6 +48,7 @@ class inst_converter { private: auto get_scalar_type(tinytc_value const &v) -> scalar_type; auto get_coopmatrix_type(tinytc_value const &v) -> scalar_type; + auto load_builtin(BuiltIn b) -> spv_inst *; auto declare(tinytc_value const &v, spv_inst *in); auto val(tinytc_value const &v) -> spv_inst *; auto multi_declare(tinytc_value const &v, std::vector insts); @@ -57,6 +59,7 @@ class inst_converter { uniquifier unique_; std::unordered_map vals_; std::unordered_map> multi_vals_; + std::vector builtins_used_by_function_; core_config core_cfg_ = {}; }; diff --git a/src/spv/enums.hpp b/src/spv/enums.hpp index 52d9a706..3969f57b 100644 --- a/src/spv/enums.hpp +++ b/src/spv/enums.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_ENUMS_2024115_HPP -#define GENERATED_ENUMS_2024115_HPP +#ifndef GENERATED_ENUMS_2024117_HPP +#define GENERATED_ENUMS_2024117_HPP namespace tinytc::spv { @@ -1422,4 +1422,4 @@ enum class FPEncoding {}; } // namespace tinytc::spv -#endif // GENERATED_ENUMS_2024115_HPP +#endif // GENERATED_ENUMS_2024117_HPP diff --git a/src/spv/instructions.hpp b/src/spv/instructions.hpp index 9a5d6e3c..76541412 100644 --- a/src/spv/instructions.hpp +++ b/src/spv/instructions.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_INSTRUCTIONS_2024115_HPP -#define GENERATED_INSTRUCTIONS_2024115_HPP +#ifndef GENERATED_INSTRUCTIONS_2024117_HPP +#define GENERATED_INSTRUCTIONS_2024117_HPP #include "enums.hpp" #include "error.hpp" @@ -40,7 +40,7 @@ class spv_inst : public ilist_node { bool has_result_id_; }; -using DecorationAttr = std::variant>; +using DecorationAttr = std::variant>; using ExecutionModeAttr = std::variant>; using LiteralContextDependentNumber = std::variant; @@ -51,6 +51,7 @@ using IdResultType = spv_inst *; using IdRef = spv_inst *; using IdScope = spv_inst *; using IdMemorySemantics = spv_inst *; +using MemoryAccessAttr = std::int32_t; using PairIdRefIdRef = std::pair; using PairLiteralIntegerIdRef = std::pair, spv_inst *>; @@ -85,8 +86,8 @@ class OpSourceContinued : public spv_inst { class OpSource : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Source; } - OpSource(SourceLanguage op0, LiteralInteger op1, std::optional op2, - std::optional op3) + OpSource(SourceLanguage op0, LiteralInteger op1, std::optional op2 = std::nullopt, + std::optional op3 = std::nullopt) : spv_inst{Op::Source, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto op0() const -> SourceLanguage const & { return op0_; } @@ -278,7 +279,7 @@ class OpTypeInt : public spv_inst { class OpTypeFloat : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeFloat; } - OpTypeFloat(LiteralInteger op0, std::optional op1) + OpTypeFloat(LiteralInteger op0, std::optional op1 = std::nullopt) : spv_inst{Op::TypeFloat, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} inline auto op0() const -> LiteralInteger const & { return op0_; } inline auto op1() const -> std::optional const & { return op1_; } @@ -316,7 +317,8 @@ class OpTypeImage : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeImage; } OpTypeImage(IdRef op0, Dim op1, LiteralInteger op2, LiteralInteger op3, LiteralInteger op4, - LiteralInteger op5, ImageFormat op6, std::optional op7) + LiteralInteger op5, ImageFormat op6, + std::optional op7 = std::nullopt) : spv_inst{Op::TypeImage, true}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), op5_(std::move(op5)), op6_(std::move(op6)), op7_(std::move(op7)) {} @@ -597,7 +599,7 @@ class OpFunctionCall : public spv_inst { class OpVariable : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Variable; } - OpVariable(IdResultType type, StorageClass op0, std::optional op1) + OpVariable(IdResultType type, StorageClass op0, std::optional op1 = std::nullopt) : spv_inst{Op::Variable, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -629,65 +631,77 @@ class OpImageTexelPointer : public spv_inst { class OpLoad : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Load; } - OpLoad(IdResultType type, IdRef op0, std::optional op1) + OpLoad(IdResultType type, IdRef op0, std::optional op1 = std::nullopt, + std::optional op2 = std::nullopt) : spv_inst{Op::Load, true}, type_(std::move(type)), op0_(std::move(op0)), - op1_(std::move(op1)) {} + op1_(std::move(op1)), op2_(std::move(op2)) {} inline auto type() const -> IdResultType const & { return type_; } inline auto op0() const -> IdRef const & { return op0_; } inline auto op1() const -> std::optional const & { return op1_; } + inline auto op2() const -> std::optional const & { return op2_; } private: IdResultType type_; IdRef op0_; std::optional op1_; + std::optional op2_; }; class OpStore : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Store; } - OpStore(IdRef op0, IdRef op1, std::optional op2) + OpStore(IdRef op0, IdRef op1, std::optional op2 = std::nullopt, + std::optional op3 = std::nullopt) : spv_inst{Op::Store, false}, op0_(std::move(op0)), op1_(std::move(op1)), - op2_(std::move(op2)) {} + op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto op0() const -> IdRef const & { return op0_; } inline auto op1() const -> IdRef const & { return op1_; } inline auto op2() const -> std::optional const & { return op2_; } + inline auto op3() const -> std::optional const & { return op3_; } private: IdRef op0_; IdRef op1_; std::optional op2_; + std::optional op3_; }; class OpCopyMemory : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CopyMemory; } - OpCopyMemory(IdRef op0, IdRef op1, std::optional op2, - std::optional op3) + OpCopyMemory(IdRef op0, IdRef op1, std::optional op2 = std::nullopt, + std::optional op3 = std::nullopt, + std::optional op4 = std::nullopt) : spv_inst{Op::CopyMemory, false}, op0_(std::move(op0)), op1_(std::move(op1)), - op2_(std::move(op2)), op3_(std::move(op3)) {} + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} inline auto op0() const -> IdRef const & { return op0_; } inline auto op1() const -> IdRef const & { return op1_; } inline auto op2() const -> std::optional const & { return op2_; } inline auto op3() const -> std::optional const & { return op3_; } + inline auto op4() const -> std::optional const & { return op4_; } private: IdRef op0_; IdRef op1_; std::optional op2_; std::optional op3_; + std::optional op4_; }; class OpCopyMemorySized : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CopyMemorySized; } constexpr static std::array required_capabilities = { Capability::Addresses, Capability::UntypedPointersKHR}; - OpCopyMemorySized(IdRef op0, IdRef op1, IdRef op2, std::optional op3, - std::optional op4) + OpCopyMemorySized(IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt, + std::optional op4 = std::nullopt, + std::optional op5 = std::nullopt) : spv_inst{Op::CopyMemorySized, false}, op0_(std::move(op0)), op1_(std::move(op1)), - op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), op5_(std::move(op5)) {} inline auto op0() const -> IdRef const & { return op0_; } inline auto op1() const -> IdRef const & { return op1_; } inline auto op2() const -> IdRef const & { return op2_; } inline auto op3() const -> std::optional const & { return op3_; } inline auto op4() const -> std::optional const & { return op4_; } + inline auto op5() const -> std::optional const & { return op5_; } private: IdRef op0_; @@ -695,6 +709,7 @@ class OpCopyMemorySized : public spv_inst { IdRef op2_; std::optional op3_; std::optional op4_; + std::optional op5_; }; class OpAccessChain : public spv_inst { public: @@ -801,17 +816,17 @@ class OpInBoundsPtrAccessChain : public spv_inst { class OpDecorate : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Decorate; } - OpDecorate(IdRef op0, Decoration op1, DecorationAttr op2) + OpDecorate(IdRef op0, Decoration op1, std::optional op2 = std::nullopt) : spv_inst{Op::Decorate, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} inline auto op0() const -> IdRef const & { return op0_; } inline auto op1() const -> Decoration const & { return op1_; } - inline auto op2() const -> DecorationAttr const & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } private: IdRef op0_; Decoration op1_; - DecorationAttr op2_; + std::optional op2_; }; class OpMemberDecorate : public spv_inst { public: @@ -999,7 +1014,7 @@ class OpImageSampleImplicitLod : public spv_inst { } constexpr static std::array required_capabilities = {Capability::Shader}; OpImageSampleImplicitLod(IdResultType type, IdRef op0, IdRef op1, - std::optional op2) + std::optional op2 = std::nullopt) : spv_inst{Op::ImageSampleImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -1039,7 +1054,7 @@ class OpImageSampleDrefImplicitLod : public spv_inst { } constexpr static std::array required_capabilities = {Capability::Shader}; OpImageSampleDrefImplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::ImageSampleDrefImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -1085,7 +1100,7 @@ class OpImageSampleProjImplicitLod : public spv_inst { } constexpr static std::array required_capabilities = {Capability::Shader}; OpImageSampleProjImplicitLod(IdResultType type, IdRef op0, IdRef op1, - std::optional op2) + std::optional op2 = std::nullopt) : spv_inst{Op::ImageSampleProjImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -1126,7 +1141,7 @@ class OpImageSampleProjDrefImplicitLod : public spv_inst { } constexpr static std::array required_capabilities = {Capability::Shader}; OpImageSampleProjDrefImplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::ImageSampleProjDrefImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -1168,7 +1183,8 @@ class OpImageSampleProjDrefExplicitLod : public spv_inst { class OpImageFetch : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageFetch; } - OpImageFetch(IdResultType type, IdRef op0, IdRef op1, std::optional op2) + OpImageFetch(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) : spv_inst{Op::ImageFetch, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -1187,7 +1203,7 @@ class OpImageGather : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageGather; } constexpr static std::array required_capabilities = {Capability::Shader}; OpImageGather(IdResultType type, IdRef op0, IdRef op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::ImageGather, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -1208,7 +1224,7 @@ class OpImageDrefGather : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageDrefGather; } constexpr static std::array required_capabilities = {Capability::Shader}; OpImageDrefGather(IdResultType type, IdRef op0, IdRef op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::ImageDrefGather, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -1227,7 +1243,8 @@ class OpImageDrefGather : public spv_inst { class OpImageRead : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageRead; } - OpImageRead(IdResultType type, IdRef op0, IdRef op1, std::optional op2) + OpImageRead(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) : spv_inst{Op::ImageRead, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -1244,7 +1261,7 @@ class OpImageRead : public spv_inst { class OpImageWrite : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageWrite; } - OpImageWrite(IdRef op0, IdRef op1, IdRef op2, std::optional op3) + OpImageWrite(IdRef op0, IdRef op1, IdRef op2, std::optional op3 = std::nullopt) : spv_inst{Op::ImageWrite, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto op0() const -> IdRef const & { return op0_; } @@ -4185,7 +4202,7 @@ class OpImageSparseSampleImplicitLod : public spv_inst { constexpr static std::array required_capabilities = { Capability::SparseResidency}; OpImageSparseSampleImplicitLod(IdResultType type, IdRef op0, IdRef op1, - std::optional op2) + std::optional op2 = std::nullopt) : spv_inst{Op::ImageSparseSampleImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -4228,7 +4245,7 @@ class OpImageSparseSampleDrefImplicitLod : public spv_inst { constexpr static std::array required_capabilities = { Capability::SparseResidency}; OpImageSparseSampleDrefImplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::ImageSparseSampleDrefImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -4276,7 +4293,7 @@ class OpImageSparseSampleProjImplicitLod : public spv_inst { constexpr static std::array required_capabilities = { Capability::SparseResidency}; OpImageSparseSampleProjImplicitLod(IdResultType type, IdRef op0, IdRef op1, - std::optional op2) + std::optional op2 = std::nullopt) : spv_inst{Op::ImageSparseSampleProjImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -4319,7 +4336,7 @@ class OpImageSparseSampleProjDrefImplicitLod : public spv_inst { constexpr static std::array required_capabilities = { Capability::SparseResidency}; OpImageSparseSampleProjDrefImplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::ImageSparseSampleProjDrefImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -4364,7 +4381,8 @@ class OpImageSparseFetch : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageSparseFetch; } constexpr static std::array required_capabilities = { Capability::SparseResidency}; - OpImageSparseFetch(IdResultType type, IdRef op0, IdRef op1, std::optional op2) + OpImageSparseFetch(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) : spv_inst{Op::ImageSparseFetch, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -4384,7 +4402,7 @@ class OpImageSparseGather : public spv_inst { constexpr static std::array required_capabilities = { Capability::SparseResidency}; OpImageSparseGather(IdResultType type, IdRef op0, IdRef op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::ImageSparseGather, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -4408,7 +4426,7 @@ class OpImageSparseDrefGather : public spv_inst { constexpr static std::array required_capabilities = { Capability::SparseResidency}; OpImageSparseDrefGather(IdResultType type, IdRef op0, IdRef op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::ImageSparseDrefGather, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -4487,7 +4505,8 @@ class OpImageSparseRead : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageSparseRead; } constexpr static std::array required_capabilities = { Capability::SparseResidency}; - OpImageSparseRead(IdResultType type, IdRef op0, IdRef op1, std::optional op2) + OpImageSparseRead(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) : spv_inst{Op::ImageSparseRead, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -4997,7 +5016,7 @@ class OpGroupNonUniformIAdd : public spv_inst { Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, Capability::GroupNonUniformPartitionedNV}; OpGroupNonUniformIAdd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformIAdd, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5020,7 +5039,7 @@ class OpGroupNonUniformFAdd : public spv_inst { Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, Capability::GroupNonUniformPartitionedNV}; OpGroupNonUniformFAdd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformFAdd, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5043,7 +5062,7 @@ class OpGroupNonUniformIMul : public spv_inst { Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, Capability::GroupNonUniformPartitionedNV}; OpGroupNonUniformIMul(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformIMul, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5066,7 +5085,7 @@ class OpGroupNonUniformFMul : public spv_inst { Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, Capability::GroupNonUniformPartitionedNV}; OpGroupNonUniformFMul(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformFMul, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5089,7 +5108,7 @@ class OpGroupNonUniformSMin : public spv_inst { Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, Capability::GroupNonUniformPartitionedNV}; OpGroupNonUniformSMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformSMin, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5112,7 +5131,7 @@ class OpGroupNonUniformUMin : public spv_inst { Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, Capability::GroupNonUniformPartitionedNV}; OpGroupNonUniformUMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformUMin, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5135,7 +5154,7 @@ class OpGroupNonUniformFMin : public spv_inst { Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, Capability::GroupNonUniformPartitionedNV}; OpGroupNonUniformFMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformFMin, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5158,7 +5177,7 @@ class OpGroupNonUniformSMax : public spv_inst { Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, Capability::GroupNonUniformPartitionedNV}; OpGroupNonUniformSMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformSMax, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5181,7 +5200,7 @@ class OpGroupNonUniformUMax : public spv_inst { Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, Capability::GroupNonUniformPartitionedNV}; OpGroupNonUniformUMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformUMax, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5204,7 +5223,7 @@ class OpGroupNonUniformFMax : public spv_inst { Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, Capability::GroupNonUniformPartitionedNV}; OpGroupNonUniformFMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformFMax, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5229,7 +5248,7 @@ class OpGroupNonUniformBitwiseAnd : public spv_inst { Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, Capability::GroupNonUniformPartitionedNV}; OpGroupNonUniformBitwiseAnd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformBitwiseAnd, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5254,7 +5273,7 @@ class OpGroupNonUniformBitwiseOr : public spv_inst { Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, Capability::GroupNonUniformPartitionedNV}; OpGroupNonUniformBitwiseOr(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformBitwiseOr, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5279,7 +5298,7 @@ class OpGroupNonUniformBitwiseXor : public spv_inst { Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, Capability::GroupNonUniformPartitionedNV}; OpGroupNonUniformBitwiseXor(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformBitwiseXor, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5304,7 +5323,7 @@ class OpGroupNonUniformLogicalAnd : public spv_inst { Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, Capability::GroupNonUniformPartitionedNV}; OpGroupNonUniformLogicalAnd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformLogicalAnd, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5329,7 +5348,7 @@ class OpGroupNonUniformLogicalOr : public spv_inst { Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, Capability::GroupNonUniformPartitionedNV}; OpGroupNonUniformLogicalOr(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformLogicalOr, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5354,7 +5373,7 @@ class OpGroupNonUniformLogicalXor : public spv_inst { Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, Capability::GroupNonUniformPartitionedNV}; OpGroupNonUniformLogicalXor(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformLogicalXor, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5502,15 +5521,19 @@ class OpCooperativeMatrixLoadKHR : public spv_inst { } constexpr static std::array required_capabilities = { Capability::CooperativeMatrixKHR}; - OpCooperativeMatrixLoadKHR(IdResultType type, IdRef op0, IdRef op1, std::optional op2, - std::optional op3) + OpCooperativeMatrixLoadKHR(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt, + std::optional op3 = std::nullopt, + std::optional op4 = std::nullopt) : spv_inst{Op::CooperativeMatrixLoadKHR, true}, type_(std::move(type)), - op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} inline auto type() const -> IdResultType const & { return type_; } inline auto op0() const -> IdRef const & { return op0_; } inline auto op1() const -> IdRef const & { return op1_; } inline auto op2() const -> std::optional const & { return op2_; } inline auto op3() const -> std::optional const & { return op3_; } + inline auto op4() const -> std::optional const & { return op4_; } private: IdResultType type_; @@ -5518,6 +5541,7 @@ class OpCooperativeMatrixLoadKHR : public spv_inst { IdRef op1_; std::optional op2_; std::optional op3_; + std::optional op4_; }; class OpCooperativeMatrixStoreKHR : public spv_inst { public: @@ -5526,15 +5550,19 @@ class OpCooperativeMatrixStoreKHR : public spv_inst { } constexpr static std::array required_capabilities = { Capability::CooperativeMatrixKHR}; - OpCooperativeMatrixStoreKHR(IdRef op0, IdRef op1, IdRef op2, std::optional op3, - std::optional op4) + OpCooperativeMatrixStoreKHR(IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt, + std::optional op4 = std::nullopt, + std::optional op5 = std::nullopt) : spv_inst{Op::CooperativeMatrixStoreKHR, false}, op0_(std::move(op0)), - op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)) {} inline auto op0() const -> IdRef const & { return op0_; } inline auto op1() const -> IdRef const & { return op1_; } inline auto op2() const -> IdRef const & { return op2_; } inline auto op3() const -> std::optional const & { return op3_; } inline auto op4() const -> std::optional const & { return op4_; } + inline auto op5() const -> std::optional const & { return op5_; } private: IdRef op0_; @@ -5542,6 +5570,7 @@ class OpCooperativeMatrixStoreKHR : public spv_inst { IdRef op2_; std::optional op3_; std::optional op4_; + std::optional op5_; }; class OpCooperativeMatrixMulAddKHR : public spv_inst { public: @@ -5551,7 +5580,7 @@ class OpCooperativeMatrixMulAddKHR : public spv_inst { constexpr static std::array required_capabilities = { Capability::CooperativeMatrixKHR}; OpCooperativeMatrixMulAddKHR(IdResultType type, IdRef op0, IdRef op1, IdRef op2, - std::optional op3) + std::optional op3 = std::nullopt) : spv_inst{Op::CooperativeMatrixMulAddKHR, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} inline auto type() const -> IdResultType const & { return type_; } @@ -5587,4 +5616,4 @@ class OpCooperativeMatrixLengthKHR : public spv_inst { } // namespace tinytc::spv -#endif // GENERATED_INSTRUCTIONS_2024115_HPP +#endif // GENERATED_INSTRUCTIONS_2024117_HPP diff --git a/src/spv/names.hpp b/src/spv/names.hpp index 3c48c560..af084c81 100644 --- a/src/spv/names.hpp +++ b/src/spv/names.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_NAMES_2024115_HPP -#define GENERATED_NAMES_2024115_HPP +#ifndef GENERATED_NAMES_2024117_HPP +#define GENERATED_NAMES_2024117_HPP #include "enums.hpp" @@ -68,4 +68,4 @@ auto to_string(FPEncoding e) -> char const *; } // namespace tinytc::spv -#endif // GENERATED_NAMES_2024115_HPP +#endif // GENERATED_NAMES_2024117_HPP diff --git a/src/spv/pass/dump_asm.cpp b/src/spv/pass/dump_asm.cpp index 0a318e1f..2dded4ac 100644 --- a/src/spv/pass/dump_asm.cpp +++ b/src/spv/pass/dump_asm.cpp @@ -55,10 +55,11 @@ void dump_asm_pass::pre_visit(spv_inst const &in) { } void dump_asm_pass::operator()(DecorationAttr const &da) { - std::visit(overloaded{[&](std::pair const &a) { - *os_ << " \"" << a.first << '"'; - this->operator()(a.second); - }}, + std::visit(overloaded{[&](auto const &a) { this->operator()(a); }, + [&](std::pair const &a) { + *os_ << " \"" << a.first << '"'; + this->operator()(a.second); + }}, da); } void dump_asm_pass::operator()(ExecutionModeAttr const &ea) { @@ -94,6 +95,8 @@ void dump_asm_pass::operator()(spv_inst *const &in) { *os_ << " %" << s->second; } else if (isa(*in)) { *os_ << " %" << declare(in); + } else if (isa(*in)) { + *os_ << " %" << declare(in); } else { throw status::spirv_forbidden_forward_declaration; } diff --git a/src/spv/uniquifier.cpp b/src/spv/uniquifier.cpp index 16d25e2a..e5a3027e 100644 --- a/src/spv/uniquifier.cpp +++ b/src/spv/uniquifier.cpp @@ -4,16 +4,15 @@ #include "spv/uniquifier.hpp" #include "compiler_context.hpp" #include "node/data_type_node.hpp" +#include "scalar_type.hpp" #include "spv/instructions.hpp" #include "spv/opencl.std.hpp" -#include "support/fnv1a.hpp" #include "support/fnv1a_array_view.hpp" #include "support/visit.hpp" #include "tinytc/types.hpp" #include #include -#include #include namespace tinytc::spv { @@ -21,40 +20,91 @@ namespace tinytc::spv { uniquifier::uniquifier(tinytc_compiler_context_t ctx, mod &m) : ctx_(ctx), mod_(&m) {} auto uniquifier::bool2_ty() -> spv_inst * { - if (!bool2_ty_) { - auto bool_ty = spv_ty(*boolean_data_type::get(ctx_)); - bool2_ty_ = mod_->add_to(section::type_const_var, bool_ty, 2); - } - return bool2_ty_; + return lookup(bool2_ty_, [&] { + auto bool_ty = spv_ty(boolean_data_type::get(ctx_)); + return mod_->add_to(section::type_const_var, bool_ty, 2); + }); } auto uniquifier::bool_constant(bool b) -> spv_inst * { if (b) { - if (!bool_true_) { - auto bool_ty = spv_ty(*boolean_data_type::get(ctx_)); - bool_true_ = mod_->add_to(section::type_const_var, bool_ty); - } - return bool_true_; + return lookup(bool_true_, [&] { + auto bool_ty = spv_ty(boolean_data_type::get(ctx_)); + return mod_->add_to(section::type_const_var, bool_ty); + }); } - if (!bool_false_) { - auto bool_ty = spv_ty(*boolean_data_type::get(ctx_)); - bool_false_ = mod_->add_to(section::type_const_var, bool_ty); + return lookup(bool_false_, [&] { + auto bool_ty = spv_ty(boolean_data_type::get(ctx_)); + return mod_->add_to(section::type_const_var, bool_ty); + }); +} + +auto uniquifier::builtin_alignment(BuiltIn b) -> std::int32_t { + switch (b) { + case BuiltIn::WorkDim: + case BuiltIn::SubgroupSize: + case BuiltIn::SubgroupMaxSize: + case BuiltIn::NumSubgroups: + case BuiltIn::NumEnqueuedSubgroups: + case BuiltIn::SubgroupId: + case BuiltIn::SubgroupLocalInvocationId: + return alignment(scalar_type::i32); + case BuiltIn::GlobalLinearId: + case BuiltIn::LocalInvocationIndex: + return alignment(scalar_type::index); + case BuiltIn::GlobalSize: + case BuiltIn::GlobalInvocationId: + case BuiltIn::WorkgroupSize: + case BuiltIn::EnqueuedWorkgroupSize: + case BuiltIn::LocalInvocationId: + case BuiltIn::NumWorkgroups: + case BuiltIn::WorkgroupId: + case BuiltIn::GlobalOffset: + return alignment(scalar_type::index, component_count::v3); + break; + default: + throw status::internal_compiler_error; } - return bool_false_; } -// inline auto builtin(BuiltIn b) -> spv_inst* { -// auto it = builtin_.find(b); -// if (it == builtin_.end()) { -// auto var = add_to(); -// add_to(section::decoration,); -// auto i32_ty = visit( *scalar_data_type::get(ctx_, scalar_type::i32)); -// auto cst_inst = add_to(section::type_const_var, -// i32_ty, LiteralContextDependentNumber{cst}); -// i32_cst_[cst] = cst_inst; -// return cst_inst; -//} -// return it->second; -//} + +auto uniquifier::builtin_pointee_ty(BuiltIn b) -> spv_inst * { + switch (b) { + case BuiltIn::WorkDim: + case BuiltIn::SubgroupSize: + case BuiltIn::SubgroupMaxSize: + case BuiltIn::NumSubgroups: + case BuiltIn::NumEnqueuedSubgroups: + case BuiltIn::SubgroupId: + case BuiltIn::SubgroupLocalInvocationId: + return spv_ty(scalar_data_type::get(ctx_, scalar_type::i32)); + case BuiltIn::GlobalLinearId: + case BuiltIn::LocalInvocationIndex: + return spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); + case BuiltIn::GlobalSize: + case BuiltIn::GlobalInvocationId: + case BuiltIn::WorkgroupSize: + case BuiltIn::EnqueuedWorkgroupSize: + case BuiltIn::LocalInvocationId: + case BuiltIn::NumWorkgroups: + case BuiltIn::WorkgroupId: + case BuiltIn::GlobalOffset: + return index3_ty(); + break; + default: + throw status::internal_compiler_error; + } +} + +auto uniquifier::builtin_var(BuiltIn b) -> spv_inst * { + return lookup(builtin_, b, [&](BuiltIn b) { + auto pointer_ty = spv_pointer_ty(StorageClass::Input, builtin_pointee_ty(b)); + auto var = mod_->add_to(section::type_const_var, pointer_ty, + StorageClass::Input, std::nullopt); + mod_->add_to(section::decoration, var, Decoration::Constant); + mod_->add_to(section::decoration, var, Decoration::BuiltIn, b); + return var; + }); +} void uniquifier::capability(Capability cap) { if (!capabilities_.contains(cap)) { @@ -64,32 +114,29 @@ void uniquifier::capability(Capability cap) { } auto uniquifier::i32_constant(std::int32_t cst) -> spv_inst * { - auto it = i32_cst_.find(cst); - if (it == i32_cst_.end()) { - auto i32_ty = spv_ty(*scalar_data_type::get(ctx_, scalar_type::i32)); - auto cst_inst = mod_->add_to(section::type_const_var, i32_ty, - LiteralContextDependentNumber{cst}); - i32_cst_[cst] = cst_inst; - return cst_inst; - } - return it->second; + return lookup(i32_cst_, cst, [&](std::int32_t cst) { + auto i32_ty = spv_ty(scalar_data_type::get(ctx_, scalar_type::i32)); + return mod_->add_to(section::type_const_var, i32_ty, + LiteralContextDependentNumber{cst}); + }); +} + +auto uniquifier::index3_ty() -> spv_inst * { + return lookup(index3_ty_, [&] { + auto index_ty = spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); + return mod_->add_to(section::type_const_var, index_ty, 3); + }); } auto uniquifier::null_constant(spv_inst *spv_ty) -> spv_inst * { - auto it = null_cst_.find(spv_ty); - if (it == null_cst_.end()) { - auto in = mod_->add_to(section::type_const_var, spv_ty); - null_cst_[spv_ty] = in; - return in; - } - return it->second; + return lookup(null_cst_, spv_ty, [&](spv_inst *spv_ty) { + return mod_->add_to(section::type_const_var, spv_ty); + }); } auto uniquifier::opencl_ext() -> spv_inst * { - if (opencl_ext_ == nullptr) { - opencl_ext_ = mod_->add_to(section::ext_inst, OpenCLExt); - } - return opencl_ext_; + return lookup(opencl_ext_, + [&] { return mod_->add_to(section::ext_inst, OpenCLExt); }); } auto uniquifier::spv_function_ty(array_view params) -> spv_inst * { @@ -101,17 +148,23 @@ auto uniquifier::spv_function_ty(array_view params) -> spv_inst * { return it->second; } } - auto void_ty = spv_ty(*void_data_type::get(ctx_)); + auto void_ty = spv_ty(void_data_type::get(ctx_)); return spv_function_tys_ .emplace(map_key, mod_->add_to(section::type_const_var, void_ty, std::move(params))) ->second; } -auto uniquifier::spv_ty(data_type_node const &ty) -> spv_inst * { - auto it = spv_tys_.find(&ty); - if (it == spv_tys_.end()) { - auto spv_ty_inst = visit( +auto uniquifier::spv_pointer_ty(StorageClass cls, spv_inst *pointee_ty) -> spv_inst * { + auto key = std::make_pair(cls, pointee_ty); + return lookup(spv_pointer_tys_, key, [&](std::pair const &key) { + return mod_->add_to(section::type_const_var, key.first, key.second); + }); +} + +auto uniquifier::spv_ty(const_tinytc_data_type_t ty) -> spv_inst * { + return lookup(spv_tys_, ty, [&](const_tinytc_data_type_t ty) { + return visit( overloaded{ [&](void_data_type const &) -> spv_inst * { return mod_->add_to(section::type_const_var); @@ -145,26 +198,23 @@ auto uniquifier::spv_ty(data_type_node const &ty) -> spv_inst * { return mod_->add_to(section::type_const_var, size(ty.ty()) * 8, std::nullopt); case scalar_type::c32: { - auto float_ty = spv_ty(*scalar_data_type::get(ctx_, scalar_type::f32)); + auto float_ty = spv_ty(scalar_data_type::get(ctx_, scalar_type::f32)); return mod_->add_to(section::type_const_var, float_ty, 2); } case scalar_type::c64: { - auto float_ty = spv_ty(*scalar_data_type::get(ctx_, scalar_type::f64)); + auto float_ty = spv_ty(scalar_data_type::get(ctx_, scalar_type::f64)); return mod_->add_to(section::type_const_var, float_ty, 2); } } throw status::internal_compiler_error; }, - [&](coopmatrix_data_type const &ty) -> spv_inst * { return spv_ty(*ty.ty()); }, + [&](coopmatrix_data_type const &ty) -> spv_inst * { return spv_ty(ty.ty()); }, [](auto const &) -> spv_inst * { // @todo throw status::not_implemented; }}, - ty); - spv_tys_[&ty] = spv_ty_inst; - return spv_ty_inst; - } - return it->second; + *ty); + }); } } // namespace tinytc::spv diff --git a/src/spv/uniquifier.hpp b/src/spv/uniquifier.hpp index 320ad35c..51c387b7 100644 --- a/src/spv/uniquifier.hpp +++ b/src/spv/uniquifier.hpp @@ -6,12 +6,14 @@ #include "spv/enums.hpp" #include "spv/module.hpp" +#include "support/fnv1a.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include #include #include +#include namespace tinytc::spv { @@ -24,24 +26,55 @@ class uniquifier { auto bool2_ty() -> spv_inst *; auto bool_constant(bool b) -> spv_inst *; + auto builtin_alignment(BuiltIn b) -> std::int32_t; + auto builtin_pointee_ty(BuiltIn b) -> spv_inst *; + auto builtin_var(BuiltIn b) -> spv_inst *; void capability(Capability cap); auto i32_constant(std::int32_t cst) -> spv_inst *; + auto index3_ty() -> spv_inst *; auto null_constant(spv_inst *spv_ty) -> spv_inst *; auto opencl_ext() -> spv_inst *; auto spv_function_ty(array_view params) -> spv_inst *; - auto spv_ty(tinytc_data_type const &ty) -> spv_inst *; + auto spv_pointer_ty(StorageClass cls, spv_inst *pointee_ty) -> spv_inst *; + auto spv_ty(const_tinytc_data_type_t ty) -> spv_inst *; private: + template + auto lookup(Map &map, Key &&key, Maker &&maker) { + auto it = map.find(key); + if (it == map.end()) { + map[key] = maker(key); + return map[key]; + } + return it->second; + } + template auto lookup(spv_inst *&var, Maker &&maker) -> spv_inst * { + if (!var) { + var = maker(); + } + return var; + } + + struct pointer_key_hash { + inline auto + operator()(std::pair const &key) const -> std::size_t { + return fnv1a_combine(key.first, key.second); + } + }; + tinytc_compiler_context_t ctx_; mod *mod_; spv_inst *bool2_ty_ = nullptr; spv_inst *bool_true_ = nullptr, *bool_false_ = nullptr; + spv_inst *index3_ty_ = nullptr; spv_inst *opencl_ext_ = nullptr; std::unordered_map builtin_; std::unordered_set capabilities_; std::unordered_map i32_cst_; std::unordered_map null_cst_; std::unordered_multimap spv_function_tys_; + std::unordered_map, spv_inst *, pointer_key_hash> + spv_pointer_tys_; std::unordered_map spv_tys_; }; diff --git a/src/spv/visit.hpp b/src/spv/visit.hpp index 5c56e858..1f3d992c 100644 --- a/src/spv/visit.hpp +++ b/src/spv/visit.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_VISIT_2024115_HPP -#define GENERATED_VISIT_2024115_HPP +#ifndef GENERATED_VISIT_2024117_HPP +#define GENERATED_VISIT_2024117_HPP namespace tinytc::spv { @@ -955,6 +955,10 @@ template class default_visitor { if (in.op1()) { static_cast(this)->operator()(*in.op1()); } + + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } } auto operator()(OpStore const &in) { static_cast(this)->pre_visit(in); @@ -963,6 +967,10 @@ template class default_visitor { if (in.op2()) { static_cast(this)->operator()(*in.op2()); } + + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } } auto operator()(OpCopyMemory const &in) { static_cast(this)->pre_visit(in); @@ -975,6 +983,10 @@ template class default_visitor { if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + if (in.op4()) { + static_cast(this)->operator()(*in.op4()); + } } auto operator()(OpCopyMemorySized const &in) { static_cast(this)->pre_visit(in); @@ -988,6 +1000,10 @@ template class default_visitor { if (in.op4()) { static_cast(this)->operator()(*in.op4()); } + + if (in.op5()) { + static_cast(this)->operator()(*in.op5()); + } } auto operator()(OpAccessChain const &in) { static_cast(this)->pre_visit(in); @@ -1038,7 +1054,9 @@ template class default_visitor { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); - static_cast(this)->operator()(in.op2()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } } auto operator()(OpMemberDecorate const &in) { static_cast(this)->pre_visit(in); @@ -2879,6 +2897,10 @@ template class default_visitor { if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + if (in.op4()) { + static_cast(this)->operator()(*in.op4()); + } } auto operator()(OpCooperativeMatrixStoreKHR const &in) { static_cast(this)->pre_visit(in); @@ -2892,6 +2914,10 @@ template class default_visitor { if (in.op4()) { static_cast(this)->operator()(*in.op4()); } + + if (in.op5()) { + static_cast(this)->operator()(*in.op5()); + } } auto operator()(OpCooperativeMatrixMulAddKHR const &in) { static_cast(this)->pre_visit(in); @@ -2912,4 +2938,4 @@ template class default_visitor { } // namespace tinytc::spv -#endif // GENERATED_VISIT_2024115_HPP +#endif // GENERATED_VISIT_2024117_HPP diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 2364b7ea..027d4479 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -54,6 +54,7 @@ if(SPIRVTools_FOUND) spv/arith.ir spv/arith_unary.ir spv/barrier.ir + spv/builtin.ir spv/cast.ir spv/compare.ir spv/unique_function_type.ir diff --git a/test/spv/builtin.ir b/test/spv/builtin.ir new file mode 100644 index 00000000..400240fc --- /dev/null +++ b/test/spv/builtin.ir @@ -0,0 +1,50 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s + +; CHECK: OpEntryPoint Kernel %[[#]] "tbuiltin" %[[#VAR1:]] %[[#VAR2:]] %[[#VAR3:]] %[[#VAR4:]] %[[#VAR5:]] %[[#VAR6:]] + +; CHECK: OpDecorate %[[#VAR1]] Constant +; CHECK: OpDecorate %[[#VAR1]] BuiltIn GlobalInvocationId +; CHECK: OpDecorate %[[#VAR2]] Constant +; CHECK: OpDecorate %[[#VAR2]] BuiltIn GlobalSize +; CHECK: OpDecorate %[[#VAR3]] Constant +; CHECK: OpDecorate %[[#VAR3]] BuiltIn NumSubgroups +; CHECK: OpDecorate %[[#VAR4]] Constant +; CHECK: OpDecorate %[[#VAR4]] BuiltIn SubgroupSize +; CHECK: OpDecorate %[[#VAR5]] Constant +; CHECK: OpDecorate %[[#VAR5]] BuiltIn SubgroupId +; CHECK: OpDecorate %[[#VAR6]] Constant +; CHECK: OpDecorate %[[#VAR6]] BuiltIn SubgroupLocalInvocationId + +; CHECK: %[[#I64:]] = OpTypeInt 64 0 +; CHECK: %[[#I64V3:]] = OpTypeVector %[[#I64]] 3 +; CHECK: %[[#PTR_TO_I64V3:]] = OpTypePointer Input %[[#I64V3]] +; CHECK: %[[#VAR1]] = OpVariable %[[#PTR_TO_I64V3]] Input +; CHECK: %[[#VAR2]] = OpVariable %[[#PTR_TO_I64V3]] Input +; CHECK: %[[#I32:]] = OpTypeInt 32 0 +; CHECK: %[[#PTR_TO_I32:]] = OpTypePointer Input %[[#I32]] +; CHECK: %[[#VAR3]] = OpVariable %[[#PTR_TO_I32]] Input +; CHECK: %[[#VAR4]] = OpVariable %[[#PTR_TO_I32]] Input +; CHECK: %[[#VAR5]] = OpVariable %[[#PTR_TO_I32]] Input +; CHECK: %[[#VAR6]] = OpVariable %[[#PTR_TO_I32]] Input + +func @tbuiltin() { + %0 = group_id + %1 = group_size + %2 = num_subgroups + %3 = subgroup_size + parallel { + %4 = subgroup_id + %5 = subgroup_local_id + } +; CHECK: %[[#VAR1_LOAD:]] = OpLoad %[[#I64V3]] %[[#VAR1]] Aligned 32 +; CHECK: %[[#]] = OpCompositeExtract %[[#I64]] %[[#VAR1_LOAD]] 2 +; CHECK: %[[#VAR2_LOAD:]] = OpLoad %[[#I64V3]] %[[#VAR2]] Aligned 32 +; CHECK: %[[#]] = OpCompositeExtract %[[#I64]] %[[#VAR2_LOAD]] 2 +; CHECK: %[[#]] = OpLoad %[[#I32]] %[[#VAR3]] Aligned 4 +; CHECK: %[[#]] = OpLoad %[[#I32]] %[[#VAR4]] Aligned 4 +; CHECK: %[[#]] = OpLoad %[[#I32]] %[[#VAR5]] Aligned 4 +; CHECK: %[[#]] = OpLoad %[[#I32]] %[[#VAR6]] Aligned 4 +} diff --git a/tools/spirvgen/spirvgen.py b/tools/spirvgen/spirvgen.py index 894d5526..2a38f3c0 100755 --- a/tools/spirvgen/spirvgen.py +++ b/tools/spirvgen/spirvgen.py @@ -51,7 +51,7 @@ class spv_inst : public ilist_node { bool has_result_id_; }; -using DecorationAttr = std::variant>; +using DecorationAttr = std::variant>; using ExecutionModeAttr = std::variant>; using LiteralContextDependentNumber = std::variant; @@ -62,6 +62,7 @@ class spv_inst : public ilist_node { using IdRef = spv_inst*; using IdScope = spv_inst*; using IdMemorySemantics = spv_inst*; +using MemoryAccessAttr = std::int32_t; using PairIdRefIdRef = std::pair; using PairLiteralIntegerIdRef = std::pair, spv_inst*>; @@ -152,6 +153,9 @@ def __init__(self, name, kind, quantifier): self.name = name self.kind = kind self.quantifier = quantifier + self.init = None + if self.quantifier == '?': + self.init = 'std::nullopt' def get_operands(instruction): @@ -189,7 +193,10 @@ def generate_op_classes(f, grammar): f'constexpr static std::array required_capabilities = {{{cap_str}}};', file=f) f.write(f'{get_class_name(instruction)}(') - f.write(','.join([f'{o.kind} {o.name}' for o in operands])) + f.write(','.join([ + f'{o.kind} {o.name}{f" = {o.init}" if o.init else ""}' + for o in operands + ])) f.write(') : ') initializer_list = [ f'spv_inst{{Op::{get_opcode_name(instruction)}, {"true" if has_result_id(instruction) else "false"}}}' @@ -270,10 +277,19 @@ def patch_grammar(grammar): for instruction in grammar['instructions']: if instruction['opname'] == 'OpDecorate': if instruction['operands'][-1]['kind'] == 'Decoration': - instruction['operands'].append({'kind': 'DecorationAttr'}) + instruction['operands'].append({ + 'kind': 'DecorationAttr', + 'quantifier': '?' + }) elif instruction['opname'] == 'OpExecutionMode': if instruction['operands'][-1]['kind'] == 'ExecutionMode': instruction['operands'].append({'kind': 'ExecutionModeAttr'}) + elif 'operands' in instruction and instruction['operands'][-1][ + 'kind'] == 'MemoryAccess': + instruction['operands'].append({ + 'kind': 'MemoryAccessAttr', + 'quantifier': '?' + }) return grammar @@ -303,7 +319,8 @@ def patch_grammar(grammar): grammar = filter_grammar(grammar, filt) grammar = patch_grammar(grammar) generate_header(args, spv_enums, grammar, generate_enums) - generate_header(args, spv_names, grammar, generate_names, spv_names_includes) + generate_header(args, spv_names, grammar, generate_names, + spv_names_includes) generate_cpp(args, spv_names_cpp, grammar, generate_names_cpp, spv_names_cpp_includes) generate_header(args, spv_ops, grammar, generate_op_classes, From 98129f96fd114b410f9960dbe62ce561075dea25 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 7 Nov 2024 18:22:25 +0100 Subject: [PATCH 094/297] SPIR-V: Add work group instruction Signed-off-by: Carsten Uphoff --- src/spv/converter.cpp | 41 +++++++++++++++++++++++++++++++++++++++++ src/spv/converter.hpp | 4 ++++ src/spv/module.hpp | 2 +- test/CMakeLists.txt | 1 + test/spv/work_group.ir | 23 +++++++++++++++++++++++ 5 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 test/spv/work_group.ir diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index 61903f1f..72aab183 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -15,6 +15,7 @@ #include "spv/instructions.hpp" #include "spv/opencl.std.hpp" #include "spv/uniquifier.hpp" +#include "spv/visit.hpp" #include "support/casting.hpp" #include "support/ilist_base.hpp" #include "support/util.hpp" @@ -55,6 +56,19 @@ auto convert_prog_to_spirv(tinytc_prog const &p, } } + // Add missing capabilites + for (std::int32_t s = 0; s < num_module_sections; ++s) { + for (auto const &i : m->insts(enum_cast

(s))) { + visit(overloaded{[&](I const &) { + for (auto const &cap : I::required_capabilities) { + conv.unique().capability(cap); + } + }, + [&](auto const &) {}}, + i); + } + } + return m; } @@ -630,6 +644,33 @@ void inst_converter::operator()(subgroup_size_inst const &in) { declare(in.result(0), load_builtin(BuiltIn::SubgroupSize)); } +void inst_converter::operator()(work_group_inst const &in) { + auto const make = [&](scalar_type sty, work_group_operation operation, spv_inst *spv_ty, + spv_inst *operand) -> spv_inst * { + auto scope = unique_.i32_constant(static_cast(Scope::Workgroup)); + if (operation == work_group_operation::reduce_add) { + switch (sty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return mod_->add(spv_ty, scope, GroupOperation::Reduce, operand); + case scalar_type::f32: + case scalar_type::f64: + case scalar_type::c32: + case scalar_type::c64: + return mod_->add(spv_ty, scope, GroupOperation::Reduce, operand); + } + } + throw compilation_error(in.loc(), status::not_implemented); + }; + + auto spv_ty = unique_.spv_ty(in.result(0).ty()); + auto sty = get_scalar_type(in.operand()); + declare(in.result(0), make(sty, in.operation(), spv_ty, val(in.operand()))); +} + void inst_converter::run_on_region(region_node const ®) { for (auto const &i : reg) { visit(*this, i); diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index 00db815b..fc6c675c 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -39,6 +39,7 @@ class inst_converter { void operator()(subgroup_id_inst const &in); void operator()(subgroup_local_id_inst const &in); void operator()(subgroup_size_inst const &in); + void operator()(work_group_inst const &in); void run_on_region(tinytc_region const ®); void run_on_function(tinytc_func const &fn, core_config const &core_cfg); @@ -63,4 +64,7 @@ class inst_converter { core_config core_cfg_ = {}; }; +template +concept spv_inst_with_required_capabilities = requires() { T::required_capabilities; }; + } // namespace tinytc::spv diff --git a/src/spv/module.hpp b/src/spv/module.hpp index 5e9dec61..1d061688 100644 --- a/src/spv/module.hpp +++ b/src/spv/module.hpp @@ -35,7 +35,7 @@ enum class section { type_const_var = 6, function = 7 }; -inline constexpr std::size_t num_module_sections = 8; +inline constexpr std::int32_t num_module_sections = 8; class mod final { public: diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 027d4479..14f2935d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -57,6 +57,7 @@ if(SPIRVTools_FOUND) spv/builtin.ir spv/cast.ir spv/compare.ir + spv/work_group.ir spv/unique_function_type.ir ) foreach(SOURCE IN LISTS SPIRV_VAL_SOURCES) diff --git a/test/spv/work_group.ir b/test/spv/work_group.ir new file mode 100644 index 00000000..d6259db3 --- /dev/null +++ b/test/spv/work_group.ir @@ -0,0 +1,23 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s + +; CHECK: OpCapability Group +; CHECK: %[[#I16:]] = OpTypeInt 16 0 +; CHECK: %[[#F32:]] = OpTypeFloat 32 +; CHECK: %[[#F64:]] = OpTypeFloat 64 +; CHECK: %[[#C64:]] = OpTypeVector %[[#F64]] 2 +; CHECK: %[[#SCOPE:]] = OpConstant %[[#]] 2 + +func @twg() { + %0 = constant 1 -> i16 + %1 = constant 1.0 -> f32 + %2 = constant [1.0, 0.0] -> c64 + %3 = work_group.reduce_add %0 : i16 + %4 = work_group.reduce_add %1 : f32 + %5 = work_group.reduce_add %2 : c64 +; CHECK: %[[#]] = OpGroupIAdd %[[#I16]] %[[#SCOPE]] Reduce %[[#]] +; CHECK: %[[#]] = OpGroupFAdd %[[#F32]] %[[#SCOPE]] Reduce %[[#]] +; CHECK: %[[#]] = OpGroupFAdd %[[#C64]] %[[#SCOPE]] Reduce %[[#]] +} From a7d94987471b394a8e7e703f4f4411ca9595f8a9 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 8 Nov 2024 14:41:53 +0100 Subject: [PATCH 095/297] SPIR-V: For and If Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 2 +- src/node/inst_node.cpp | 9 +- src/spv/converter.cpp | 357 +++++++++--- src/spv/converter.hpp | 23 + src/spv/enums.hpp | 6 +- src/spv/instructions.hpp | 1073 +++++++++++++++++++++++++++++++++++- src/spv/names.hpp | 6 +- src/spv/pass/dump_asm.cpp | 46 +- src/spv/pass/dump_asm.hpp | 1 + src/spv/uniquifier.cpp | 3 +- src/spv/visit.hpp | 6 +- test/CMakeLists.txt | 2 + test/spv/for.ir | 115 ++++ test/spv/if.ir | 108 ++++ tools/spirvgen/spirvgen.py | 3 + 15 files changed, 1665 insertions(+), 95 deletions(-) create mode 100644 test/spv/for.ir create mode 100644 test/spv/if.ir diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 1d892bb7..1717264f 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -1028,7 +1028,7 @@ For init-value-list = init-value *("," init-value) init-value = local-identifier "=" local-identifier return-type-list = return-type *("," return-type) - return-type = scalar-type / coopmatrix-type + return-type = boolean-type / scalar-type / coopmatrix-type Overview ~~~~~~~~ diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index c3c854ae..639d12c0 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -166,9 +166,10 @@ loop_inst::loop_inst(IK tid, tinytc_value_t from0, tinytc_value_t to0, tinytc_va result(i) = value_node{init_values[i]->ty(), this, lc}; } for (std::size_t i = 0; i < init_values.size(); ++i) { - if (!isa(*init_values[i]->ty()) && + if (!isa(*init_values[i]->ty()) && + !isa(*init_values[i]->ty()) && !isa(*init_values[i]->ty())) { - throw compilation_error(loc(), status::ir_expected_coopmatrix_or_scalar); + throw compilation_error(loc(), status::ir_expected_coopmatrix_scalar_or_boolean); } op(op_init() + i, init_values[i]); } @@ -841,9 +842,9 @@ if_inst::if_inst(tinytc_value_t condition, array_view return throw compilation_error(loc(), status::ir_expected_boolean); } for (std::size_t i = 0; i < return_types.size(); ++i) { - if (!isa(*return_types[i]) && + if (!isa(*return_types[i]) && !isa(*return_types[i]) && !isa(*return_types[i])) { - throw compilation_error(loc(), status::ir_expected_coopmatrix_or_scalar); + throw compilation_error(loc(), status::ir_expected_coopmatrix_scalar_or_boolean); } result(i) = value_node{return_types[i], this, lc}; } diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index 72aab183..a686ec30 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -16,7 +16,6 @@ #include "spv/opencl.std.hpp" #include "spv/uniquifier.hpp" #include "spv/visit.hpp" -#include "support/casting.hpp" #include "support/ilist_base.hpp" #include "support/util.hpp" #include "support/visit.hpp" @@ -75,6 +74,18 @@ auto convert_prog_to_spirv(tinytc_prog const &p, inst_converter::inst_converter(tinytc_compiler_context_t ctx, mod &m) : ctx_(ctx), mod_(&m), unique_(ctx, m) {} +auto inst_converter::get_last_label() -> spv_inst * { + auto &insts = mod_->insts(section::function); + auto it = insts.end(); + while (it != insts.begin()) { + auto in = (--it).get(); + if (isa(*in)) { + return in; + } + } + return nullptr; +} + auto inst_converter::get_scalar_type(value_node const &v) -> scalar_type { auto st = dyn_cast(v.ty()); if (!st) { @@ -118,6 +129,64 @@ auto inst_converter::multi_val(value_node const &v) -> std::vector & throw compilation_error(v.loc(), status::spirv_undefined_value); } +auto inst_converter::make_constant(scalar_type sty, spv_inst *spv_ty, + constant_inst::value_type const &val) -> spv_inst * { + auto const add_constant = [this, &spv_ty](auto val) -> spv_inst * { + return mod_->add_to(section::type_const_var, spv_ty, val); + }; + auto const add_constant_complex = [this, &spv_ty](spv_inst *spv_float_ty, auto re, + auto im) -> spv_inst * { + auto c_re = mod_->add_to(section::type_const_var, spv_float_ty, re); + auto c_im = mod_->add_to(section::type_const_var, spv_float_ty, im); + return mod_->add_to(section::type_const_var, spv_ty, + std::vector{c_re, c_im}); + }; + const auto visitor = overloaded{ + [&](bool) -> spv_inst * { return nullptr; }, + [&](std::int64_t i) -> spv_inst * { + switch (sty) { + case scalar_type::i8: + return add_constant(static_cast(i)); + case scalar_type::i16: + return add_constant(static_cast(i)); + case scalar_type::i32: + return add_constant(static_cast(i)); + case scalar_type::i64: + case scalar_type::index: + return add_constant(i); + default: + return nullptr; + } + }, + [&](double d) -> spv_inst * { + switch (sty) { + case scalar_type::f32: + return add_constant(static_cast(d)); + case scalar_type::f64: + return add_constant(d); + default: + return nullptr; + } + }, + [&](std::complex d) -> spv_inst * { + switch (sty) { + case scalar_type::c32: { + auto spv_float_ty = unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::f32)); + return add_constant_complex(spv_float_ty, static_cast(d.real()), + static_cast(d.imag())); + } + case scalar_type::c64: { + auto spv_float_ty = unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::f64)); + return add_constant_complex(spv_float_ty, d.real(), d.imag()); + } + default: + return nullptr; + } + }, + }; + return std::visit(visitor, val); +}; + void inst_converter::operator()(inst_node const &in) { // @todo throw compilation_error(in.loc(), status::not_implemented); @@ -533,70 +602,6 @@ void inst_converter::operator()(compare_inst const &in) { } void inst_converter::operator()(constant_inst const &in) { - auto const make = [&](scalar_type sty, spv_inst *spv_ty, - constant_inst::value_type const &val) -> spv_inst * { - auto const add_constant = [this, &spv_ty](auto val) -> spv_inst * { - return mod_->add_to(section::type_const_var, spv_ty, val); - }; - auto const add_constant_complex = [this, &spv_ty](spv_inst *spv_float_ty, auto re, - auto im) -> spv_inst * { - auto c_re = mod_->add_to(section::type_const_var, spv_float_ty, re); - auto c_im = mod_->add_to(section::type_const_var, spv_float_ty, im); - return mod_->add_to(section::type_const_var, spv_ty, - std::vector{c_re, c_im}); - }; - const auto visitor = overloaded{ - [&](bool) -> spv_inst * { return nullptr; }, - [&](std::int64_t i) -> spv_inst * { - switch (sty) { - case scalar_type::i8: - return add_constant(static_cast(i)); - case scalar_type::i16: - return add_constant(static_cast(i)); - case scalar_type::i32: - return add_constant(static_cast(i)); - case scalar_type::i64: - case scalar_type::index: - return add_constant(i); - default: - return nullptr; - } - }, - [&](double d) -> spv_inst * { - switch (sty) { - case scalar_type::f32: - return add_constant(static_cast(d)); - case scalar_type::f64: - return add_constant(d); - default: - return nullptr; - } - }, - [&](std::complex d) -> spv_inst * { - switch (sty) { - case scalar_type::c32: { - auto spv_float_ty = - unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::f32)); - return add_constant_complex(spv_float_ty, static_cast(d.real()), - static_cast(d.imag())); - } - case scalar_type::c64: { - auto spv_float_ty = - unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::f64)); - return add_constant_complex(spv_float_ty, d.real(), d.imag()); - } - default: - return nullptr; - } - }, - }; - auto cst = std::visit(visitor, val); - if (cst == nullptr) { - throw compilation_error(in.loc(), status::internal_compiler_error); - } - return cst; - }; - auto spv_ty = unique_.spv_ty(in.result(0).ty()); if (isa(*in.result(0).ty())) { @@ -605,10 +610,17 @@ void inst_converter::operator()(constant_inst const &in) { } declare(in.result(0), unique_.bool_constant(std::get(in.value()))); } else if (auto st = dyn_cast(in.result(0).ty()); st) { - declare(in.result(0), make(st->ty(), spv_ty, in.value())); + auto cst = make_constant(st->ty(), spv_ty, in.value()); + if (cst == nullptr) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + declare(in.result(0), cst); } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { auto const length = ct->length(core_cfg_.subgroup_size); - auto cst = make(ct->component_ty(), spv_ty, in.value()); + auto cst = make_constant(ct->component_ty(), spv_ty, in.value()); + if (cst == nullptr) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } multi_declare(in.result(0), std::vector(length, cst)); } else { @@ -616,6 +628,125 @@ void inst_converter::operator()(constant_inst const &in) { } } +void inst_converter::operator()(for_inst const &in) { + const std::int64_t num_results = num_yielded_vals(in.result_begin(), in.result_end()); + + auto header_label = std::make_unique(); + auto body_label = std::make_unique(); + auto continue_label = std::make_unique(); + auto merge_label = std::make_unique(); + + mod_->add(merge_label.get(), continue_label.get(), LoopControl::None); + mod_->add(header_label.get()); + + // Header block + auto spv_bool_ty = unique_.spv_ty(boolean_data_type::get(ctx_)); + auto spv_loop_var_ty = unique_.spv_ty(in.loop_var().ty()); + auto header_block_last_label = header_label.get(); + mod_->insts(section::function).push_back(header_label.release()); + + auto condition = mod_->add(spv_bool_ty, val(in.from()), val(in.to())); + mod_->add(condition, body_label.get(), merge_label.get(), + std::vector{}); + + // Body block + auto body_first_label = body_label.get(); + mod_->insts(section::function).push_back(body_label.release()); + // nullptr needs to be replaced by the loop var update once it is defined + auto loop_var_phi = mod_->add( + spv_loop_var_ty, + std::vector{PairIdRefIdRef{val(in.from()), header_block_last_label}, + PairIdRefIdRef{nullptr, continue_label.get()}}); + declare(in.loop_var(), loop_var_phi); + + auto const &make_iter_arg_phi = [&]() -> std::vector { + auto phis = std::vector{}; + phis.reserve(num_results); + for (std::int64_t i = 0; i < in.num_results(); ++i) { + auto ty = unique_.spv_ty(in.iter_arg(i).ty()); + if (isa(*in.iter_arg(i).ty())) { + auto &init_vals = multi_val(in.iter_init(i)); + auto iter_arg_vals = std::vector(init_vals.size(), nullptr); + for (auto init_val = init_vals.begin(), iter_arg_val = iter_arg_vals.begin(); + init_val != init_vals.end(); ++init_val, ++iter_arg_val) { + auto phi = + mod_->add(ty, std::vector{ + PairIdRefIdRef{*init_val, header_block_last_label}, + PairIdRefIdRef{nullptr, continue_label.get()}}); + *iter_arg_val = phi; + phis.emplace_back(phi); + } + multi_declare(in.iter_arg(i), std::move(iter_arg_vals)); + } else { + phis.emplace_back(mod_->add( + ty, std::vector{ + PairIdRefIdRef{val(in.iter_init(i)), header_block_last_label}, + PairIdRefIdRef{nullptr, continue_label.get()}})); + declare(in.iter_arg(i), phis.back()); + } + } + return phis; + }; + auto iter_arg_phis = make_iter_arg_phi(); + + auto yielded_for = run_on_region_with_yield(in.body(), num_results); + // Update phis with yielded values + for (std::int64_t i = 0; i < num_results; ++i) { + iter_arg_phis[i]->op0().back().first = yielded_for[i]; + } + + auto body_last_label = get_last_label(); + if (!body_last_label) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + mod_->add(continue_label.get()); + + // Continue block + auto continue_block_last_label = continue_label.get(); + mod_->insts(section::function).push_back(continue_label.release()); + auto step = [&]() -> spv_inst * { + if (in.has_step()) { + return val(in.step()); + } + return make_constant(get_scalar_type(in.loop_var()), spv_loop_var_ty, std::int64_t{1}); + }(); + auto loop_var_update = mod_->add(spv_loop_var_ty, val(in.loop_var()), step); + loop_var_phi->op0().back().first = loop_var_update; + auto condition2 = mod_->add(spv_bool_ty, loop_var_update, val(in.to())); + mod_->add(condition2, body_first_label, merge_label.get(), + std::vector{}); + + // Merge block + mod_->insts(section::function).push_back(merge_label.release()); + + auto const &set_results = [&] { + std::int64_t val_no = 0; + for (std::int64_t i = 0; i < in.num_results(); ++i) { + auto ty = unique_.spv_ty(in.result(i).ty()); + if (isa(*in.result(i).ty())) { + auto &init_vals = multi_val(in.iter_init(i)); + auto results = std::vector(init_vals.size(), nullptr); + for (auto init_val = init_vals.begin(), result = results.begin(); + init_val != init_vals.end(); ++init_val, ++result) { + *result = mod_->add( + ty, std::vector{ + PairIdRefIdRef{*init_val, header_block_last_label}, + PairIdRefIdRef{yielded_for[val_no++], continue_block_last_label}}); + } + multi_declare(in.result(i), std::move(results)); + } else { + declare( + in.result(i), + mod_->add( + ty, std::vector{ + PairIdRefIdRef{val(in.iter_init(i)), header_block_last_label}, + PairIdRefIdRef{yielded_for[val_no++], continue_block_last_label}})); + } + } + }; + set_results(); +} + void inst_converter::operator()(group_id_inst const &in) { auto gid = load_builtin(BuiltIn::GlobalInvocationId); auto index_ty = unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); @@ -628,6 +759,60 @@ void inst_converter::operator()(group_size_inst const &in) { declare(in.result(0), mod_->add(index_ty, gs, std::vector{2})); } + +void inst_converter::operator()(if_inst const &in) { + const std::int64_t num_results = num_yielded_vals(in.result_begin(), in.result_end()); + + auto then_label = std::make_unique(); + auto otherwise_label = std::make_unique(); + auto merge_label = std::make_unique(); + + auto conditionv = val(in.condition()); + mod_->add(merge_label.get(), SelectionControl::None); + mod_->add(conditionv, then_label.get(), otherwise_label.get(), + std::vector{}); + mod_->insts(section::function).push_back(then_label.release()); + auto yielded_then = run_on_region_with_yield(in.then(), num_results); + mod_->add(merge_label.get()); + auto then_last_label = get_last_label(); + if (!then_last_label) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + mod_->insts(section::function).push_back(otherwise_label.release()); + auto yielded_otherwise = run_on_region_with_yield(in.otherwise(), num_results); + mod_->add(merge_label.get()); + auto otherwise_last_label = get_last_label(); + if (!otherwise_last_label) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + + mod_->insts(section::function).push_back(merge_label.release()); + + std::int64_t val_no = 0; + for (std::int64_t i = 0; i < in.num_results(); ++i) { + auto ty = unique_.spv_ty(in.result(i).ty()); + if (auto ct = dyn_cast(in.result(i).ty()); ct) { + const auto length = ct->length(core_cfg_.subgroup_size); + auto phi_insts = std::vector(length, nullptr); + for (auto &phi_inst : phi_insts) { + phi_inst = mod_->add( + ty, std::vector{ + PairIdRefIdRef{yielded_then[val_no], then_last_label}, + PairIdRefIdRef{yielded_otherwise[val_no], otherwise_last_label}}); + ++val_no; + } + multi_declare(in.result(i), std::move(phi_insts)); + } else { + auto phi_inst = mod_->add( + ty, std::vector{ + PairIdRefIdRef{yielded_then[val_no], then_last_label}, + PairIdRefIdRef{yielded_otherwise[val_no], otherwise_last_label}}); + ++val_no; + declare(in.result(i), phi_inst); + } + } +} + void inst_converter::operator()(num_subgroups_inst const &in) { declare(in.result(0), load_builtin(BuiltIn::NumSubgroups)); } @@ -671,12 +856,50 @@ void inst_converter::operator()(work_group_inst const &in) { declare(in.result(0), make(sty, in.operation(), spv_ty, val(in.operand()))); } +void inst_converter::operator()(yield_inst const &in) { + if (yielded_vals_.empty()) { + throw compilation_error(in.loc(), status::ir_unexpected_yield); + } + + auto &top = yielded_vals_.top(); + const std::int64_t num = num_yielded_vals(in.op_begin(), in.op_end()); + if (static_cast(top.size()) != num) { + throw compilation_error(in.loc(), status::ir_yield_mismatch); + } + + std::int64_t i = 0; + for (auto &op : in.operands()) { + if (auto ct = dyn_cast(op.ty()); ct) { + auto &vals = multi_val(op); + for (auto &v : vals) { + top[i++] = v; + } + } else { + top[i++] = val(op); + } + } +} + void inst_converter::run_on_region(region_node const ®) { for (auto const &i : reg) { visit(*this, i); } } +auto inst_converter::run_on_region_with_yield(region_node const ®, + std::int64_t num_results) -> std::vector { + yielded_vals_.push(std::vector(num_results, nullptr)); + run_on_region(reg); + auto yielded_vals = std::move(yielded_vals_.top()); + if (static_cast(yielded_vals.size()) != num_results || + std::any_of(yielded_vals.begin(), yielded_vals.end(), + [](spv_inst *in) { return in == nullptr; })) { + throw compilation_error(reg.loc(), status::ir_yield_mismatch); + } + yielded_vals_.pop(); + return yielded_vals; +} + void inst_converter::run_on_function(function_node const &fn, core_config const &core_cfg) { core_cfg_ = core_cfg; diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index fc6c675c..4f23b200 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -6,10 +6,12 @@ #include "node/inst_node.hpp" #include "spv/module.hpp" #include "spv/uniquifier.hpp" +#include "support/casting.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" #include +#include #include #include @@ -32,21 +34,39 @@ class inst_converter { void operator()(cast_inst const &in); void operator()(compare_inst const &in); void operator()(constant_inst const &in); + void operator()(for_inst const &in); void operator()(group_id_inst const &in); void operator()(group_size_inst const &in); + void operator()(if_inst const &in); void operator()(num_subgroups_inst const &in); void operator()(parallel_inst const &in); void operator()(subgroup_id_inst const &in); void operator()(subgroup_local_id_inst const &in); void operator()(subgroup_size_inst const &in); void operator()(work_group_inst const &in); + void operator()(yield_inst const &in); void run_on_region(tinytc_region const ®); + auto run_on_region_with_yield(region_node const ®, + std::int64_t num_results) -> std::vector; void run_on_function(tinytc_func const &fn, core_config const &core_cfg); inline auto unique() -> uniquifier & { return unique_; } private: + template + auto num_yielded_vals(Iterator begin, Iterator end) -> std::int64_t { + std::int64_t num_results = 0; + for (; begin != end; ++begin) { + if (auto ct = dyn_cast(begin->ty()); ct) { + num_results += ct->length(core_cfg_.subgroup_size); + } else { + ++num_results; + } + } + return num_results; + } + auto get_last_label() -> spv_inst *; auto get_scalar_type(tinytc_value const &v) -> scalar_type; auto get_coopmatrix_type(tinytc_value const &v) -> scalar_type; auto load_builtin(BuiltIn b) -> spv_inst *; @@ -54,12 +74,15 @@ class inst_converter { auto val(tinytc_value const &v) -> spv_inst *; auto multi_declare(tinytc_value const &v, std::vector insts); auto multi_val(tinytc_value const &v) -> std::vector &; + auto make_constant(scalar_type sty, spv_inst *spv_ty, + constant_inst::value_type const &val) -> spv_inst *; tinytc_compiler_context_t ctx_; mod *mod_; uniquifier unique_; std::unordered_map vals_; std::unordered_map> multi_vals_; + std::stack> yielded_vals_; std::vector builtins_used_by_function_; core_config core_cfg_ = {}; }; diff --git a/src/spv/enums.hpp b/src/spv/enums.hpp index 3969f57b..2ee1718b 100644 --- a/src/spv/enums.hpp +++ b/src/spv/enums.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_ENUMS_2024117_HPP -#define GENERATED_ENUMS_2024117_HPP +#ifndef GENERATED_ENUMS_2024118_HPP +#define GENERATED_ENUMS_2024118_HPP namespace tinytc::spv { @@ -1422,4 +1422,4 @@ enum class FPEncoding {}; } // namespace tinytc::spv -#endif // GENERATED_ENUMS_2024117_HPP +#endif // GENERATED_ENUMS_2024118_HPP diff --git a/src/spv/instructions.hpp b/src/spv/instructions.hpp index 76541412..21db47a7 100644 --- a/src/spv/instructions.hpp +++ b/src/spv/instructions.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_INSTRUCTIONS_2024117_HPP -#define GENERATED_INSTRUCTIONS_2024117_HPP +#ifndef GENERATED_INSTRUCTIONS_2024118_HPP +#define GENERATED_INSTRUCTIONS_2024118_HPP #include "enums.hpp" #include "error.hpp" @@ -68,6 +68,7 @@ class OpUndef : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Undef; } OpUndef(IdResultType type) : spv_inst{Op::Undef, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } private: @@ -78,6 +79,7 @@ class OpSourceContinued : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SourceContinued; } OpSourceContinued(LiteralString op0) : spv_inst{Op::SourceContinued, false}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } inline auto op0() const -> LiteralString const & { return op0_; } private: @@ -90,9 +92,13 @@ class OpSource : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::Source, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> SourceLanguage & { return op0_; } inline auto op0() const -> SourceLanguage const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } inline auto op2() const -> std::optional const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -106,6 +112,7 @@ class OpSourceExtension : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SourceExtension; } OpSourceExtension(LiteralString op0) : spv_inst{Op::SourceExtension, false}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } inline auto op0() const -> LiteralString const & { return op0_; } private: @@ -116,7 +123,9 @@ class OpName : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Name; } OpName(IdRef op0, LiteralString op1) : spv_inst{Op::Name, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralString & { return op1_; } inline auto op1() const -> LiteralString const & { return op1_; } private: @@ -129,8 +138,11 @@ class OpMemberName : public spv_inst { OpMemberName(IdRef op0, LiteralInteger op1, LiteralString op2) : spv_inst{Op::MemberName, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> LiteralString & { return op2_; } inline auto op2() const -> LiteralString const & { return op2_; } private: @@ -142,6 +154,7 @@ class OpString : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::String; } OpString(LiteralString op0) : spv_inst{Op::String, true}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } inline auto op0() const -> LiteralString const & { return op0_; } private: @@ -153,8 +166,11 @@ class OpLine : public spv_inst { OpLine(IdRef op0, LiteralInteger op1, LiteralInteger op2) : spv_inst{Op::Line, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> LiteralInteger & { return op2_; } inline auto op2() const -> LiteralInteger const & { return op2_; } private: @@ -166,6 +182,7 @@ class OpExtension : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Extension; } OpExtension(LiteralString op0) : spv_inst{Op::Extension, false}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } inline auto op0() const -> LiteralString const & { return op0_; } private: @@ -175,6 +192,7 @@ class OpExtInstImport : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ExtInstImport; } OpExtInstImport(LiteralString op0) : spv_inst{Op::ExtInstImport, true}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } inline auto op0() const -> LiteralString const & { return op0_; } private: @@ -186,9 +204,13 @@ class OpExtInst : public spv_inst { OpExtInst(IdResultType type, IdRef op0, LiteralExtInstInteger op1, std::vector op2) : spv_inst{Op::ExtInst, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralExtInstInteger & { return op1_; } inline auto op1() const -> LiteralExtInstInteger const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } inline auto op2() const -> std::vector const & { return op2_; } private: @@ -202,7 +224,9 @@ class OpMemoryModel : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemoryModel; } OpMemoryModel(AddressingModel op0, MemoryModel op1) : spv_inst{Op::MemoryModel, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> AddressingModel & { return op0_; } inline auto op0() const -> AddressingModel const & { return op0_; } + inline auto op1() -> MemoryModel & { return op1_; } inline auto op1() const -> MemoryModel const & { return op1_; } private: @@ -215,9 +239,13 @@ class OpEntryPoint : public spv_inst { OpEntryPoint(ExecutionModel op0, IdRef op1, LiteralString op2, std::vector op3) : spv_inst{Op::EntryPoint, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> ExecutionModel & { return op0_; } inline auto op0() const -> ExecutionModel const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> LiteralString & { return op2_; } inline auto op2() const -> LiteralString const & { return op2_; } + inline auto op3() -> std::vector & { return op3_; } inline auto op3() const -> std::vector const & { return op3_; } private: @@ -232,8 +260,11 @@ class OpExecutionMode : public spv_inst { OpExecutionMode(IdRef op0, ExecutionMode op1, ExecutionModeAttr op2) : spv_inst{Op::ExecutionMode, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> ExecutionMode & { return op1_; } inline auto op1() const -> ExecutionMode const & { return op1_; } + inline auto op2() -> ExecutionModeAttr & { return op2_; } inline auto op2() const -> ExecutionModeAttr const & { return op2_; } private: @@ -245,6 +276,7 @@ class OpCapability : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Capability; } OpCapability(Capability op0) : spv_inst{Op::Capability, false}, op0_(std::move(op0)) {} + inline auto op0() -> Capability & { return op0_; } inline auto op0() const -> Capability const & { return op0_; } private: @@ -269,7 +301,9 @@ class OpTypeInt : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeInt; } OpTypeInt(LiteralInteger op0, LiteralInteger op1) : spv_inst{Op::TypeInt, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> LiteralInteger & { return op0_; } inline auto op0() const -> LiteralInteger const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } inline auto op1() const -> LiteralInteger const & { return op1_; } private: @@ -281,7 +315,9 @@ class OpTypeFloat : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeFloat; } OpTypeFloat(LiteralInteger op0, std::optional op1 = std::nullopt) : spv_inst{Op::TypeFloat, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> LiteralInteger & { return op0_; } inline auto op0() const -> LiteralInteger const & { return op0_; } + inline auto op1() -> std::optional & { return op1_; } inline auto op1() const -> std::optional const & { return op1_; } private: @@ -293,7 +329,9 @@ class OpTypeVector : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeVector; } OpTypeVector(IdRef op0, LiteralInteger op1) : spv_inst{Op::TypeVector, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } inline auto op1() const -> LiteralInteger const & { return op1_; } private: @@ -306,7 +344,9 @@ class OpTypeMatrix : public spv_inst { constexpr static std::array required_capabilities = {Capability::Matrix}; OpTypeMatrix(IdRef op0, LiteralInteger op1) : spv_inst{Op::TypeMatrix, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } inline auto op1() const -> LiteralInteger const & { return op1_; } private: @@ -322,13 +362,21 @@ class OpTypeImage : public spv_inst { : spv_inst{Op::TypeImage, true}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), op5_(std::move(op5)), op6_(std::move(op6)), op7_(std::move(op7)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> Dim & { return op1_; } inline auto op1() const -> Dim const & { return op1_; } + inline auto op2() -> LiteralInteger & { return op2_; } inline auto op2() const -> LiteralInteger const & { return op2_; } + inline auto op3() -> LiteralInteger & { return op3_; } inline auto op3() const -> LiteralInteger const & { return op3_; } + inline auto op4() -> LiteralInteger & { return op4_; } inline auto op4() const -> LiteralInteger const & { return op4_; } + inline auto op5() -> LiteralInteger & { return op5_; } inline auto op5() const -> LiteralInteger const & { return op5_; } + inline auto op6() -> ImageFormat & { return op6_; } inline auto op6() const -> ImageFormat const & { return op6_; } + inline auto op7() -> std::optional & { return op7_; } inline auto op7() const -> std::optional const & { return op7_; } private: @@ -352,6 +400,7 @@ class OpTypeSampledImage : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeSampledImage; } OpTypeSampledImage(IdRef op0) : spv_inst{Op::TypeSampledImage, true}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -362,7 +411,9 @@ class OpTypeArray : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeArray; } OpTypeArray(IdRef op0, IdRef op1) : spv_inst{Op::TypeArray, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -374,6 +425,7 @@ class OpTypeRuntimeArray : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeRuntimeArray; } constexpr static std::array required_capabilities = {Capability::Shader}; OpTypeRuntimeArray(IdRef op0) : spv_inst{Op::TypeRuntimeArray, true}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -383,6 +435,7 @@ class OpTypeStruct : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeStruct; } OpTypeStruct(std::vector op0) : spv_inst{Op::TypeStruct, true}, op0_(std::move(op0)) {} + inline auto op0() -> std::vector & { return op0_; } inline auto op0() const -> std::vector const & { return op0_; } private: @@ -393,6 +446,7 @@ class OpTypeOpaque : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeOpaque; } constexpr static std::array required_capabilities = {Capability::Kernel}; OpTypeOpaque(LiteralString op0) : spv_inst{Op::TypeOpaque, true}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } inline auto op0() const -> LiteralString const & { return op0_; } private: @@ -403,7 +457,9 @@ class OpTypePointer : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypePointer; } OpTypePointer(StorageClass op0, IdRef op1) : spv_inst{Op::TypePointer, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> StorageClass & { return op0_; } inline auto op0() const -> StorageClass const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -415,7 +471,9 @@ class OpTypeFunction : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeFunction; } OpTypeFunction(IdRef op0, std::vector op1) : spv_inst{Op::TypeFunction, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } inline auto op1() const -> std::vector const & { return op1_; } private: @@ -459,6 +517,7 @@ class OpTypePipe : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypePipe; } constexpr static std::array required_capabilities = {Capability::Pipes}; OpTypePipe(AccessQualifier op0) : spv_inst{Op::TypePipe, true}, op0_(std::move(op0)) {} + inline auto op0() -> AccessQualifier & { return op0_; } inline auto op0() const -> AccessQualifier const & { return op0_; } private: @@ -471,7 +530,9 @@ class OpTypeForwardPointer : public spv_inst { Capability::Addresses, Capability::PhysicalStorageBufferAddresses}; OpTypeForwardPointer(IdRef op0, StorageClass op1) : spv_inst{Op::TypeForwardPointer, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> StorageClass & { return op1_; } inline auto op1() const -> StorageClass const & { return op1_; } private: @@ -482,6 +543,7 @@ class OpConstantTrue : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantTrue; } OpConstantTrue(IdResultType type) : spv_inst{Op::ConstantTrue, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } private: @@ -492,6 +554,7 @@ class OpConstantFalse : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantFalse; } OpConstantFalse(IdResultType type) : spv_inst{Op::ConstantFalse, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } private: @@ -502,7 +565,9 @@ class OpConstant : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Constant; } OpConstant(IdResultType type, LiteralContextDependentNumber op0) : spv_inst{Op::Constant, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> LiteralContextDependentNumber & { return op0_; } inline auto op0() const -> LiteralContextDependentNumber const & { return op0_; } private: @@ -514,7 +579,9 @@ class OpConstantComposite : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantComposite; } OpConstantComposite(IdResultType type, std::vector op0) : spv_inst{Op::ConstantComposite, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> std::vector & { return op0_; } inline auto op0() const -> std::vector const & { return op0_; } private: @@ -529,9 +596,13 @@ class OpConstantSampler : public spv_inst { SamplerFilterMode op2) : spv_inst{Op::ConstantSampler, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> SamplerAddressingMode & { return op0_; } inline auto op0() const -> SamplerAddressingMode const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> SamplerFilterMode & { return op2_; } inline auto op2() const -> SamplerFilterMode const & { return op2_; } private: @@ -544,6 +615,7 @@ class OpConstantNull : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantNull; } OpConstantNull(IdResultType type) : spv_inst{Op::ConstantNull, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } private: @@ -555,8 +627,11 @@ class OpFunction : public spv_inst { OpFunction(IdResultType type, FunctionControl op0, IdRef op1) : spv_inst{Op::Function, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> FunctionControl & { return op0_; } inline auto op0() const -> FunctionControl const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -569,6 +644,7 @@ class OpFunctionParameter : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FunctionParameter; } OpFunctionParameter(IdResultType type) : spv_inst{Op::FunctionParameter, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } private: @@ -587,8 +663,11 @@ class OpFunctionCall : public spv_inst { OpFunctionCall(IdResultType type, IdRef op0, std::vector op1) : spv_inst{Op::FunctionCall, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } inline auto op1() const -> std::vector const & { return op1_; } private: @@ -602,8 +681,11 @@ class OpVariable : public spv_inst { OpVariable(IdResultType type, StorageClass op0, std::optional op1 = std::nullopt) : spv_inst{Op::Variable, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> StorageClass & { return op0_; } inline auto op0() const -> StorageClass const & { return op0_; } + inline auto op1() -> std::optional & { return op1_; } inline auto op1() const -> std::optional const & { return op1_; } private: @@ -617,9 +699,13 @@ class OpImageTexelPointer : public spv_inst { OpImageTexelPointer(IdResultType type, IdRef op0, IdRef op1, IdRef op2) : spv_inst{Op::ImageTexelPointer, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -635,9 +721,13 @@ class OpLoad : public spv_inst { std::optional op2 = std::nullopt) : spv_inst{Op::Load, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::optional & { return op1_; } inline auto op1() const -> std::optional const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } inline auto op2() const -> std::optional const & { return op2_; } private: @@ -653,9 +743,13 @@ class OpStore : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::Store, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } inline auto op2() const -> std::optional const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -672,10 +766,15 @@ class OpCopyMemory : public spv_inst { std::optional op4 = std::nullopt) : spv_inst{Op::CopyMemory, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } inline auto op2() const -> std::optional const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } + inline auto op4() -> std::optional & { return op4_; } inline auto op4() const -> std::optional const & { return op4_; } private: @@ -696,11 +795,17 @@ class OpCopyMemorySized : public spv_inst { std::optional op5 = std::nullopt) : spv_inst{Op::CopyMemorySized, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), op5_(std::move(op5)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } + inline auto op4() -> std::optional & { return op4_; } inline auto op4() const -> std::optional const & { return op4_; } + inline auto op5() -> std::optional & { return op5_; } inline auto op5() const -> std::optional const & { return op5_; } private: @@ -717,8 +822,11 @@ class OpAccessChain : public spv_inst { OpAccessChain(IdResultType type, IdRef op0, std::vector op1) : spv_inst{Op::AccessChain, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } inline auto op1() const -> std::vector const & { return op1_; } private: @@ -732,8 +840,11 @@ class OpInBoundsAccessChain : public spv_inst { OpInBoundsAccessChain(IdResultType type, IdRef op0, std::vector op1) : spv_inst{Op::InBoundsAccessChain, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } inline auto op1() const -> std::vector const & { return op1_; } private: @@ -750,9 +861,13 @@ class OpPtrAccessChain : public spv_inst { OpPtrAccessChain(IdResultType type, IdRef op0, IdRef op1, std::vector op2) : spv_inst{Op::PtrAccessChain, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } inline auto op2() const -> std::vector const & { return op2_; } private: @@ -768,8 +883,11 @@ class OpArrayLength : public spv_inst { OpArrayLength(IdResultType type, IdRef op0, LiteralInteger op1) : spv_inst{Op::ArrayLength, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } inline auto op1() const -> LiteralInteger const & { return op1_; } private: @@ -786,7 +904,9 @@ class OpGenericPtrMemSemantics : public spv_inst { OpGenericPtrMemSemantics(IdResultType type, IdRef op0) : spv_inst{Op::GenericPtrMemSemantics, true}, type_(std::move(type)), op0_(std::move(op0)) { } + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -802,9 +922,13 @@ class OpInBoundsPtrAccessChain : public spv_inst { OpInBoundsPtrAccessChain(IdResultType type, IdRef op0, IdRef op1, std::vector op2) : spv_inst{Op::InBoundsPtrAccessChain, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } inline auto op2() const -> std::vector const & { return op2_; } private: @@ -819,8 +943,11 @@ class OpDecorate : public spv_inst { OpDecorate(IdRef op0, Decoration op1, std::optional op2 = std::nullopt) : spv_inst{Op::Decorate, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> Decoration & { return op1_; } inline auto op1() const -> Decoration const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } inline auto op2() const -> std::optional const & { return op2_; } private: @@ -834,8 +961,11 @@ class OpMemberDecorate : public spv_inst { OpMemberDecorate(IdRef op0, LiteralInteger op1, Decoration op2) : spv_inst{Op::MemberDecorate, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> Decoration & { return op2_; } inline auto op2() const -> Decoration const & { return op2_; } private: @@ -855,7 +985,9 @@ class OpGroupDecorate : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupDecorate; } OpGroupDecorate(IdRef op0, std::vector op1) : spv_inst{Op::GroupDecorate, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } inline auto op1() const -> std::vector const & { return op1_; } private: @@ -867,7 +999,9 @@ class OpGroupMemberDecorate : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupMemberDecorate; } OpGroupMemberDecorate(IdRef op0, std::vector op1) : spv_inst{Op::GroupMemberDecorate, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } inline auto op1() const -> std::vector const & { return op1_; } private: @@ -880,8 +1014,11 @@ class OpVectorExtractDynamic : public spv_inst { OpVectorExtractDynamic(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::VectorExtractDynamic, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -895,9 +1032,13 @@ class OpVectorInsertDynamic : public spv_inst { OpVectorInsertDynamic(IdResultType type, IdRef op0, IdRef op1, IdRef op2) : spv_inst{Op::VectorInsertDynamic, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -912,9 +1053,13 @@ class OpVectorShuffle : public spv_inst { OpVectorShuffle(IdResultType type, IdRef op0, IdRef op1, std::vector op2) : spv_inst{Op::VectorShuffle, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } inline auto op2() const -> std::vector const & { return op2_; } private: @@ -928,7 +1073,9 @@ class OpCompositeConstruct : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CompositeConstruct; } OpCompositeConstruct(IdResultType type, std::vector op0) : spv_inst{Op::CompositeConstruct, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> std::vector & { return op0_; } inline auto op0() const -> std::vector const & { return op0_; } private: @@ -941,8 +1088,11 @@ class OpCompositeExtract : public spv_inst { OpCompositeExtract(IdResultType type, IdRef op0, std::vector op1) : spv_inst{Op::CompositeExtract, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } inline auto op1() const -> std::vector const & { return op1_; } private: @@ -956,9 +1106,13 @@ class OpCompositeInsert : public spv_inst { OpCompositeInsert(IdResultType type, IdRef op0, IdRef op1, std::vector op2) : spv_inst{Op::CompositeInsert, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } inline auto op2() const -> std::vector const & { return op2_; } private: @@ -972,7 +1126,9 @@ class OpCopyObject : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CopyObject; } OpCopyObject(IdResultType type, IdRef op0) : spv_inst{Op::CopyObject, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -985,7 +1141,9 @@ class OpTranspose : public spv_inst { constexpr static std::array required_capabilities = {Capability::Matrix}; OpTranspose(IdResultType type, IdRef op0) : spv_inst{Op::Transpose, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -998,8 +1156,11 @@ class OpSampledImage : public spv_inst { OpSampledImage(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::SampledImage, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1017,9 +1178,13 @@ class OpImageSampleImplicitLod : public spv_inst { std::optional op2 = std::nullopt) : spv_inst{Op::ImageSampleImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } inline auto op2() const -> std::optional const & { return op2_; } private: @@ -1036,9 +1201,13 @@ class OpImageSampleExplicitLod : public spv_inst { OpImageSampleExplicitLod(IdResultType type, IdRef op0, IdRef op1, ImageOperands op2) : spv_inst{Op::ImageSampleExplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> ImageOperands & { return op2_; } inline auto op2() const -> ImageOperands const & { return op2_; } private: @@ -1057,10 +1226,15 @@ class OpImageSampleDrefImplicitLod : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::ImageSampleDrefImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -1080,10 +1254,15 @@ class OpImageSampleDrefExplicitLod : public spv_inst { ImageOperands op3) : spv_inst{Op::ImageSampleDrefExplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> ImageOperands & { return op3_; } inline auto op3() const -> ImageOperands const & { return op3_; } private: @@ -1103,9 +1282,13 @@ class OpImageSampleProjImplicitLod : public spv_inst { std::optional op2 = std::nullopt) : spv_inst{Op::ImageSampleProjImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } inline auto op2() const -> std::optional const & { return op2_; } private: @@ -1123,9 +1306,13 @@ class OpImageSampleProjExplicitLod : public spv_inst { OpImageSampleProjExplicitLod(IdResultType type, IdRef op0, IdRef op1, ImageOperands op2) : spv_inst{Op::ImageSampleProjExplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> ImageOperands & { return op2_; } inline auto op2() const -> ImageOperands const & { return op2_; } private: @@ -1144,10 +1331,15 @@ class OpImageSampleProjDrefImplicitLod : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::ImageSampleProjDrefImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -1167,10 +1359,15 @@ class OpImageSampleProjDrefExplicitLod : public spv_inst { ImageOperands op3) : spv_inst{Op::ImageSampleProjDrefExplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> ImageOperands & { return op3_; } inline auto op3() const -> ImageOperands const & { return op3_; } private: @@ -1187,9 +1384,13 @@ class OpImageFetch : public spv_inst { std::optional op2 = std::nullopt) : spv_inst{Op::ImageFetch, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } inline auto op2() const -> std::optional const & { return op2_; } private: @@ -1206,10 +1407,15 @@ class OpImageGather : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::ImageGather, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -1227,10 +1433,15 @@ class OpImageDrefGather : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::ImageDrefGather, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -1247,9 +1458,13 @@ class OpImageRead : public spv_inst { std::optional op2 = std::nullopt) : spv_inst{Op::ImageRead, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } inline auto op2() const -> std::optional const & { return op2_; } private: @@ -1264,9 +1479,13 @@ class OpImageWrite : public spv_inst { OpImageWrite(IdRef op0, IdRef op1, IdRef op2, std::optional op3 = std::nullopt) : spv_inst{Op::ImageWrite, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -1280,7 +1499,9 @@ class OpImage : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Image; } OpImage(IdResultType type, IdRef op0) : spv_inst{Op::Image, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1293,7 +1514,9 @@ class OpImageQueryFormat : public spv_inst { constexpr static std::array required_capabilities = {Capability::Kernel}; OpImageQueryFormat(IdResultType type, IdRef op0) : spv_inst{Op::ImageQueryFormat, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1306,7 +1529,9 @@ class OpImageQueryOrder : public spv_inst { constexpr static std::array required_capabilities = {Capability::Kernel}; OpImageQueryOrder(IdResultType type, IdRef op0) : spv_inst{Op::ImageQueryOrder, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1321,8 +1546,11 @@ class OpImageQuerySizeLod : public spv_inst { OpImageQuerySizeLod(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::ImageQuerySizeLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1337,7 +1565,9 @@ class OpImageQuerySize : public spv_inst { Capability::ImageQuery}; OpImageQuerySize(IdResultType type, IdRef op0) : spv_inst{Op::ImageQuerySize, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1351,8 +1581,11 @@ class OpImageQueryLod : public spv_inst { OpImageQueryLod(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::ImageQueryLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1367,7 +1600,9 @@ class OpImageQueryLevels : public spv_inst { Capability::ImageQuery}; OpImageQueryLevels(IdResultType type, IdRef op0) : spv_inst{Op::ImageQueryLevels, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1381,7 +1616,9 @@ class OpImageQuerySamples : public spv_inst { Capability::ImageQuery}; OpImageQuerySamples(IdResultType type, IdRef op0) : spv_inst{Op::ImageQuerySamples, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1393,7 +1630,9 @@ class OpConvertFToU : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertFToU; } OpConvertFToU(IdResultType type, IdRef op0) : spv_inst{Op::ConvertFToU, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1405,7 +1644,9 @@ class OpConvertFToS : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertFToS; } OpConvertFToS(IdResultType type, IdRef op0) : spv_inst{Op::ConvertFToS, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1417,7 +1658,9 @@ class OpConvertSToF : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertSToF; } OpConvertSToF(IdResultType type, IdRef op0) : spv_inst{Op::ConvertSToF, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1429,7 +1672,9 @@ class OpConvertUToF : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertUToF; } OpConvertUToF(IdResultType type, IdRef op0) : spv_inst{Op::ConvertUToF, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1441,7 +1686,9 @@ class OpUConvert : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UConvert; } OpUConvert(IdResultType type, IdRef op0) : spv_inst{Op::UConvert, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1453,7 +1700,9 @@ class OpSConvert : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SConvert; } OpSConvert(IdResultType type, IdRef op0) : spv_inst{Op::SConvert, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1465,7 +1714,9 @@ class OpFConvert : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FConvert; } OpFConvert(IdResultType type, IdRef op0) : spv_inst{Op::FConvert, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1477,7 +1728,9 @@ class OpQuantizeToF16 : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::QuantizeToF16; } OpQuantizeToF16(IdResultType type, IdRef op0) : spv_inst{Op::QuantizeToF16, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1491,7 +1744,9 @@ class OpConvertPtrToU : public spv_inst { Capability::Addresses, Capability::PhysicalStorageBufferAddresses}; OpConvertPtrToU(IdResultType type, IdRef op0) : spv_inst{Op::ConvertPtrToU, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1504,7 +1759,9 @@ class OpSatConvertSToU : public spv_inst { constexpr static std::array required_capabilities = {Capability::Kernel}; OpSatConvertSToU(IdResultType type, IdRef op0) : spv_inst{Op::SatConvertSToU, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1517,7 +1774,9 @@ class OpSatConvertUToS : public spv_inst { constexpr static std::array required_capabilities = {Capability::Kernel}; OpSatConvertUToS(IdResultType type, IdRef op0) : spv_inst{Op::SatConvertUToS, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1531,7 +1790,9 @@ class OpConvertUToPtr : public spv_inst { Capability::Addresses, Capability::PhysicalStorageBufferAddresses}; OpConvertUToPtr(IdResultType type, IdRef op0) : spv_inst{Op::ConvertUToPtr, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1544,7 +1805,9 @@ class OpPtrCastToGeneric : public spv_inst { constexpr static std::array required_capabilities = {Capability::Kernel}; OpPtrCastToGeneric(IdResultType type, IdRef op0) : spv_inst{Op::PtrCastToGeneric, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1557,7 +1820,9 @@ class OpGenericCastToPtr : public spv_inst { constexpr static std::array required_capabilities = {Capability::Kernel}; OpGenericCastToPtr(IdResultType type, IdRef op0) : spv_inst{Op::GenericCastToPtr, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1573,8 +1838,11 @@ class OpGenericCastToPtrExplicit : public spv_inst { OpGenericCastToPtrExplicit(IdResultType type, IdRef op0, StorageClass op1) : spv_inst{Op::GenericCastToPtrExplicit, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> StorageClass & { return op1_; } inline auto op1() const -> StorageClass const & { return op1_; } private: @@ -1587,7 +1855,9 @@ class OpBitcast : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Bitcast; } OpBitcast(IdResultType type, IdRef op0) : spv_inst{Op::Bitcast, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1599,7 +1869,9 @@ class OpSNegate : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SNegate; } OpSNegate(IdResultType type, IdRef op0) : spv_inst{Op::SNegate, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1611,7 +1883,9 @@ class OpFNegate : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FNegate; } OpFNegate(IdResultType type, IdRef op0) : spv_inst{Op::FNegate, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -1624,8 +1898,11 @@ class OpIAdd : public spv_inst { OpIAdd(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::IAdd, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1639,8 +1916,11 @@ class OpFAdd : public spv_inst { OpFAdd(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FAdd, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1654,8 +1934,11 @@ class OpISub : public spv_inst { OpISub(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::ISub, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1669,8 +1952,11 @@ class OpFSub : public spv_inst { OpFSub(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FSub, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1684,8 +1970,11 @@ class OpIMul : public spv_inst { OpIMul(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::IMul, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1699,8 +1988,11 @@ class OpFMul : public spv_inst { OpFMul(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FMul, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1714,8 +2006,11 @@ class OpUDiv : public spv_inst { OpUDiv(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::UDiv, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1729,8 +2024,11 @@ class OpSDiv : public spv_inst { OpSDiv(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::SDiv, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1744,8 +2042,11 @@ class OpFDiv : public spv_inst { OpFDiv(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FDiv, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1759,8 +2060,11 @@ class OpUMod : public spv_inst { OpUMod(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::UMod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1774,8 +2078,11 @@ class OpSRem : public spv_inst { OpSRem(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::SRem, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1789,8 +2096,11 @@ class OpSMod : public spv_inst { OpSMod(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::SMod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1804,8 +2114,11 @@ class OpFRem : public spv_inst { OpFRem(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FRem, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1819,8 +2132,11 @@ class OpFMod : public spv_inst { OpFMod(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FMod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1834,8 +2150,11 @@ class OpVectorTimesScalar : public spv_inst { OpVectorTimesScalar(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::VectorTimesScalar, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1850,8 +2169,11 @@ class OpMatrixTimesScalar : public spv_inst { OpMatrixTimesScalar(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::MatrixTimesScalar, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1866,8 +2188,11 @@ class OpVectorTimesMatrix : public spv_inst { OpVectorTimesMatrix(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::VectorTimesMatrix, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1882,8 +2207,11 @@ class OpMatrixTimesVector : public spv_inst { OpMatrixTimesVector(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::MatrixTimesVector, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1898,8 +2226,11 @@ class OpMatrixTimesMatrix : public spv_inst { OpMatrixTimesMatrix(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::MatrixTimesMatrix, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1914,8 +2245,11 @@ class OpOuterProduct : public spv_inst { OpOuterProduct(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::OuterProduct, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1929,8 +2263,11 @@ class OpDot : public spv_inst { OpDot(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::Dot, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1944,8 +2281,11 @@ class OpIAddCarry : public spv_inst { OpIAddCarry(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::IAddCarry, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1959,8 +2299,11 @@ class OpISubBorrow : public spv_inst { OpISubBorrow(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::ISubBorrow, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1974,8 +2317,11 @@ class OpUMulExtended : public spv_inst { OpUMulExtended(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::UMulExtended, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -1989,8 +2335,11 @@ class OpSMulExtended : public spv_inst { OpSMulExtended(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::SMulExtended, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2003,7 +2352,9 @@ class OpAny : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Any; } OpAny(IdResultType type, IdRef op0) : spv_inst{Op::Any, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2015,7 +2366,9 @@ class OpAll : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::All; } OpAll(IdResultType type, IdRef op0) : spv_inst{Op::All, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2027,7 +2380,9 @@ class OpIsNan : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsNan; } OpIsNan(IdResultType type, IdRef op0) : spv_inst{Op::IsNan, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2039,7 +2394,9 @@ class OpIsInf : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsInf; } OpIsInf(IdResultType type, IdRef op0) : spv_inst{Op::IsInf, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2052,7 +2409,9 @@ class OpIsFinite : public spv_inst { constexpr static std::array required_capabilities = {Capability::Kernel}; OpIsFinite(IdResultType type, IdRef op0) : spv_inst{Op::IsFinite, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2065,7 +2424,9 @@ class OpIsNormal : public spv_inst { constexpr static std::array required_capabilities = {Capability::Kernel}; OpIsNormal(IdResultType type, IdRef op0) : spv_inst{Op::IsNormal, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2078,7 +2439,9 @@ class OpSignBitSet : public spv_inst { constexpr static std::array required_capabilities = {Capability::Kernel}; OpSignBitSet(IdResultType type, IdRef op0) : spv_inst{Op::SignBitSet, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2092,8 +2455,11 @@ class OpLessOrGreater : public spv_inst { OpLessOrGreater(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::LessOrGreater, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2108,8 +2474,11 @@ class OpOrdered : public spv_inst { OpOrdered(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::Ordered, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2124,8 +2493,11 @@ class OpUnordered : public spv_inst { OpUnordered(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::Unordered, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2139,8 +2511,11 @@ class OpLogicalEqual : public spv_inst { OpLogicalEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::LogicalEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2154,8 +2529,11 @@ class OpLogicalNotEqual : public spv_inst { OpLogicalNotEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::LogicalNotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2169,8 +2547,11 @@ class OpLogicalOr : public spv_inst { OpLogicalOr(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::LogicalOr, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2184,8 +2565,11 @@ class OpLogicalAnd : public spv_inst { OpLogicalAnd(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::LogicalAnd, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2198,7 +2582,9 @@ class OpLogicalNot : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LogicalNot; } OpLogicalNot(IdResultType type, IdRef op0) : spv_inst{Op::LogicalNot, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2211,9 +2597,13 @@ class OpSelect : public spv_inst { OpSelect(IdResultType type, IdRef op0, IdRef op1, IdRef op2) : spv_inst{Op::Select, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -2228,8 +2618,11 @@ class OpIEqual : public spv_inst { OpIEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::IEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2243,8 +2636,11 @@ class OpINotEqual : public spv_inst { OpINotEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::INotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2258,8 +2654,11 @@ class OpUGreaterThan : public spv_inst { OpUGreaterThan(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::UGreaterThan, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2273,8 +2672,11 @@ class OpSGreaterThan : public spv_inst { OpSGreaterThan(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::SGreaterThan, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2288,8 +2690,11 @@ class OpUGreaterThanEqual : public spv_inst { OpUGreaterThanEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::UGreaterThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2303,8 +2708,11 @@ class OpSGreaterThanEqual : public spv_inst { OpSGreaterThanEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::SGreaterThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2318,8 +2726,11 @@ class OpULessThan : public spv_inst { OpULessThan(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::ULessThan, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2333,8 +2744,11 @@ class OpSLessThan : public spv_inst { OpSLessThan(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::SLessThan, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2348,8 +2762,11 @@ class OpULessThanEqual : public spv_inst { OpULessThanEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::ULessThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2363,8 +2780,11 @@ class OpSLessThanEqual : public spv_inst { OpSLessThanEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::SLessThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2378,8 +2798,11 @@ class OpFOrdEqual : public spv_inst { OpFOrdEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FOrdEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2393,8 +2816,11 @@ class OpFUnordEqual : public spv_inst { OpFUnordEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FUnordEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2408,8 +2834,11 @@ class OpFOrdNotEqual : public spv_inst { OpFOrdNotEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FOrdNotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2423,8 +2852,11 @@ class OpFUnordNotEqual : public spv_inst { OpFUnordNotEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FUnordNotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2438,8 +2870,11 @@ class OpFOrdLessThan : public spv_inst { OpFOrdLessThan(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FOrdLessThan, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2453,8 +2888,11 @@ class OpFUnordLessThan : public spv_inst { OpFUnordLessThan(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FUnordLessThan, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2468,8 +2906,11 @@ class OpFOrdGreaterThan : public spv_inst { OpFOrdGreaterThan(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FOrdGreaterThan, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2483,8 +2924,11 @@ class OpFUnordGreaterThan : public spv_inst { OpFUnordGreaterThan(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FUnordGreaterThan, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2498,8 +2942,11 @@ class OpFOrdLessThanEqual : public spv_inst { OpFOrdLessThanEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FOrdLessThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2513,8 +2960,11 @@ class OpFUnordLessThanEqual : public spv_inst { OpFUnordLessThanEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FUnordLessThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2528,8 +2978,11 @@ class OpFOrdGreaterThanEqual : public spv_inst { OpFOrdGreaterThanEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FOrdGreaterThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2545,8 +2998,11 @@ class OpFUnordGreaterThanEqual : public spv_inst { OpFUnordGreaterThanEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::FUnordGreaterThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2560,8 +3016,11 @@ class OpShiftRightLogical : public spv_inst { OpShiftRightLogical(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::ShiftRightLogical, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2575,8 +3034,11 @@ class OpShiftRightArithmetic : public spv_inst { OpShiftRightArithmetic(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::ShiftRightArithmetic, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2590,8 +3052,11 @@ class OpShiftLeftLogical : public spv_inst { OpShiftLeftLogical(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::ShiftLeftLogical, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2605,8 +3070,11 @@ class OpBitwiseOr : public spv_inst { OpBitwiseOr(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::BitwiseOr, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2620,8 +3088,11 @@ class OpBitwiseXor : public spv_inst { OpBitwiseXor(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::BitwiseXor, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2635,8 +3106,11 @@ class OpBitwiseAnd : public spv_inst { OpBitwiseAnd(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::BitwiseAnd, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -2649,7 +3123,9 @@ class OpNot : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Not; } OpNot(IdResultType type, IdRef op0) : spv_inst{Op::Not, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2664,10 +3140,15 @@ class OpBitFieldInsert : public spv_inst { OpBitFieldInsert(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) : spv_inst{Op::BitFieldInsert, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -2685,9 +3166,13 @@ class OpBitFieldSExtract : public spv_inst { OpBitFieldSExtract(IdResultType type, IdRef op0, IdRef op1, IdRef op2) : spv_inst{Op::BitFieldSExtract, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -2704,9 +3189,13 @@ class OpBitFieldUExtract : public spv_inst { OpBitFieldUExtract(IdResultType type, IdRef op0, IdRef op1, IdRef op2) : spv_inst{Op::BitFieldUExtract, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -2722,7 +3211,9 @@ class OpBitReverse : public spv_inst { Capability::Shader, Capability::BitInstructions}; OpBitReverse(IdResultType type, IdRef op0) : spv_inst{Op::BitReverse, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2734,7 +3225,9 @@ class OpBitCount : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitCount; } OpBitCount(IdResultType type, IdRef op0) : spv_inst{Op::BitCount, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2747,7 +3240,9 @@ class OpDPdx : public spv_inst { constexpr static std::array required_capabilities = {Capability::Shader}; OpDPdx(IdResultType type, IdRef op0) : spv_inst{Op::DPdx, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2760,7 +3255,9 @@ class OpDPdy : public spv_inst { constexpr static std::array required_capabilities = {Capability::Shader}; OpDPdy(IdResultType type, IdRef op0) : spv_inst{Op::DPdy, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2773,7 +3270,9 @@ class OpFwidth : public spv_inst { constexpr static std::array required_capabilities = {Capability::Shader}; OpFwidth(IdResultType type, IdRef op0) : spv_inst{Op::Fwidth, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2787,7 +3286,9 @@ class OpDPdxFine : public spv_inst { Capability::DerivativeControl}; OpDPdxFine(IdResultType type, IdRef op0) : spv_inst{Op::DPdxFine, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2801,7 +3302,9 @@ class OpDPdyFine : public spv_inst { Capability::DerivativeControl}; OpDPdyFine(IdResultType type, IdRef op0) : spv_inst{Op::DPdyFine, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2815,7 +3318,9 @@ class OpFwidthFine : public spv_inst { Capability::DerivativeControl}; OpFwidthFine(IdResultType type, IdRef op0) : spv_inst{Op::FwidthFine, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2829,7 +3334,9 @@ class OpDPdxCoarse : public spv_inst { Capability::DerivativeControl}; OpDPdxCoarse(IdResultType type, IdRef op0) : spv_inst{Op::DPdxCoarse, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2843,7 +3350,9 @@ class OpDPdyCoarse : public spv_inst { Capability::DerivativeControl}; OpDPdyCoarse(IdResultType type, IdRef op0) : spv_inst{Op::DPdyCoarse, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2857,7 +3366,9 @@ class OpFwidthCoarse : public spv_inst { Capability::DerivativeControl}; OpFwidthCoarse(IdResultType type, IdRef op0) : spv_inst{Op::FwidthCoarse, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2886,6 +3397,7 @@ class OpEmitStreamVertex : public spv_inst { constexpr static std::array required_capabilities = { Capability::GeometryStreams}; OpEmitStreamVertex(IdRef op0) : spv_inst{Op::EmitStreamVertex, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2898,6 +3410,7 @@ class OpEndStreamPrimitive : public spv_inst { Capability::GeometryStreams}; OpEndStreamPrimitive(IdRef op0) : spv_inst{Op::EndStreamPrimitive, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -2909,8 +3422,11 @@ class OpControlBarrier : public spv_inst { OpControlBarrier(IdScope op0, IdScope op1, IdMemorySemantics op2) : spv_inst{Op::ControlBarrier, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } private: @@ -2923,7 +3439,9 @@ class OpMemoryBarrier : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemoryBarrier; } OpMemoryBarrier(IdScope op0, IdMemorySemantics op1) : spv_inst{Op::MemoryBarrier, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdMemorySemantics & { return op1_; } inline auto op1() const -> IdMemorySemantics const & { return op1_; } private: @@ -2936,9 +3454,13 @@ class OpAtomicLoad : public spv_inst { OpAtomicLoad(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2) : spv_inst{Op::AtomicLoad, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } private: @@ -2953,9 +3475,13 @@ class OpAtomicStore : public spv_inst { OpAtomicStore(IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) : spv_inst{Op::AtomicStore, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -2970,10 +3496,15 @@ class OpAtomicExchange : public spv_inst { OpAtomicExchange(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) : spv_inst{Op::AtomicExchange, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -2993,12 +3524,19 @@ class OpAtomicCompareExchange : public spv_inst { : spv_inst{Op::AtomicCompareExchange, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), op5_(std::move(op5)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdMemorySemantics & { return op3_; } inline auto op3() const -> IdMemorySemantics const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } inline auto op5() const -> IdRef const & { return op5_; } private: @@ -3021,12 +3559,19 @@ class OpAtomicCompareExchangeWeak : public spv_inst { : spv_inst{Op::AtomicCompareExchangeWeak, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), op5_(std::move(op5)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdMemorySemantics & { return op3_; } inline auto op3() const -> IdMemorySemantics const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } inline auto op5() const -> IdRef const & { return op5_; } private: @@ -3044,9 +3589,13 @@ class OpAtomicIIncrement : public spv_inst { OpAtomicIIncrement(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2) : spv_inst{Op::AtomicIIncrement, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } private: @@ -3061,9 +3610,13 @@ class OpAtomicIDecrement : public spv_inst { OpAtomicIDecrement(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2) : spv_inst{Op::AtomicIDecrement, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } private: @@ -3078,10 +3631,15 @@ class OpAtomicIAdd : public spv_inst { OpAtomicIAdd(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) : spv_inst{Op::AtomicIAdd, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -3097,10 +3655,15 @@ class OpAtomicISub : public spv_inst { OpAtomicISub(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) : spv_inst{Op::AtomicISub, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -3116,10 +3679,15 @@ class OpAtomicSMin : public spv_inst { OpAtomicSMin(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) : spv_inst{Op::AtomicSMin, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -3135,10 +3703,15 @@ class OpAtomicUMin : public spv_inst { OpAtomicUMin(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) : spv_inst{Op::AtomicUMin, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -3154,10 +3727,15 @@ class OpAtomicSMax : public spv_inst { OpAtomicSMax(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) : spv_inst{Op::AtomicSMax, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -3173,10 +3751,15 @@ class OpAtomicUMax : public spv_inst { OpAtomicUMax(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) : spv_inst{Op::AtomicUMax, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -3192,10 +3775,15 @@ class OpAtomicAnd : public spv_inst { OpAtomicAnd(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) : spv_inst{Op::AtomicAnd, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -3211,10 +3799,15 @@ class OpAtomicOr : public spv_inst { OpAtomicOr(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) : spv_inst{Op::AtomicOr, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -3230,10 +3823,15 @@ class OpAtomicXor : public spv_inst { OpAtomicXor(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) : spv_inst{Op::AtomicXor, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -3248,7 +3846,9 @@ class OpPhi : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Phi; } OpPhi(IdResultType type, std::vector op0) : spv_inst{Op::Phi, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> std::vector & { return op0_; } inline auto op0() const -> std::vector const & { return op0_; } private: @@ -3261,8 +3861,11 @@ class OpLoopMerge : public spv_inst { OpLoopMerge(IdRef op0, IdRef op1, LoopControl op2) : spv_inst{Op::LoopMerge, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> LoopControl & { return op2_; } inline auto op2() const -> LoopControl const & { return op2_; } private: @@ -3275,7 +3878,9 @@ class OpSelectionMerge : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SelectionMerge; } OpSelectionMerge(IdRef op0, SelectionControl op1) : spv_inst{Op::SelectionMerge, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> SelectionControl & { return op1_; } inline auto op1() const -> SelectionControl const & { return op1_; } private: @@ -3293,6 +3898,7 @@ class OpBranch : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Branch; } OpBranch(IdRef op0) : spv_inst{Op::Branch, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -3304,9 +3910,13 @@ class OpBranchConditional : public spv_inst { OpBranchConditional(IdRef op0, IdRef op1, IdRef op2, std::vector op3) : spv_inst{Op::BranchConditional, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::vector & { return op3_; } inline auto op3() const -> std::vector const & { return op3_; } private: @@ -3321,8 +3931,11 @@ class OpSwitch : public spv_inst { OpSwitch(IdRef op0, IdRef op1, std::vector op2) : spv_inst{Op::Switch, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } inline auto op2() const -> std::vector const & { return op2_; } private: @@ -3349,6 +3962,7 @@ class OpReturnValue : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReturnValue; } OpReturnValue(IdRef op0) : spv_inst{Op::ReturnValue, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -3367,7 +3981,9 @@ class OpLifetimeStart : public spv_inst { constexpr static std::array required_capabilities = {Capability::Kernel}; OpLifetimeStart(IdRef op0, LiteralInteger op1) : spv_inst{Op::LifetimeStart, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } inline auto op1() const -> LiteralInteger const & { return op1_; } private: @@ -3380,7 +3996,9 @@ class OpLifetimeStop : public spv_inst { constexpr static std::array required_capabilities = {Capability::Kernel}; OpLifetimeStop(IdRef op0, LiteralInteger op1) : spv_inst{Op::LifetimeStop, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } inline auto op1() const -> LiteralInteger const & { return op1_; } private: @@ -3396,12 +4014,19 @@ class OpGroupAsyncCopy : public spv_inst { : spv_inst{Op::GroupAsyncCopy, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), op5_(std::move(op5)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } inline auto op5() const -> IdRef const & { return op5_; } private: @@ -3420,8 +4045,11 @@ class OpGroupWaitEvents : public spv_inst { OpGroupWaitEvents(IdScope op0, IdRef op1, IdRef op2) : spv_inst{Op::GroupWaitEvents, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -3436,8 +4064,11 @@ class OpGroupAll : public spv_inst { OpGroupAll(IdResultType type, IdScope op0, IdRef op1) : spv_inst{Op::GroupAll, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -3452,8 +4083,11 @@ class OpGroupAny : public spv_inst { OpGroupAny(IdResultType type, IdScope op0, IdRef op1) : spv_inst{Op::GroupAny, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -3468,9 +4102,13 @@ class OpGroupBroadcast : public spv_inst { OpGroupBroadcast(IdResultType type, IdScope op0, IdRef op1, IdRef op2) : spv_inst{Op::GroupBroadcast, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -3486,9 +4124,13 @@ class OpGroupIAdd : public spv_inst { OpGroupIAdd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) : spv_inst{Op::GroupIAdd, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -3504,9 +4146,13 @@ class OpGroupFAdd : public spv_inst { OpGroupFAdd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) : spv_inst{Op::GroupFAdd, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -3522,9 +4168,13 @@ class OpGroupFMin : public spv_inst { OpGroupFMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) : spv_inst{Op::GroupFMin, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -3540,9 +4190,13 @@ class OpGroupUMin : public spv_inst { OpGroupUMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) : spv_inst{Op::GroupUMin, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -3558,9 +4212,13 @@ class OpGroupSMin : public spv_inst { OpGroupSMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) : spv_inst{Op::GroupSMin, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -3576,9 +4234,13 @@ class OpGroupFMax : public spv_inst { OpGroupFMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) : spv_inst{Op::GroupFMax, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -3594,9 +4256,13 @@ class OpGroupUMax : public spv_inst { OpGroupUMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) : spv_inst{Op::GroupUMax, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -3612,9 +4278,13 @@ class OpGroupSMax : public spv_inst { OpGroupSMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) : spv_inst{Op::GroupSMax, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -3630,10 +4300,15 @@ class OpReadPipe : public spv_inst { OpReadPipe(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) : spv_inst{Op::ReadPipe, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -3650,10 +4325,15 @@ class OpWritePipe : public spv_inst { OpWritePipe(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) : spv_inst{Op::WritePipe, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -3672,12 +4352,19 @@ class OpReservedReadPipe : public spv_inst { : spv_inst{Op::ReservedReadPipe, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), op5_(std::move(op5)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } inline auto op5() const -> IdRef const & { return op5_; } private: @@ -3698,12 +4385,19 @@ class OpReservedWritePipe : public spv_inst { : spv_inst{Op::ReservedWritePipe, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), op5_(std::move(op5)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } inline auto op5() const -> IdRef const & { return op5_; } private: @@ -3724,10 +4418,15 @@ class OpReserveReadPipePackets : public spv_inst { OpReserveReadPipePackets(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) : spv_inst{Op::ReserveReadPipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -3746,10 +4445,15 @@ class OpReserveWritePipePackets : public spv_inst { OpReserveWritePipePackets(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) : spv_inst{Op::ReserveWritePipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -3766,9 +4470,13 @@ class OpCommitReadPipe : public spv_inst { OpCommitReadPipe(IdRef op0, IdRef op1, IdRef op2, IdRef op3) : spv_inst{Op::CommitReadPipe, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -3784,9 +4492,13 @@ class OpCommitWritePipe : public spv_inst { OpCommitWritePipe(IdRef op0, IdRef op1, IdRef op2, IdRef op3) : spv_inst{Op::CommitWritePipe, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -3801,7 +4513,9 @@ class OpIsValidReserveId : public spv_inst { constexpr static std::array required_capabilities = {Capability::Pipes}; OpIsValidReserveId(IdResultType type, IdRef op0) : spv_inst{Op::IsValidReserveId, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -3815,9 +4529,13 @@ class OpGetNumPipePackets : public spv_inst { OpGetNumPipePackets(IdResultType type, IdRef op0, IdRef op1, IdRef op2) : spv_inst{Op::GetNumPipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -3833,9 +4551,13 @@ class OpGetMaxPipePackets : public spv_inst { OpGetMaxPipePackets(IdResultType type, IdRef op0, IdRef op1, IdRef op2) : spv_inst{Op::GetMaxPipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -3855,11 +4577,17 @@ class OpGroupReserveReadPipePackets : public spv_inst { : spv_inst{Op::GroupReserveReadPipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } inline auto op4() const -> IdRef const & { return op4_; } private: @@ -3881,11 +4609,17 @@ class OpGroupReserveWritePipePackets : public spv_inst { : spv_inst{Op::GroupReserveWritePipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } inline auto op4() const -> IdRef const & { return op4_; } private: @@ -3903,10 +4637,15 @@ class OpGroupCommitReadPipe : public spv_inst { OpGroupCommitReadPipe(IdScope op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4) : spv_inst{Op::GroupCommitReadPipe, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } inline auto op4() const -> IdRef const & { return op4_; } private: @@ -3923,10 +4662,15 @@ class OpGroupCommitWritePipe : public spv_inst { OpGroupCommitWritePipe(IdScope op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4) : spv_inst{Op::GroupCommitWritePipe, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } inline auto op4() const -> IdRef const & { return op4_; } private: @@ -3943,10 +4687,15 @@ class OpEnqueueMarker : public spv_inst { OpEnqueueMarker(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) : spv_inst{Op::EnqueueMarker, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -3966,17 +4715,29 @@ class OpEnqueueKernel : public spv_inst { op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), op5_(std::move(op5)), op6_(std::move(op6)), op7_(std::move(op7)), op8_(std::move(op8)), op9_(std::move(op9)), op10_(std::move(op10)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } inline auto op5() const -> IdRef const & { return op5_; } + inline auto op6() -> IdRef & { return op6_; } inline auto op6() const -> IdRef const & { return op6_; } + inline auto op7() -> IdRef & { return op7_; } inline auto op7() const -> IdRef const & { return op7_; } + inline auto op8() -> IdRef & { return op8_; } inline auto op8() const -> IdRef const & { return op8_; } + inline auto op9() -> IdRef & { return op9_; } inline auto op9() const -> IdRef const & { return op9_; } + inline auto op10() -> std::vector & { return op10_; } inline auto op10() const -> std::vector const & { return op10_; } private: @@ -4004,11 +4765,17 @@ class OpGetKernelNDrangeSubGroupCount : public spv_inst { : spv_inst{Op::GetKernelNDrangeSubGroupCount, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } inline auto op4() const -> IdRef const & { return op4_; } private: @@ -4030,11 +4797,17 @@ class OpGetKernelNDrangeMaxSubGroupSize : public spv_inst { : spv_inst{Op::GetKernelNDrangeMaxSubGroupSize, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } inline auto op4() const -> IdRef const & { return op4_; } private: @@ -4054,10 +4827,15 @@ class OpGetKernelWorkGroupSize : public spv_inst { OpGetKernelWorkGroupSize(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) : spv_inst{Op::GetKernelWorkGroupSize, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -4077,10 +4855,15 @@ class OpGetKernelPreferredWorkGroupSizeMultiple : public spv_inst { IdRef op3) : spv_inst{Op::GetKernelPreferredWorkGroupSizeMultiple, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -4095,6 +4878,7 @@ class OpRetainEvent : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::RetainEvent; } constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; OpRetainEvent(IdRef op0) : spv_inst{Op::RetainEvent, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -4105,6 +4889,7 @@ class OpReleaseEvent : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReleaseEvent; } constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; OpReleaseEvent(IdRef op0) : spv_inst{Op::ReleaseEvent, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -4116,6 +4901,7 @@ class OpCreateUserEvent : public spv_inst { constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; OpCreateUserEvent(IdResultType type) : spv_inst{Op::CreateUserEvent, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } private: @@ -4127,7 +4913,9 @@ class OpIsValidEvent : public spv_inst { constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; OpIsValidEvent(IdResultType type, IdRef op0) : spv_inst{Op::IsValidEvent, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -4140,7 +4928,9 @@ class OpSetUserEventStatus : public spv_inst { constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; OpSetUserEventStatus(IdRef op0, IdRef op1) : spv_inst{Op::SetUserEventStatus, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -4156,8 +4946,11 @@ class OpCaptureEventProfilingInfo : public spv_inst { OpCaptureEventProfilingInfo(IdRef op0, IdRef op1, IdRef op2) : spv_inst{Op::CaptureEventProfilingInfo, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -4171,6 +4964,7 @@ class OpGetDefaultQueue : public spv_inst { constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; OpGetDefaultQueue(IdResultType type) : spv_inst{Op::GetDefaultQueue, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } private: @@ -4183,9 +4977,13 @@ class OpBuildNDRange : public spv_inst { OpBuildNDRange(IdResultType type, IdRef op0, IdRef op1, IdRef op2) : spv_inst{Op::BuildNDRange, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -4205,9 +5003,13 @@ class OpImageSparseSampleImplicitLod : public spv_inst { std::optional op2 = std::nullopt) : spv_inst{Op::ImageSparseSampleImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } inline auto op2() const -> std::optional const & { return op2_; } private: @@ -4226,9 +5028,13 @@ class OpImageSparseSampleExplicitLod : public spv_inst { OpImageSparseSampleExplicitLod(IdResultType type, IdRef op0, IdRef op1, ImageOperands op2) : spv_inst{Op::ImageSparseSampleExplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> ImageOperands & { return op2_; } inline auto op2() const -> ImageOperands const & { return op2_; } private: @@ -4248,10 +5054,15 @@ class OpImageSparseSampleDrefImplicitLod : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::ImageSparseSampleDrefImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -4272,10 +5083,15 @@ class OpImageSparseSampleDrefExplicitLod : public spv_inst { ImageOperands op3) : spv_inst{Op::ImageSparseSampleDrefExplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> ImageOperands & { return op3_; } inline auto op3() const -> ImageOperands const & { return op3_; } private: @@ -4296,9 +5112,13 @@ class OpImageSparseSampleProjImplicitLod : public spv_inst { std::optional op2 = std::nullopt) : spv_inst{Op::ImageSparseSampleProjImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } inline auto op2() const -> std::optional const & { return op2_; } private: @@ -4317,9 +5137,13 @@ class OpImageSparseSampleProjExplicitLod : public spv_inst { OpImageSparseSampleProjExplicitLod(IdResultType type, IdRef op0, IdRef op1, ImageOperands op2) : spv_inst{Op::ImageSparseSampleProjExplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> ImageOperands & { return op2_; } inline auto op2() const -> ImageOperands const & { return op2_; } private: @@ -4339,10 +5163,15 @@ class OpImageSparseSampleProjDrefImplicitLod : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::ImageSparseSampleProjDrefImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -4363,10 +5192,15 @@ class OpImageSparseSampleProjDrefExplicitLod : public spv_inst { ImageOperands op3) : spv_inst{Op::ImageSparseSampleProjDrefExplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> ImageOperands & { return op3_; } inline auto op3() const -> ImageOperands const & { return op3_; } private: @@ -4385,9 +5219,13 @@ class OpImageSparseFetch : public spv_inst { std::optional op2 = std::nullopt) : spv_inst{Op::ImageSparseFetch, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } inline auto op2() const -> std::optional const & { return op2_; } private: @@ -4405,10 +5243,15 @@ class OpImageSparseGather : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::ImageSparseGather, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -4429,10 +5272,15 @@ class OpImageSparseDrefGather : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::ImageSparseDrefGather, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -4452,7 +5300,9 @@ class OpImageSparseTexelsResident : public spv_inst { OpImageSparseTexelsResident(IdResultType type, IdRef op0) : spv_inst{Op::ImageSparseTexelsResident, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -4473,9 +5323,13 @@ class OpAtomicFlagTestAndSet : public spv_inst { OpAtomicFlagTestAndSet(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2) : spv_inst{Op::AtomicFlagTestAndSet, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } private: @@ -4491,8 +5345,11 @@ class OpAtomicFlagClear : public spv_inst { OpAtomicFlagClear(IdRef op0, IdScope op1, IdMemorySemantics op2) : spv_inst{Op::AtomicFlagClear, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } private: @@ -4509,9 +5366,13 @@ class OpImageSparseRead : public spv_inst { std::optional op2 = std::nullopt) : spv_inst{Op::ImageSparseRead, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } inline auto op2() const -> std::optional const & { return op2_; } private: @@ -4526,7 +5387,9 @@ class OpSizeOf : public spv_inst { constexpr static std::array required_capabilities = {Capability::Addresses}; OpSizeOf(IdResultType type, IdRef op0) : spv_inst{Op::SizeOf, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -4549,9 +5412,13 @@ class OpConstantPipeStorage : public spv_inst { LiteralInteger op2) : spv_inst{Op::ConstantPipeStorage, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> LiteralInteger & { return op0_; } inline auto op0() const -> LiteralInteger const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> LiteralInteger & { return op2_; } inline auto op2() const -> LiteralInteger const & { return op2_; } private: @@ -4569,7 +5436,9 @@ class OpCreatePipeFromPipeStorage : public spv_inst { OpCreatePipeFromPipeStorage(IdResultType type, IdRef op0) : spv_inst{Op::CreatePipeFromPipeStorage, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -4588,11 +5457,17 @@ class OpGetKernelLocalSizeForSubgroupCount : public spv_inst { : spv_inst{Op::GetKernelLocalSizeForSubgroupCount, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } inline auto op4() const -> IdRef const & { return op4_; } private: @@ -4613,10 +5488,15 @@ class OpGetKernelMaxNumSubgroups : public spv_inst { OpGetKernelMaxNumSubgroups(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) : spv_inst{Op::GetKernelMaxNumSubgroups, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } private: @@ -4643,7 +5523,9 @@ class OpNamedBarrierInitialize : public spv_inst { OpNamedBarrierInitialize(IdResultType type, IdRef op0) : spv_inst{Op::NamedBarrierInitialize, true}, type_(std::move(type)), op0_(std::move(op0)) { } + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -4657,8 +5539,11 @@ class OpMemoryNamedBarrier : public spv_inst { OpMemoryNamedBarrier(IdRef op0, IdScope op1, IdMemorySemantics op2) : spv_inst{Op::MemoryNamedBarrier, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } inline auto op2() const -> IdMemorySemantics const & { return op2_; } private: @@ -4671,6 +5556,7 @@ class OpModuleProcessed : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ModuleProcessed; } OpModuleProcessed(LiteralString op0) : spv_inst{Op::ModuleProcessed, false}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } inline auto op0() const -> LiteralString const & { return op0_; } private: @@ -4681,7 +5567,9 @@ class OpExecutionModeId : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ExecutionModeId; } OpExecutionModeId(IdRef op0, ExecutionMode op1) : spv_inst{Op::ExecutionModeId, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> ExecutionMode & { return op1_; } inline auto op1() const -> ExecutionMode const & { return op1_; } private: @@ -4693,7 +5581,9 @@ class OpDecorateId : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DecorateId; } OpDecorateId(IdRef op0, Decoration op1) : spv_inst{Op::DecorateId, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> Decoration & { return op1_; } inline auto op1() const -> Decoration const & { return op1_; } private: @@ -4707,7 +5597,9 @@ class OpGroupNonUniformElect : public spv_inst { Capability::GroupNonUniform}; OpGroupNonUniformElect(IdResultType type, IdScope op0) : spv_inst{Op::GroupNonUniformElect, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } private: @@ -4722,8 +5614,11 @@ class OpGroupNonUniformAll : public spv_inst { OpGroupNonUniformAll(IdResultType type, IdScope op0, IdRef op1) : spv_inst{Op::GroupNonUniformAll, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -4739,8 +5634,11 @@ class OpGroupNonUniformAny : public spv_inst { OpGroupNonUniformAny(IdResultType type, IdScope op0, IdRef op1) : spv_inst{Op::GroupNonUniformAny, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -4758,8 +5656,11 @@ class OpGroupNonUniformAllEqual : public spv_inst { OpGroupNonUniformAllEqual(IdResultType type, IdScope op0, IdRef op1) : spv_inst{Op::GroupNonUniformAllEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -4777,9 +5678,13 @@ class OpGroupNonUniformBroadcast : public spv_inst { OpGroupNonUniformBroadcast(IdResultType type, IdScope op0, IdRef op1, IdRef op2) : spv_inst{Op::GroupNonUniformBroadcast, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -4798,8 +5703,11 @@ class OpGroupNonUniformBroadcastFirst : public spv_inst { OpGroupNonUniformBroadcastFirst(IdResultType type, IdScope op0, IdRef op1) : spv_inst{Op::GroupNonUniformBroadcastFirst, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -4817,8 +5725,11 @@ class OpGroupNonUniformBallot : public spv_inst { OpGroupNonUniformBallot(IdResultType type, IdScope op0, IdRef op1) : spv_inst{Op::GroupNonUniformBallot, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -4836,8 +5747,11 @@ class OpGroupNonUniformInverseBallot : public spv_inst { OpGroupNonUniformInverseBallot(IdResultType type, IdScope op0, IdRef op1) : spv_inst{Op::GroupNonUniformInverseBallot, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -4855,9 +5769,13 @@ class OpGroupNonUniformBallotBitExtract : public spv_inst { OpGroupNonUniformBallotBitExtract(IdResultType type, IdScope op0, IdRef op1, IdRef op2) : spv_inst{Op::GroupNonUniformBallotBitExtract, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -4876,9 +5794,13 @@ class OpGroupNonUniformBallotBitCount : public spv_inst { OpGroupNonUniformBallotBitCount(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) : spv_inst{Op::GroupNonUniformBallotBitCount, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -4897,8 +5819,11 @@ class OpGroupNonUniformBallotFindLSB : public spv_inst { OpGroupNonUniformBallotFindLSB(IdResultType type, IdScope op0, IdRef op1) : spv_inst{Op::GroupNonUniformBallotFindLSB, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -4916,8 +5841,11 @@ class OpGroupNonUniformBallotFindMSB : public spv_inst { OpGroupNonUniformBallotFindMSB(IdResultType type, IdScope op0, IdRef op1) : spv_inst{Op::GroupNonUniformBallotFindMSB, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -4935,9 +5863,13 @@ class OpGroupNonUniformShuffle : public spv_inst { OpGroupNonUniformShuffle(IdResultType type, IdScope op0, IdRef op1, IdRef op2) : spv_inst{Op::GroupNonUniformShuffle, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -4956,9 +5888,13 @@ class OpGroupNonUniformShuffleXor : public spv_inst { OpGroupNonUniformShuffleXor(IdResultType type, IdScope op0, IdRef op1, IdRef op2) : spv_inst{Op::GroupNonUniformShuffleXor, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -4977,9 +5913,13 @@ class OpGroupNonUniformShuffleUp : public spv_inst { OpGroupNonUniformShuffleUp(IdResultType type, IdScope op0, IdRef op1, IdRef op2) : spv_inst{Op::GroupNonUniformShuffleUp, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -4998,9 +5938,13 @@ class OpGroupNonUniformShuffleDown : public spv_inst { OpGroupNonUniformShuffleDown(IdResultType type, IdScope op0, IdRef op1, IdRef op2) : spv_inst{Op::GroupNonUniformShuffleDown, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -5019,10 +5963,15 @@ class OpGroupNonUniformIAdd : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformIAdd, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5042,10 +5991,15 @@ class OpGroupNonUniformFAdd : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformFAdd, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5065,10 +6019,15 @@ class OpGroupNonUniformIMul : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformIMul, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5088,10 +6047,15 @@ class OpGroupNonUniformFMul : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformFMul, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5111,10 +6075,15 @@ class OpGroupNonUniformSMin : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformSMin, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5134,10 +6103,15 @@ class OpGroupNonUniformUMin : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformUMin, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5157,10 +6131,15 @@ class OpGroupNonUniformFMin : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformFMin, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5180,10 +6159,15 @@ class OpGroupNonUniformSMax : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformSMax, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5203,10 +6187,15 @@ class OpGroupNonUniformUMax : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformUMax, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5226,10 +6215,15 @@ class OpGroupNonUniformFMax : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformFMax, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5251,10 +6245,15 @@ class OpGroupNonUniformBitwiseAnd : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformBitwiseAnd, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5276,10 +6275,15 @@ class OpGroupNonUniformBitwiseOr : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformBitwiseOr, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5301,10 +6305,15 @@ class OpGroupNonUniformBitwiseXor : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformBitwiseXor, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5326,10 +6335,15 @@ class OpGroupNonUniformLogicalAnd : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformLogicalAnd, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5351,10 +6365,15 @@ class OpGroupNonUniformLogicalOr : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformLogicalOr, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5376,10 +6395,15 @@ class OpGroupNonUniformLogicalXor : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::GroupNonUniformLogicalXor, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5399,9 +6423,13 @@ class OpGroupNonUniformQuadBroadcast : public spv_inst { OpGroupNonUniformQuadBroadcast(IdResultType type, IdScope op0, IdRef op1, IdRef op2) : spv_inst{Op::GroupNonUniformQuadBroadcast, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -5420,9 +6448,13 @@ class OpGroupNonUniformQuadSwap : public spv_inst { OpGroupNonUniformQuadSwap(IdResultType type, IdScope op0, IdRef op1, IdRef op2) : spv_inst{Op::GroupNonUniformQuadSwap, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } private: @@ -5436,7 +6468,9 @@ class OpCopyLogical : public spv_inst { inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CopyLogical; } OpCopyLogical(IdResultType type, IdRef op0) : spv_inst{Op::CopyLogical, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -5449,8 +6483,11 @@ class OpPtrEqual : public spv_inst { OpPtrEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::PtrEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -5464,8 +6501,11 @@ class OpPtrNotEqual : public spv_inst { OpPtrNotEqual(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::PtrNotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -5482,8 +6522,11 @@ class OpPtrDiff : public spv_inst { OpPtrDiff(IdResultType type, IdRef op0, IdRef op1) : spv_inst{Op::PtrDiff, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } private: @@ -5501,10 +6544,15 @@ class OpTypeCooperativeMatrixKHR : public spv_inst { OpTypeCooperativeMatrixKHR(IdRef op0, IdScope op1, IdRef op2, IdRef op3, IdRef op4) : spv_inst{Op::TypeCooperativeMatrixKHR, true}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } inline auto op4() const -> IdRef const & { return op4_; } private: @@ -5528,11 +6576,17 @@ class OpCooperativeMatrixLoadKHR : public spv_inst { : spv_inst{Op::CooperativeMatrixLoadKHR, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } inline auto op2() const -> std::optional const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } + inline auto op4() -> std::optional & { return op4_; } inline auto op4() const -> std::optional const & { return op4_; } private: @@ -5557,11 +6611,17 @@ class OpCooperativeMatrixStoreKHR : public spv_inst { : spv_inst{Op::CooperativeMatrixStoreKHR, false}, op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), op5_(std::move(op5)) {} + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } + inline auto op4() -> std::optional & { return op4_; } inline auto op4() const -> std::optional const & { return op4_; } + inline auto op5() -> std::optional & { return op5_; } inline auto op5() const -> std::optional const & { return op5_; } private: @@ -5583,10 +6643,15 @@ class OpCooperativeMatrixMulAddKHR : public spv_inst { std::optional op3 = std::nullopt) : spv_inst{Op::CooperativeMatrixMulAddKHR, true}, type_(std::move(type)), op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } inline auto op3() const -> std::optional const & { return op3_; } private: @@ -5606,7 +6671,9 @@ class OpCooperativeMatrixLengthKHR : public spv_inst { OpCooperativeMatrixLengthKHR(IdResultType type, IdRef op0) : spv_inst{Op::CooperativeMatrixLengthKHR, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } inline auto op0() const -> IdRef const & { return op0_; } private: @@ -5616,4 +6683,4 @@ class OpCooperativeMatrixLengthKHR : public spv_inst { } // namespace tinytc::spv -#endif // GENERATED_INSTRUCTIONS_2024117_HPP +#endif // GENERATED_INSTRUCTIONS_2024118_HPP diff --git a/src/spv/names.hpp b/src/spv/names.hpp index af084c81..856ecfe5 100644 --- a/src/spv/names.hpp +++ b/src/spv/names.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_NAMES_2024117_HPP -#define GENERATED_NAMES_2024117_HPP +#ifndef GENERATED_NAMES_2024118_HPP +#define GENERATED_NAMES_2024118_HPP #include "enums.hpp" @@ -68,4 +68,4 @@ auto to_string(FPEncoding e) -> char const *; } // namespace tinytc::spv -#endif // GENERATED_NAMES_2024117_HPP +#endif // GENERATED_NAMES_2024118_HPP diff --git a/src/spv/pass/dump_asm.cpp b/src/spv/pass/dump_asm.cpp index 2dded4ac..577ffa55 100644 --- a/src/spv/pass/dump_asm.cpp +++ b/src/spv/pass/dump_asm.cpp @@ -11,7 +11,9 @@ #include "tinytc/types.hpp" #include +#include #include +#include #include #include #include @@ -63,18 +65,26 @@ void dump_asm_pass::operator()(DecorationAttr const &da) { da); } void dump_asm_pass::operator()(ExecutionModeAttr const &ea) { - std::visit(overloaded{[&](std::int32_t const &a) { *os_ << " " << a; }, - [&](std::array const &a) { - for (auto const &s : a) { - *os_ << " " << s; - } - }}, - ea); + std::visit( + overloaded{[&](std::int32_t const &a) { *os_ << " " << static_cast(a); }, + [&](std::array const &a) { + for (auto const &s : a) { + *os_ << " " << static_cast(s); + } + }}, + ea); } void dump_asm_pass::operator()(LiteralContextDependentNumber const &l) { - std::visit(overloaded{[&](auto const &l) { *os_ << " " << l; }}, l); + std::visit(overloaded{[&](std::signed_integral auto const &l) { + using unsigned_t = std::make_unsigned_t>; + *os_ << " " << static_cast(l); + }, + [&](auto const &l) { *os_ << " " << l; }}, + l); +} +void dump_asm_pass::operator()(LiteralInteger const &l) { + *os_ << " " << static_cast>(l); } -void dump_asm_pass::operator()(LiteralInteger const &l) { *os_ << " " << l; } void dump_asm_pass::operator()(LiteralString const &l) { *os_ << " \"" << l << '"'; } void dump_asm_pass::operator()(PairIdRefIdRef const &p) { @@ -86,7 +96,11 @@ void dump_asm_pass::operator()(PairIdRefLiteralInteger const &p) { this->operator()(p.second); } void dump_asm_pass::operator()(PairLiteralIntegerIdRef const &p) { - std::visit(overloaded{[&](auto const &l) { *os_ << " " << l; }}, p.first); + std::visit(overloaded{[&](auto const &l) { + using unsigned_t = std::make_unsigned_t>; + *os_ << " " << static_cast(l); + }}, + p.first); this->operator()(p.second); } @@ -97,6 +111,8 @@ void dump_asm_pass::operator()(spv_inst *const &in) { *os_ << " %" << declare(in); } else if (isa(*in)) { *os_ << " %" << declare(in); + } else if (isa(*in)) { + *os_ << " %" << declare(in); } else { throw status::spirv_forbidden_forward_declaration; } @@ -118,6 +134,16 @@ auto dump_asm_pass::operator()(OpExtInst const &in) { } } +auto dump_asm_pass::operator()(OpPhi const &in) { + pre_visit(in); + this->operator()(in.type()); + for (auto const &op : in.op0()) { + // Forward references are allowed in phi instructions + declare(op.first); + this->operator()(op); + } +} + void dump_asm_pass::run_on_module(mod const &m) { auto const visit_section = [&](section s) { for (auto const &i : m.insts(s)) { diff --git a/src/spv/pass/dump_asm.hpp b/src/spv/pass/dump_asm.hpp index dee7128f..f1082090 100644 --- a/src/spv/pass/dump_asm.hpp +++ b/src/spv/pass/dump_asm.hpp @@ -42,6 +42,7 @@ class dump_asm_pass : public default_visitor { void operator()(spv_inst *const &in); auto operator()(OpExtInst const &in); + auto operator()(OpPhi const &in); void run_on_module(mod const &m); diff --git a/src/spv/uniquifier.cpp b/src/spv/uniquifier.cpp index e5a3027e..96b144d6 100644 --- a/src/spv/uniquifier.cpp +++ b/src/spv/uniquifier.cpp @@ -189,8 +189,9 @@ auto uniquifier::spv_ty(const_tinytc_data_type_t ty) -> spv_inst * { const auto sz = size(ty.ty()); if (sz == 8) { capability(Capability::Int64); + return spv_ty(scalar_data_type::get(ctx_, scalar_type::i64)); } - return mod_->add_to(section::type_const_var, sz * 8, 0); + return spv_ty(scalar_data_type::get(ctx_, scalar_type::i32)); } case scalar_type::f32: case scalar_type::f64: diff --git a/src/spv/visit.hpp b/src/spv/visit.hpp index 1f3d992c..d9fa0726 100644 --- a/src/spv/visit.hpp +++ b/src/spv/visit.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_VISIT_2024117_HPP -#define GENERATED_VISIT_2024117_HPP +#ifndef GENERATED_VISIT_2024118_HPP +#define GENERATED_VISIT_2024118_HPP namespace tinytc::spv { @@ -2938,4 +2938,4 @@ template class default_visitor { } // namespace tinytc::spv -#endif // GENERATED_VISIT_2024117_HPP +#endif // GENERATED_VISIT_2024118_HPP diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 14f2935d..97e0b177 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -57,6 +57,8 @@ if(SPIRVTools_FOUND) spv/builtin.ir spv/cast.ir spv/compare.ir + spv/for.ir + spv/if.ir spv/work_group.ir spv/unique_function_type.ir ) diff --git a/test/spv/for.ir b/test/spv/for.ir new file mode 100644 index 00000000..a1aef989 --- /dev/null +++ b/test/spv/for.ir @@ -0,0 +1,115 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; CHECK: %[[#I16:]] = OpTypeInt 16 0 +; CHECK: %[[#I16_C0:]] = OpConstant %[[#I16]] 0 +; CHECK: %[[#I16_C10:]] = OpConstant %[[#I16]] 10 +; CHECK: %[[#I16_C2:]] = OpConstant %[[#I16]] 2 +; CHECK: %[[#BOOL:]] = OpTypeBool +; CHECK: %[[#I32:]] = OpTypeInt 32 0 +; CHECK: %[[#I32_C2:]] = OpConstant %[[#I32]] 2 +; CHECK: %[[#I32_C6:]] = OpConstant %[[#I32]] 6 +; CHECK: %[[#I64:]] = OpTypeInt 64 0 +; CHECK: %[[#I64_C0:]] = OpConstant %[[#I64]] 0 +; CHECK: %[[#I64_C1:]] = OpConstant %[[#I64]] 1 +; CHECK: %[[#I32_C1:]] = OpConstant %[[#I32]] 1 +; CHECK: %[[#I16_C2_0:]] = OpConstant %[[#I16]] 2 +; CHECK: %[[#I16_C6:]] = OpConstant %[[#I16]] 6 +; CHECK: %[[#I16_C1:]] = OpConstant %[[#I16]] 1 +; CHECK: %[[#I16_C1_0:]] = OpConstant %[[#I16]] 1 + +; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s +func @for1() { + %lb = constant 0 -> i16 + %ub = constant 10 -> i16 + %step = constant 2 -> i16 + for %0 = %lb,%ub,%step : i16 { + } +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK: OpLoopMerge %[[#MERGE_LABEL1:]] %[[#CONT_LABEL1:]] None +; CHECK-NEXT: OpBranch %[[#HEAD_LABEL1:]] +; CHECK-NEXT: %[[#HEAD_LABEL1]] = OpLabel +; CHECK-NEXT: %[[#CONDITION1_1:]] = OpSLessThan %[[#BOOL]] %[[#I16_C0]] %[[#I16_C10]] +; CHECK-NEXT: OpBranchConditional %[[#CONDITION1_1]] %[[#BODY_LABEL1:]] %[[#MERGE_LABEL1]] +; CHECK-NEXT: %[[#BODY_LABEL1]] = OpLabel +; CHECK-NEXT: %[[#LOOP_VAR1:]] = OpPhi %[[#I16]] %[[#I16_C0]] %[[#HEAD_LABEL1]] %[[#LOOP_VAR_UPDATE1:]] %[[CONT_LABEL1]] +; CHECK-NEXT: OpBranch %[[#CONT_LABEL1]] +; CHECK-NEXT: %[[#CONT_LABEL1]] = OpLabel +; CHECK-NEXT: %[[#LOOP_VAR_UPDATE1]] = OpIAdd %[[#I16]] %[[#LOOP_VAR1]] %[[#I16_C2]] +; CHECK-NEXT: %[[#CONDITION1_2:]] = OpSLessThan %[[#BOOL]] %[[#LOOP_VAR_UPDATE1]] %[[#I16_C10]] +; CHECK-NEXT: OpBranchConditional %[[#CONDITION1_2]] %[[#BODY_LABEL1]] %[[#MERGE_LABEL1]] +; CHECK-NEXT: %[[#MERGE_LABEL1]] = OpLabel +} + +func @for2() { + %from = constant 2 -> i32 + %to = constant 6 -> i32 + %f0 = constant 0 -> i64 + %f1 = constant 1 -> i64 + %fn_1, %fn = for %n=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) : i32 { + %fn = arith.add %fn_2, %fn_1 : i64 + yield %fn_1, %fn : i64, i64 + } + %neg_fn = arith.neg %fn : i64 +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK: OpLoopMerge %[[#MERGE_LABEL2:]] %[[#CONT_LABEL2:]] None +; CHECK-NEXT: OpBranch %[[#HEAD_LABEL2:]] +; CHECK-NEXT: %[[#HEAD_LABEL2]] = OpLabel +; CHECK-NEXT: %[[#CONDITION2_1:]] = OpSLessThan %[[#BOOL]] %[[#I32_C2]] %[[#I32_C6]] +; CHECK-NEXT: OpBranchConditional %[[#CONDITION2_1]] %[[#BODY_LABEL2:]] %[[#MERGE_LABEL2]] +; CHECK-NEXT: %[[#BODY_LABEL2]] = OpLabel +; CHECK-NEXT: %[[#LOOP_VAR2:]] = OpPhi %[[#I32]] %[[#I32_C2]] %[[#HEAD_LABEL2]] %[[#LOOP_VAR_UPDATE2:]] %[[CONT_LABEL2]] +; CHECK-NEXT: %[[#PHI2_1:]] = OpPhi %[[#I64]] %[[#I64_C0]] %[[#HEAD_LABEL2]] %[[#PHI2_2:]] %[[CONT_LABEL2]] +; CHECK-NEXT: %[[#PHI2_2]] = OpPhi %[[#I64]] %[[#I64_C1]] %[[#HEAD_LABEL2]] %[[#FN_UPDATE2:]] %[[CONT_LABEL2]] +; CHECK-NEXT: %[[#FN_UPDATE2]] = OpIAdd %[[#I64]] %[[#PHI2_1]] %[[#PHI2_2]] +; CHECK-NEXT: OpBranch %[[#CONT_LABEL2]] +; CHECK-NEXT: %[[#CONT_LABEL2]] = OpLabel +; CHECK-NEXT: %[[#LOOP_VAR_UPDATE2]] = OpIAdd %[[#I32]] %[[#LOOP_VAR2]] %[[#I32_C1]] +; CHECK-NEXT: %[[#CONDITION2_2:]] = OpSLessThan %[[#BOOL]] %[[#LOOP_VAR_UPDATE2]] %[[#I32_C6]] +; CHECK-NEXT: OpBranchConditional %[[#CONDITION2_2]] %[[#BODY_LABEL2]] %[[#MERGE_LABEL2]] +; CHECK-NEXT: %[[#MERGE_LABEL2]] = OpLabel +; CHECK-NEXT: %[[#]] = OpPhi %[[#I64]] %[[#I64_C0]] %[[#HEAD_LABEL2]] %[[#PHI2_2]] %[[CONT_LABEL2]] +; CHECK-NEXT: %[[#RESULT2_2:]] = OpPhi %[[#I64]] %[[#I64_C1]] %[[#HEAD_LABEL2]] %[[#FN_UPDATE2]] %[[CONT_LABEL2]] +; CHECK-NEXT: %[[#]] = OpSNegate %[[#I64]] %[[#RESULT2_2]] +} + +func @for3() subgroup_size(16) { + %from = constant 2 -> i16 + %to = constant 6 -> i16 + %m_init = constant 1 -> coopmatrix + %m = for %n=%from,%to init(%m_iter=%m_init) -> (coopmatrix) : i16 { + %m_update = arith.add %m_iter, %m_init : coopmatrix + yield %m_update : coopmatrix + } + %neg_m = arith.neg %m : coopmatrix +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK: OpLoopMerge %[[#MERGE_LABEL3:]] %[[#CONT_LABEL3:]] None +; CHECK-NEXT: OpBranch %[[#HEAD_LABEL3:]] +; CHECK-NEXT: %[[#HEAD_LABEL3]] = OpLabel +; CHECK-NEXT: %[[#CONDITION3_1:]] = OpSLessThan %[[#BOOL]] %[[#I16_C2_0]] %[[#I16_C6]] +; CHECK-NEXT: OpBranchConditional %[[#CONDITION3_1]] %[[#BODY_LABEL3:]] %[[#MERGE_LABEL3]] +; CHECK-NEXT: %[[#BODY_LABEL3]] = OpLabel +; CHECK-NEXT: %[[#LOOP_VAR3:]] = OpPhi %[[#I16]] %[[#I16_C2_0]] %[[#HEAD_LABEL3]] %[[#LOOP_VAR_UPDATE3:]] %[[CONT_LABEL3]] +; CHECK-NEXT: %[[#PHI3_1:]] = OpPhi %[[#I16]] %[[#I16_C1]] %[[#HEAD_LABEL3]] %[[#PHI_UPDATE3_1:]] %[[CONT_LABEL3]] +; CHECK-NEXT: %[[#PHI3_2:]] = OpPhi %[[#I16]] %[[#I16_C1]] %[[#HEAD_LABEL3]] %[[#PHI_UPDATE3_2:]] %[[CONT_LABEL3]] +; CHECK-NEXT: %[[#PHI3_3:]] = OpPhi %[[#I16]] %[[#I16_C1]] %[[#HEAD_LABEL3]] %[[#PHI_UPDATE3_3:]] %[[CONT_LABEL3]] +; CHECK-NEXT: %[[#PHI3_4:]] = OpPhi %[[#I16]] %[[#I16_C1]] %[[#HEAD_LABEL3]] %[[#PHI_UPDATE3_4:]] %[[CONT_LABEL3]] +; CHECK-NEXT: %[[#PHI_UPDATE3_1:]] = OpIAdd %[[#I16]] %[[#PHI3_1]] %[[#I16_C1]] +; CHECK-NEXT: %[[#PHI_UPDATE3_2:]] = OpIAdd %[[#I16]] %[[#PHI3_2]] %[[#I16_C1]] +; CHECK-NEXT: %[[#PHI_UPDATE3_3:]] = OpIAdd %[[#I16]] %[[#PHI3_3]] %[[#I16_C1]] +; CHECK-NEXT: %[[#PHI_UPDATE3_4:]] = OpIAdd %[[#I16]] %[[#PHI3_4]] %[[#I16_C1]] +; CHECK-NEXT: OpBranch %[[#CONT_LABEL3]] +; CHECK-NEXT: %[[#CONT_LABEL3]] = OpLabel +; CHECK-NEXT: %[[#LOOP_VAR_UPDATE3]] = OpIAdd %[[#I16]] %[[#LOOP_VAR3]] %[[#I16_C1_0]] +; CHECK-NEXT: %[[#CONDITION3_2:]] = OpSLessThan %[[#BOOL]] %[[#LOOP_VAR_UPDATE3]] %[[#I16_C6]] +; CHECK-NEXT: OpBranchConditional %[[#CONDITION3_2]] %[[#BODY_LABEL3]] %[[#MERGE_LABEL3]] +; CHECK-NEXT: %[[#MERGE_LABEL3]] = OpLabel +; CHECK-NEXT: %[[#RESULT3_1:]] = OpPhi %[[#I16]] %[[#I16_C1]] %[[#HEAD_LABEL3]] %[[#PHI_UPDATE3_1]] %[[CONT_LABEL3]] +; CHECK-NEXT: %[[#RESULT3_2:]] = OpPhi %[[#I16]] %[[#I16_C1]] %[[#HEAD_LABEL3]] %[[#PHI_UPDATE3_2]] %[[CONT_LABEL3]] +; CHECK-NEXT: %[[#RESULT3_3:]] = OpPhi %[[#I16]] %[[#I16_C1]] %[[#HEAD_LABEL3]] %[[#PHI_UPDATE3_3]] %[[CONT_LABEL3]] +; CHECK-NEXT: %[[#RESULT3_4:]] = OpPhi %[[#I16]] %[[#I16_C1]] %[[#HEAD_LABEL3]] %[[#PHI_UPDATE3_4]] %[[CONT_LABEL3]] +; CHECK-NEXT: %[[#]] = OpSNegate %[[#I16]] %[[#RESULT3_1]] +; CHECK-NEXT: %[[#]] = OpSNegate %[[#I16]] %[[#RESULT3_2]] +; CHECK-NEXT: %[[#]] = OpSNegate %[[#I16]] %[[#RESULT3_3]] +; CHECK-NEXT: %[[#]] = OpSNegate %[[#I16]] %[[#RESULT3_4]] +} diff --git a/test/spv/if.ir b/test/spv/if.ir new file mode 100644 index 00000000..18852109 --- /dev/null +++ b/test/spv/if.ir @@ -0,0 +1,108 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s + +; CHECK: %[[#I32:]] = OpTypeInt 32 0 +; CHECK: %[[#BOOL:]] = OpTypeBool +; CHECK: %[[#BOOL_TRUE:]] = OpConstantTrue +; CHECK: %[[#I32_0:]] = OpConstant %[[#I32]] 0 +; CHECK: %[[#F32:]] = OpTypeFloat 32 +; CHECK: %[[#CST1:]] = OpConstant %[[#F32]] 1 +; CHECK: %[[#CST0:]] = OpConstant %[[#F32]] 0 + +func @if0(%0: i32) { + %c42 = constant 42 -> i32 + %1 = cmp.lt %0, %c42 : i32 + if %1 { + %2 = arith.neg %0 : i32 + } else { + %3 = arith.not %0 : i32 + } +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK: %[[#CONDITION0:]] = OpSLessThan %[[#]] %[[#]] %[[#]] +; CHECK-NEXT: OpSelectionMerge %[[#MERGE_LABEL0:]] None +; CHECK-NEXT: OpBranchConditional %[[#CONDITION0]] %[[#TRUE_LABEL0:]] %[[#FALSE_LABEL0:]] +; CHECK-NEXT: %[[#TRUE_LABEL0]] = OpLabel +; CHECK-NEXT: %[[#]] = OpSNegate %[[#]] %[[#]] +; CHECK-NEXT: OpBranch %[[#MERGE_LABEL0]] +; CHECK-NEXT: %[[#FALSE_LABEL0]] = OpLabel +; CHECK-NEXT: %[[#]] = OpNot %[[#]] %[[#]] +; CHECK-NEXT: OpBranch %[[#MERGE_LABEL0]] +; CHECK-NEXT: %[[#MERGE_LABEL0]] = OpLabel +} + +func @if1() { + %c1 = constant true -> bool + if %c1 -> (){ + yield : + } else { + yield : + } +; Just check that it does not crash +; CHECK: %[[#]] = OpFunction {{.*}} +} + +func @if2(%0: i32) { + %c1 = constant true -> bool + %x = if %c1 -> (i32) { + %1 = if %c1 -> (i32) { + yield %0 : i32 + } else { + %c0 = constant 0 -> i32 + yield %c0 : i32 + } + yield %1 : i32 + } else { + %1 = arith.not %0 : i32 + yield %1 : i32 + } + %y = arith.not %x : i32 +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#PARAM2:]] = OpFunctionParameter %[[#]] +; CHECK-NEXT: %[[#]] = OpLabel +; CHECK-NEXT: OpSelectionMerge %[[#MERGE_LABEL2:]] None +; CHECK-NEXT: OpBranchConditional %[[#BOOL_TRUE]] %[[#THEN_LABEL2:]] %[[#OTHER_LABEL2:]] +; CHECK-NEXT: %[[#THEN_LABEL2]] = OpLabel +; CHECK-NEXT: OpSelectionMerge %[[#NESTED_MERGE_LABEL2:]] None +; CHECK-NEXT: OpBranchConditional %[[#BOOL_TRUE]] %[[#NESTED_THEN_LABEL2:]] %[[#NESTED_OTHER_LABEL2:]] +; CHECK-NEXT: %[[#NESTED_THEN_LABEL2]] = OpLabel +; CHECK-NEXT: OpBranch %[[#NESTED_MERGE_LABEL2]] +; CHECK-NEXT: %[[#NESTED_OTHER_LABEL2]] = OpLabel +; CHECK-NEXT: OpBranch %[[#NESTED_MERGE_LABEL2]] +; CHECK-NEXT: %[[#NESTED_MERGE_LABEL2]] = OpLabel +; CHECK-NEXT: %[[#NESTED_PHI2:]] = OpPhi %[[#I32]] %[[#PARAM2]] %[[#NESTED_THEN_LABEL2]] %[[#I32_0]] %[[#NESTED_OTHER_LABEL2]] +; CHECK-NEXT: OpBranch %[[#MERGE_LABEL2]] +; CHECK-NEXT: %[[#OTHER_LABEL2]] = OpLabel +; CHECK-NEXT: %[[#NOT_PARAM2:]] = OpNot %[[#I32]] %[[#PARAM2]] +; CHECK-NEXT: OpBranch %[[#MERGE_LABEL2]] +; CHECK-NEXT: %[[#MERGE_LABEL2]] = OpLabel +; CHECK-NEXT: %[[#PHI2:]] = OpPhi %[[#I32]] %[[#NESTED_PHI2]] %[[#NESTED_MERGE_LABEL2]] %[[#NOT_PARAM2]] %[[#OTHER_LABEL2]] +; CHECK-NEXT: %[[#]] = OpNot %[[#I32]] %[[#PHI2]] +} + +func @if3() subgroup_size(16) { + %c1 = constant true -> bool + %y, %x = if %c1 -> (bool,coopmatrix) { + %0 = constant 1.0 -> coopmatrix + yield %c1, %0 : bool, coopmatrix + } else { + %1 = constant 0.0 -> coopmatrix + yield %c1, %1 : bool, coopmatrix + } + %z = arith.neg %x : coopmatrix +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#]] = OpLabel +; CHECK: %[[#TRUE_LABEL3:]] = OpLabel +; CHECK: %[[#FALSE_LABEL3:]] = OpLabel +; CHECK: %[[#]] = OpLabel +; CHECK-NEXT: %[[#PHI3_0:]] = OpPhi %[[#BOOL]] %[[#BOOL_TRUE]] %[[#TRUE_LABEL3]] %[[#BOOL_TRUE]] %[[#FALSE_LABEL3]] +; CHECK-NEXT: %[[#PHI3_1:]] = OpPhi %[[#F32]] %[[#CST1]] %[[#TRUE_LABEL3]] %[[#CST0]] %[[#FALSE_LABEL3]] +; CHECK-NEXT: %[[#PHI3_2:]] = OpPhi %[[#F32]] %[[#CST1]] %[[#TRUE_LABEL3]] %[[#CST0]] %[[#FALSE_LABEL3]] +; CHECK-NEXT: %[[#PHI3_3:]] = OpPhi %[[#F32]] %[[#CST1]] %[[#TRUE_LABEL3]] %[[#CST0]] %[[#FALSE_LABEL3]] +; CHECK-NEXT: %[[#PHI3_4:]] = OpPhi %[[#F32]] %[[#CST1]] %[[#TRUE_LABEL3]] %[[#CST0]] %[[#FALSE_LABEL3]] +; CHECK-NEXT: %[[#]] = OpFNegate %[[#F32]] %[[#PHI3_1]] +; CHECK-NEXT: %[[#]] = OpFNegate %[[#F32]] %[[#PHI3_2]] +; CHECK-NEXT: %[[#]] = OpFNegate %[[#F32]] %[[#PHI3_3]] +; CHECK-NEXT: %[[#]] = OpFNegate %[[#F32]] %[[#PHI3_4]] +} diff --git a/tools/spirvgen/spirvgen.py b/tools/spirvgen/spirvgen.py index 2a38f3c0..5ccdb25c 100755 --- a/tools/spirvgen/spirvgen.py +++ b/tools/spirvgen/spirvgen.py @@ -207,6 +207,9 @@ def generate_op_classes(f, grammar): f.write(','.join(initializer_list)) f.write('{}') for o in operands: + print( + f'inline auto {o.name}() -> {o.kind}& {{ return {o.name}_; }}', + file=f) print( f'inline auto {o.name}() const -> {o.kind} const& {{ return {o.name}_; }}', file=f) From bb80a1374a850cfe1c072176c9d4fefb6f09dc73 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 8 Nov 2024 15:15:41 +0100 Subject: [PATCH 096/297] SPIR-V: Unique constants Signed-off-by: Carsten Uphoff --- src/spv/converter.cpp | 43 ++++++++---------- src/spv/converter.hpp | 8 +++- src/spv/defs.hpp | 58 ++++++++++++++++++++++++ src/spv/instructions.hpp | 37 +--------------- src/spv/module.cpp | 2 +- src/spv/module.hpp | 1 - src/spv/pass/dump_asm.hpp | 1 + src/spv/uniquifier.cpp | 14 +++--- src/spv/uniquifier.hpp | 7 +-- test/spv/for.ir | 8 ++-- tools/spirvgen/spirvgen.py | 91 +++++++++++++++++++++----------------- 11 files changed, 151 insertions(+), 119 deletions(-) create mode 100644 src/spv/defs.hpp diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index a686ec30..79beedb6 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -16,12 +16,14 @@ #include "spv/opencl.std.hpp" #include "spv/uniquifier.hpp" #include "spv/visit.hpp" +#include "support/ilist.hpp" #include "support/ilist_base.hpp" #include "support/util.hpp" #include "support/visit.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" +#include #include #include #include @@ -131,13 +133,9 @@ auto inst_converter::multi_val(value_node const &v) -> std::vector & auto inst_converter::make_constant(scalar_type sty, spv_inst *spv_ty, constant_inst::value_type const &val) -> spv_inst * { - auto const add_constant = [this, &spv_ty](auto val) -> spv_inst * { - return mod_->add_to(section::type_const_var, spv_ty, val); - }; - auto const add_constant_complex = [this, &spv_ty](spv_inst *spv_float_ty, auto re, - auto im) -> spv_inst * { - auto c_re = mod_->add_to(section::type_const_var, spv_float_ty, re); - auto c_im = mod_->add_to(section::type_const_var, spv_float_ty, im); + auto const add_constant_complex = [this, &spv_ty](auto cst) -> spv_inst * { + auto c_re = unique_.constant(cst.real()); + auto c_im = unique_.constant(cst.imag()); return mod_->add_to(section::type_const_var, spv_ty, std::vector{c_re, c_im}); }; @@ -146,14 +144,14 @@ auto inst_converter::make_constant(scalar_type sty, spv_inst *spv_ty, [&](std::int64_t i) -> spv_inst * { switch (sty) { case scalar_type::i8: - return add_constant(static_cast(i)); + return unique_.constant(static_cast(i)); case scalar_type::i16: - return add_constant(static_cast(i)); + return unique_.constant(static_cast(i)); case scalar_type::i32: - return add_constant(static_cast(i)); + return unique_.constant(static_cast(i)); case scalar_type::i64: case scalar_type::index: - return add_constant(i); + return unique_.constant(i); default: return nullptr; } @@ -161,24 +159,19 @@ auto inst_converter::make_constant(scalar_type sty, spv_inst *spv_ty, [&](double d) -> spv_inst * { switch (sty) { case scalar_type::f32: - return add_constant(static_cast(d)); + return unique_.constant(static_cast(d)); case scalar_type::f64: - return add_constant(d); + return unique_.constant(d); default: return nullptr; } }, [&](std::complex d) -> spv_inst * { switch (sty) { - case scalar_type::c32: { - auto spv_float_ty = unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::f32)); - return add_constant_complex(spv_float_ty, static_cast(d.real()), - static_cast(d.imag())); - } - case scalar_type::c64: { - auto spv_float_ty = unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::f64)); - return add_constant_complex(spv_float_ty, d.real(), d.imag()); - } + case scalar_type::c32: + return add_constant_complex(static_cast>(d)); + case scalar_type::c64: + return add_constant_complex(d); default: return nullptr; } @@ -417,8 +410,8 @@ void inst_converter::operator()(barrier_inst const &in) { fence = fence | static_cast(MemorySemantics::WorkgroupMemory) | static_cast(MemorySemantics::SequentiallyConsistent); } - auto scope = unique_.i32_constant(static_cast(Scope::Workgroup)); - auto memory_semantics = unique_.i32_constant(fence); + auto scope = unique_.constant(static_cast(Scope::Workgroup)); + auto memory_semantics = unique_.constant(fence); mod_->add(scope, scope, memory_semantics); } @@ -832,7 +825,7 @@ void inst_converter::operator()(subgroup_size_inst const &in) { void inst_converter::operator()(work_group_inst const &in) { auto const make = [&](scalar_type sty, work_group_operation operation, spv_inst *spv_ty, spv_inst *operand) -> spv_inst * { - auto scope = unique_.i32_constant(static_cast(Scope::Workgroup)); + auto scope = unique_.constant(static_cast(Scope::Workgroup)); if (operation == work_group_operation::reduce_add) { switch (sty) { case scalar_type::i8: diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index 4f23b200..8310ec13 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -3,13 +3,19 @@ #include "compiler_context.hpp" #include "device_info.hpp" +#include "node/data_type_node.hpp" #include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "spv/defs.hpp" +#include "spv/enums.hpp" #include "spv/module.hpp" #include "spv/uniquifier.hpp" #include "support/casting.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" +#include #include #include #include @@ -17,8 +23,6 @@ namespace tinytc::spv { -class spv_inst; - auto convert_prog_to_spirv(tinytc_prog const &p, tinytc_core_info const &info) -> std::unique_ptr; diff --git a/src/spv/defs.hpp b/src/spv/defs.hpp new file mode 100644 index 00000000..d01f558b --- /dev/null +++ b/src/spv/defs.hpp @@ -0,0 +1,58 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_DEFS_2024118_HPP +#define GENERATED_DEFS_2024118_HPP + +#include "enums.hpp" +#include "support/ilist_base.hpp" + +#include +#include +#include +#include + +namespace tinytc::spv { + +class spv_inst : public ilist_node { + public: + inline spv_inst(Op opcode, bool has_result_id) + : opcode_{opcode}, has_result_id_{has_result_id} {} + virtual ~spv_inst() = default; + + spv_inst(spv_inst const &other) = delete; + spv_inst(spv_inst &&other) = delete; + spv_inst &operator=(spv_inst const &other) = delete; + spv_inst &operator=(spv_inst &&other) = delete; + + inline auto opcode() const -> Op { return opcode_; } + inline auto has_result_id() const -> bool { return has_result_id_; } + + private: + Op opcode_; + bool has_result_id_; +}; + +using DecorationAttr = std::variant>; +using ExecutionModeAttr = std::variant>; +using LiteralContextDependentNumber = + std::variant; +using LiteralString = std::string; +using LiteralInteger = std::int32_t; +using LiteralExtInstInteger = std::int32_t; +using IdResultType = spv_inst *; +using IdRef = spv_inst *; +using IdScope = spv_inst *; +using IdMemorySemantics = spv_inst *; +using MemoryAccessAttr = std::int32_t; +using PairIdRefIdRef = std::pair; +using PairLiteralIntegerIdRef = + std::pair, spv_inst *>; +using PairIdRefLiteralInteger = std::pair; + +} // namespace tinytc::spv + +#endif // GENERATED_DEFS_2024118_HPP diff --git a/src/spv/instructions.hpp b/src/spv/instructions.hpp index 21db47a7..f9a9b2d8 100644 --- a/src/spv/instructions.hpp +++ b/src/spv/instructions.hpp @@ -7,6 +7,7 @@ #ifndef GENERATED_INSTRUCTIONS_2024118_HPP #define GENERATED_INSTRUCTIONS_2024118_HPP +#include "defs.hpp" #include "enums.hpp" #include "error.hpp" #include "support/ilist_base.hpp" @@ -21,42 +22,6 @@ namespace tinytc::spv { -class spv_inst : public ilist_node { - public: - inline spv_inst(Op opcode, bool has_result_id) - : opcode_{opcode}, has_result_id_{has_result_id} {} - virtual ~spv_inst() = default; - - spv_inst(spv_inst const &other) = delete; - spv_inst(spv_inst &&other) = delete; - spv_inst &operator=(spv_inst const &other) = delete; - spv_inst &operator=(spv_inst &&other) = delete; - - inline auto opcode() const -> Op { return opcode_; } - inline auto has_result_id() const -> bool { return has_result_id_; } - - private: - Op opcode_; - bool has_result_id_; -}; - -using DecorationAttr = std::variant>; -using ExecutionModeAttr = std::variant>; -using LiteralContextDependentNumber = - std::variant; -using LiteralString = std::string; -using LiteralInteger = std::int32_t; -using LiteralExtInstInteger = std::int32_t; -using IdResultType = spv_inst *; -using IdRef = spv_inst *; -using IdScope = spv_inst *; -using IdMemorySemantics = spv_inst *; -using MemoryAccessAttr = std::int32_t; -using PairIdRefIdRef = std::pair; -using PairLiteralIntegerIdRef = - std::pair, spv_inst *>; -using PairIdRefLiteralInteger = std::pair; - class OpNop : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Nop; } diff --git a/src/spv/module.cpp b/src/spv/module.cpp index 03bb9511..60a83361 100644 --- a/src/spv/module.cpp +++ b/src/spv/module.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause #include "spv/module.hpp" -#include "spv/instructions.hpp" +#include "spv/defs.hpp" #include "support/ilist_base.hpp" namespace tinytc { diff --git a/src/spv/module.hpp b/src/spv/module.hpp index 1d061688..e7c2ab7a 100644 --- a/src/spv/module.hpp +++ b/src/spv/module.hpp @@ -7,7 +7,6 @@ #include "support/ilist.hpp" #include -#include #include #include #include diff --git a/src/spv/pass/dump_asm.hpp b/src/spv/pass/dump_asm.hpp index f1082090..9b4de2f2 100644 --- a/src/spv/pass/dump_asm.hpp +++ b/src/spv/pass/dump_asm.hpp @@ -4,6 +4,7 @@ #ifndef DUMP_ASM_20241029_HPP #define DUMP_ASM_20241029_HPP +#include "spv/defs.hpp" #include "spv/instructions.hpp" #include "spv/names.hpp" #include "spv/visit.hpp" diff --git a/src/spv/uniquifier.cpp b/src/spv/uniquifier.cpp index 96b144d6..51882ede 100644 --- a/src/spv/uniquifier.cpp +++ b/src/spv/uniquifier.cpp @@ -13,6 +13,8 @@ #include #include +#include +#include #include namespace tinytc::spv { @@ -113,11 +115,13 @@ void uniquifier::capability(Capability cap) { } } -auto uniquifier::i32_constant(std::int32_t cst) -> spv_inst * { - return lookup(i32_cst_, cst, [&](std::int32_t cst) { - auto i32_ty = spv_ty(scalar_data_type::get(ctx_, scalar_type::i32)); - return mod_->add_to(section::type_const_var, i32_ty, - LiteralContextDependentNumber{cst}); +auto uniquifier::constant(LiteralContextDependentNumber cst) -> spv_inst * { + return lookup(cst_map_, cst, [&](LiteralContextDependentNumber const &cst) { + scalar_type sty = std::visit( + overloaded{[](auto const &c) { return to_scalar_type_v>; }}, + cst); + auto ty = spv_ty(scalar_data_type::get(ctx_, sty)); + return mod_->add_to(section::type_const_var, ty, cst); }); } diff --git a/src/spv/uniquifier.hpp b/src/spv/uniquifier.hpp index 51c387b7..f65d634a 100644 --- a/src/spv/uniquifier.hpp +++ b/src/spv/uniquifier.hpp @@ -4,12 +4,14 @@ #ifndef UNIQUIFIER_20241107_HPP #define UNIQUIFIER_20241107_HPP +#include "spv/defs.hpp" #include "spv/enums.hpp" #include "spv/module.hpp" #include "support/fnv1a.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" +#include #include #include #include @@ -17,7 +19,6 @@ namespace tinytc::spv { -class spv_inst; class OpTypeFunction; class uniquifier { @@ -30,7 +31,7 @@ class uniquifier { auto builtin_pointee_ty(BuiltIn b) -> spv_inst *; auto builtin_var(BuiltIn b) -> spv_inst *; void capability(Capability cap); - auto i32_constant(std::int32_t cst) -> spv_inst *; + auto constant(LiteralContextDependentNumber cst) -> spv_inst *; auto index3_ty() -> spv_inst *; auto null_constant(spv_inst *spv_ty) -> spv_inst *; auto opencl_ext() -> spv_inst *; @@ -70,7 +71,7 @@ class uniquifier { spv_inst *opencl_ext_ = nullptr; std::unordered_map builtin_; std::unordered_set capabilities_; - std::unordered_map i32_cst_; + std::unordered_map cst_map_; std::unordered_map null_cst_; std::unordered_multimap spv_function_tys_; std::unordered_map, spv_inst *, pointer_key_hash> diff --git a/test/spv/for.ir b/test/spv/for.ir index a1aef989..6b803d28 100644 --- a/test/spv/for.ir +++ b/test/spv/for.ir @@ -13,10 +13,8 @@ ; CHECK: %[[#I64_C0:]] = OpConstant %[[#I64]] 0 ; CHECK: %[[#I64_C1:]] = OpConstant %[[#I64]] 1 ; CHECK: %[[#I32_C1:]] = OpConstant %[[#I32]] 1 -; CHECK: %[[#I16_C2_0:]] = OpConstant %[[#I16]] 2 ; CHECK: %[[#I16_C6:]] = OpConstant %[[#I16]] 6 ; CHECK: %[[#I16_C1:]] = OpConstant %[[#I16]] 1 -; CHECK: %[[#I16_C1_0:]] = OpConstant %[[#I16]] 1 ; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s func @for1() { @@ -86,10 +84,10 @@ func @for3() subgroup_size(16) { ; CHECK: OpLoopMerge %[[#MERGE_LABEL3:]] %[[#CONT_LABEL3:]] None ; CHECK-NEXT: OpBranch %[[#HEAD_LABEL3:]] ; CHECK-NEXT: %[[#HEAD_LABEL3]] = OpLabel -; CHECK-NEXT: %[[#CONDITION3_1:]] = OpSLessThan %[[#BOOL]] %[[#I16_C2_0]] %[[#I16_C6]] +; CHECK-NEXT: %[[#CONDITION3_1:]] = OpSLessThan %[[#BOOL]] %[[#I16_C2]] %[[#I16_C6]] ; CHECK-NEXT: OpBranchConditional %[[#CONDITION3_1]] %[[#BODY_LABEL3:]] %[[#MERGE_LABEL3]] ; CHECK-NEXT: %[[#BODY_LABEL3]] = OpLabel -; CHECK-NEXT: %[[#LOOP_VAR3:]] = OpPhi %[[#I16]] %[[#I16_C2_0]] %[[#HEAD_LABEL3]] %[[#LOOP_VAR_UPDATE3:]] %[[CONT_LABEL3]] +; CHECK-NEXT: %[[#LOOP_VAR3:]] = OpPhi %[[#I16]] %[[#I16_C2]] %[[#HEAD_LABEL3]] %[[#LOOP_VAR_UPDATE3:]] %[[CONT_LABEL3]] ; CHECK-NEXT: %[[#PHI3_1:]] = OpPhi %[[#I16]] %[[#I16_C1]] %[[#HEAD_LABEL3]] %[[#PHI_UPDATE3_1:]] %[[CONT_LABEL3]] ; CHECK-NEXT: %[[#PHI3_2:]] = OpPhi %[[#I16]] %[[#I16_C1]] %[[#HEAD_LABEL3]] %[[#PHI_UPDATE3_2:]] %[[CONT_LABEL3]] ; CHECK-NEXT: %[[#PHI3_3:]] = OpPhi %[[#I16]] %[[#I16_C1]] %[[#HEAD_LABEL3]] %[[#PHI_UPDATE3_3:]] %[[CONT_LABEL3]] @@ -100,7 +98,7 @@ func @for3() subgroup_size(16) { ; CHECK-NEXT: %[[#PHI_UPDATE3_4:]] = OpIAdd %[[#I16]] %[[#PHI3_4]] %[[#I16_C1]] ; CHECK-NEXT: OpBranch %[[#CONT_LABEL3]] ; CHECK-NEXT: %[[#CONT_LABEL3]] = OpLabel -; CHECK-NEXT: %[[#LOOP_VAR_UPDATE3]] = OpIAdd %[[#I16]] %[[#LOOP_VAR3]] %[[#I16_C1_0]] +; CHECK-NEXT: %[[#LOOP_VAR_UPDATE3]] = OpIAdd %[[#I16]] %[[#LOOP_VAR3]] %[[#I16_C1]] ; CHECK-NEXT: %[[#CONDITION3_2:]] = OpSLessThan %[[#BOOL]] %[[#LOOP_VAR_UPDATE3]] %[[#I16_C6]] ; CHECK-NEXT: OpBranchConditional %[[#CONDITION3_2]] %[[#BODY_LABEL3]] %[[#MERGE_LABEL3]] ; CHECK-NEXT: %[[#MERGE_LABEL3]] = OpLabel diff --git a/tools/spirvgen/spirvgen.py b/tools/spirvgen/spirvgen.py index 5ccdb25c..42c6771b 100755 --- a/tools/spirvgen/spirvgen.py +++ b/tools/spirvgen/spirvgen.py @@ -18,11 +18,17 @@ spv_names_includes = [spv_enums] spv_names_cpp = 'names.cpp' spv_names_cpp_includes = [spv_names, spv_enums] +spv_defs = 'defs.hpp' +spv_defs_includes = [ + spv_enums, 'support/ilist_base.hpp', None, '', '', + '', '' +] spv_ops = 'instructions.hpp' spv_visitor = 'visit.hpp' spv_ops_includes = [ - spv_enums, 'error.hpp', 'support/ilist_base.hpp', None, '', - '', '', '', '', '', '' + spv_defs, spv_enums, 'error.hpp', 'support/ilist_base.hpp', None, + '', '', '', '', '', '', + '' ] enumerant_subs = { @@ -32,43 +38,6 @@ '2x2': 'CooperativeMatrixReduce2x2' } -spv_inst_class = """ -class spv_inst : public ilist_node { - public: - inline spv_inst(Op opcode, bool has_result_id) : opcode_{opcode}, has_result_id_{has_result_id} {} - virtual ~spv_inst() = default; - - spv_inst(spv_inst const &other) = delete; - spv_inst(spv_inst &&other) = delete; - spv_inst &operator=(spv_inst const &other) = delete; - spv_inst &operator=(spv_inst &&other) = delete; - - inline auto opcode() const -> Op { return opcode_; } - inline auto has_result_id() const -> bool { return has_result_id_; } - - private: - Op opcode_; - bool has_result_id_; -}; - -using DecorationAttr = std::variant>; -using ExecutionModeAttr = std::variant>; -using LiteralContextDependentNumber - = std::variant; -using LiteralString = std::string; -using LiteralInteger = std::int32_t; -using LiteralExtInstInteger = std::int32_t; -using IdResultType = spv_inst*; -using IdRef = spv_inst*; -using IdScope = spv_inst*; -using IdMemorySemantics = spv_inst*; -using MemoryAccessAttr = std::int32_t; -using PairIdRefIdRef = std::pair; -using PairLiteralIntegerIdRef - = std::pair, spv_inst*>; -using PairIdRefLiteralInteger = std::pair; -""" - def get_opcode_name(instruction): return instruction['opname'][2:] @@ -174,9 +143,47 @@ def get_operands(instruction): return operands -def generate_op_classes(f, grammar): - print(spv_inst_class, file=f) +def generate_defs(f, grammar): + print(""" +class spv_inst : public ilist_node { + public: + inline spv_inst(Op opcode, bool has_result_id) : opcode_{opcode}, has_result_id_{has_result_id} {} + virtual ~spv_inst() = default; + + spv_inst(spv_inst const &other) = delete; + spv_inst(spv_inst &&other) = delete; + spv_inst &operator=(spv_inst const &other) = delete; + spv_inst &operator=(spv_inst &&other) = delete; + inline auto opcode() const -> Op { return opcode_; } + inline auto has_result_id() const -> bool { return has_result_id_; } + + private: + Op opcode_; + bool has_result_id_; +}; + +using DecorationAttr = std::variant>; +using ExecutionModeAttr = std::variant>; +using LiteralContextDependentNumber + = std::variant; +using LiteralString = std::string; +using LiteralInteger = std::int32_t; +using LiteralExtInstInteger = std::int32_t; +using IdResultType = spv_inst*; +using IdRef = spv_inst*; +using IdScope = spv_inst*; +using IdMemorySemantics = spv_inst*; +using MemoryAccessAttr = std::int32_t; +using PairIdRefIdRef = std::pair; +using PairLiteralIntegerIdRef + = std::pair, spv_inst*>; +using PairIdRefLiteralInteger = std::pair; +""", + file=f) + + +def generate_op_classes(f, grammar): for instruction in grammar['instructions']: operands = get_operands(instruction) @@ -326,6 +333,8 @@ def patch_grammar(grammar): spv_names_includes) generate_cpp(args, spv_names_cpp, grammar, generate_names_cpp, spv_names_cpp_includes) + generate_header(args, spv_defs, grammar, generate_defs, + spv_defs_includes) generate_header(args, spv_ops, grammar, generate_op_classes, spv_ops_includes) generate_header(args, spv_visitor, grammar, generate_visitor) From f73d0877ff9e2ff117c5b6a7e34493936ef75928 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 8 Nov 2024 18:04:14 +0100 Subject: [PATCH 097/297] SPIR-V: Calling convention and dope vector Signed-off-by: Carsten Uphoff --- src/spv/converter.cpp | 137 ++++++++++++++++++++++++++++++++- src/spv/converter.hpp | 40 +++++++++- src/spv/defs.hpp | 2 +- src/spv/pass/dump_asm.cpp | 2 + src/spv/uniquifier.cpp | 37 +++++++-- src/spv/uniquifier.hpp | 14 ++-- test/CMakeLists.txt | 1 + test/spv/calling_convention.ir | 50 ++++++++++++ tools/spirvgen/spirvgen.py | 2 +- 9 files changed, 266 insertions(+), 19 deletions(-) create mode 100644 test/spv/calling_convention.ir diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index 79beedb6..ed1fa337 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -20,6 +20,7 @@ #include "support/ilist_base.hpp" #include "support/util.hpp" #include "support/visit.hpp" +#include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" @@ -27,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -73,6 +75,30 @@ auto convert_prog_to_spirv(tinytc_prog const &p, return m; } +dope_vector::dope_vector(spv_inst *ty, std::vector static_shape, + std::vector static_stride, spv_inst *offset_ty, + std::int64_t static_offset) + : ty_(ty), static_shape_(std::move(static_shape)), static_stride_(std::move(static_stride)), + shape_(dim(), nullptr), stride_(dim(), nullptr), offset_ty_(offset_ty), + static_offset_(static_offset) { + if (static_shape_.size() != static_stride_.size()) { + throw status::internal_compiler_error; + } +} + +auto dope_vector::num_dynamic() const -> std::int64_t { + auto const sum_dynamic = [](std::vector const &vec) { + std::int64_t num_dynamic = 0; + for (auto &v : vec) { + if (is_dynamic_value(v)) { + ++num_dynamic; + } + } + return num_dynamic; + }; + return sum_dynamic(static_shape_) + sum_dynamic(static_stride_); +} + inst_converter::inst_converter(tinytc_compiler_context_t ctx, mod &m) : ctx_(ctx), mod_(&m), unique_(ctx, m) {} @@ -88,14 +114,21 @@ auto inst_converter::get_last_label() -> spv_inst * { return nullptr; } -auto inst_converter::get_scalar_type(value_node const &v) -> scalar_type { +auto inst_converter::get_dope_vector(tinytc_value const &v) -> dope_vector * { + if (auto it = dope_vec_.find(&v); it != dope_vec_.end()) { + return &it->second; + } + return nullptr; +} + +auto inst_converter::get_scalar_type(value_node const &v) const -> scalar_type { auto st = dyn_cast(v.ty()); if (!st) { throw compilation_error(v.loc(), status::ir_expected_scalar); } return st->ty(); } -auto inst_converter::get_coopmatrix_type(value_node const &v) -> scalar_type { +auto inst_converter::get_coopmatrix_type(value_node const &v) const -> scalar_type { auto ct = dyn_cast(v.ty()); if (!ct) { throw compilation_error(v.loc(), status::ir_expected_coopmatrix); @@ -178,7 +211,33 @@ auto inst_converter::make_constant(scalar_type sty, spv_inst *spv_ty, }, }; return std::visit(visitor, val); -}; +} + +auto inst_converter::make_dope_vector(tinytc_value const &v) -> dope_vector * { + if (dope_vec_.contains(&v)) { + throw compilation_error(v.loc(), status::internal_compiler_error); + } + + auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); + return ::tinytc::visit( + overloaded{[&](memref_data_type const &mr) -> dope_vector * { + return &(dope_vec_[&v] = dope_vector{spv_index_ty, mr.shape(), mr.stride()}); + }, + [&](group_data_type const &g) -> dope_vector * { + if (auto mt = dyn_cast(g.ty()); mt) { + auto pointer_ty = + unique_.spv_pointer_ty(StorageClass::CrossWorkgroup, spv_index_ty, + alignment(scalar_type::i64)); + return &(dope_vec_[&v] = + dope_vector{pointer_ty, mt->shape(), mt->stride(), + spv_index_ty, g.offset()}); + } else { + throw compilation_error(v.loc(), status::ir_expected_memref); + } + }, + [](auto const &) -> dope_vector * { return nullptr; }}, + *v.ty()); +} void inst_converter::operator()(inst_node const &in) { // @todo @@ -806,6 +865,53 @@ void inst_converter::operator()(if_inst const &in) { } } +void inst_converter::operator()(load_inst const &in) { + auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); + auto spv_pointer_index_ty = unique_.spv_pointer_ty(StorageClass::CrossWorkgroup, spv_index_ty, + alignment(scalar_type::i64)); + auto spv_pointer_ty = unique_.spv_ty(in.operand().ty()); + auto spv_result_ty = unique_.spv_ty(in.result(0).ty()); + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + + if (auto group_ty = dyn_cast(in.operand().ty()); group_ty) { + auto offset = mod_->add(spv_index_ty, dv->offset(), val(in.index_list()[0])); + auto pointer = mod_->add(spv_pointer_ty, val(in.operand()), + offset, std::vector{}); + declare(in.result(0), mod_->add(spv_result_ty, pointer)); + auto rdv = make_dope_vector(in.result(0)); + + auto const make_dope_par = [&](std::int64_t static_s, spv_inst *s) -> spv_inst * { + if (is_dynamic_value(static_s)) { + auto pointer = mod_->add(spv_pointer_index_ty, s, offset, + std::vector{}); + return mod_->add(spv_index_ty, pointer); + } + return s; + }; + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->shape(i, make_dope_par(dv->static_shape(i), dv->shape(i))); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->stride(i, make_dope_par(dv->static_stride(i), dv->stride(i))); + } + } else if (auto memref_ty = dyn_cast(in.operand().ty()); memref_ty) { + auto offset = unique_.null_constant(spv_index_ty); + for (std::int64_t i = 0; i < memref_ty->dim(); ++i) { + auto tmp = mod_->add(spv_index_ty, val(in.index_list()[i]), dv->stride(i)); + offset = mod_->add(spv_index_ty, offset, tmp); + } + + auto pointer = mod_->add(spv_pointer_ty, val(in.operand()), + offset, std::vector{}); + declare(in.result(0), mod_->add(spv_result_ty, pointer)); + } else { + throw compilation_error(in.loc(), status::ir_expected_memref_or_group); + } +} + void inst_converter::operator()(num_subgroups_inst const &in) { declare(in.result(0), load_builtin(BuiltIn::NumSubgroups)); } @@ -902,6 +1008,15 @@ void inst_converter::run_on_function(function_node const &fn, core_config const params.reserve(fn.num_params()); for (auto const &p : fn.params()) { params.emplace_back(unique_.spv_ty(p.ty())); + auto dv = make_dope_vector(p); + if (dv) { + for (std::int64_t i = 0; i < dv->num_dynamic(); ++i) { + params.emplace_back(dv->ty()); + } + if (is_dynamic_value(dv->static_offset())) { + params.emplace_back(dv->offset_ty()); + } + } } return params; }()); @@ -911,6 +1026,22 @@ void inst_converter::run_on_function(function_node const &fn, core_config const auto fun = mod_->add(void_ty, FunctionControl::None, fun_ty); for (auto const &p : fn.params()) { declare(p, mod_->add(unique_.spv_ty(p.ty()))); + auto dv = get_dope_vector(p); + if (dv) { + auto const make_dope_par = [&](spv_inst *ty, std::int64_t s) { + return is_dynamic_value(s) ? mod_->add(ty) + : unique_.constant(s); + }; + for (std::int64_t i = 0; i < dv->dim(); ++i) { + dv->shape(i, make_dope_par(dv->ty(), dv->static_shape(i))); + } + for (std::int64_t i = 0; i < dv->dim(); ++i) { + dv->stride(i, make_dope_par(dv->ty(), dv->static_stride(i))); + } + if (dv->offset_ty()) { + dv->offset(make_dope_par(dv->offset_ty(), dv->static_offset())); + } + } } mod_->add(); run_on_region(fn.body()); diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index 8310ec13..9be1ec3f 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -26,6 +26,38 @@ namespace tinytc::spv { auto convert_prog_to_spirv(tinytc_prog const &p, tinytc_core_info const &info) -> std::unique_ptr; +class dope_vector { + public: + dope_vector() = default; + dope_vector(spv_inst *ty, std::vector static_shape, + std::vector static_stride, spv_inst *offset_ty = nullptr, + std::int64_t static_offset = 0); + + inline auto dim() const -> std::int64_t { return static_shape_.size(); } + inline auto ty() const -> spv_inst * { return ty_; } + inline auto static_shape(std::int64_t i) const -> std::int64_t { return static_shape_[i]; } + inline auto static_stride(std::int64_t i) const -> std::int64_t { return static_stride_[i]; } + inline auto shape(std::int64_t i) const -> spv_inst * { return shape_[i]; } + inline auto stride(std::int64_t i) const -> spv_inst * { return stride_[i]; } + inline void shape(std::int64_t i, spv_inst *s) { shape_[i] = s; } + inline void stride(std::int64_t i, spv_inst *s) { stride_[i] = s; } + + inline auto offset_ty() const -> spv_inst * { return offset_ty_; } + inline auto static_offset() const -> std::int64_t { return static_offset_; } + inline auto offset() -> spv_inst * { return offset_; } + inline void offset(spv_inst *offset) { offset_ = offset; } + + auto num_dynamic() const -> std::int64_t; + + private: + spv_inst *ty_ = nullptr; + std::vector static_shape_, static_stride_; + std::vector shape_, stride_; + spv_inst *offset_ty_ = nullptr; + std::int64_t static_offset_; + spv_inst *offset_ = nullptr; +}; + class inst_converter { public: inst_converter(tinytc_compiler_context_t ctx, mod &m); @@ -42,6 +74,7 @@ class inst_converter { void operator()(group_id_inst const &in); void operator()(group_size_inst const &in); void operator()(if_inst const &in); + void operator()(load_inst const &in); void operator()(num_subgroups_inst const &in); void operator()(parallel_inst const &in); void operator()(subgroup_id_inst const &in); @@ -71,8 +104,9 @@ class inst_converter { return num_results; } auto get_last_label() -> spv_inst *; - auto get_scalar_type(tinytc_value const &v) -> scalar_type; - auto get_coopmatrix_type(tinytc_value const &v) -> scalar_type; + auto get_dope_vector(tinytc_value const &v) -> dope_vector *; + auto get_scalar_type(tinytc_value const &v) const -> scalar_type; + auto get_coopmatrix_type(tinytc_value const &v) const -> scalar_type; auto load_builtin(BuiltIn b) -> spv_inst *; auto declare(tinytc_value const &v, spv_inst *in); auto val(tinytc_value const &v) -> spv_inst *; @@ -80,10 +114,12 @@ class inst_converter { auto multi_val(tinytc_value const &v) -> std::vector &; auto make_constant(scalar_type sty, spv_inst *spv_ty, constant_inst::value_type const &val) -> spv_inst *; + auto make_dope_vector(tinytc_value const &v) -> dope_vector *; tinytc_compiler_context_t ctx_; mod *mod_; uniquifier unique_; + std::unordered_map dope_vec_; std::unordered_map vals_; std::unordered_map> multi_vals_; std::stack> yielded_vals_; diff --git a/src/spv/defs.hpp b/src/spv/defs.hpp index d01f558b..b101e998 100644 --- a/src/spv/defs.hpp +++ b/src/spv/defs.hpp @@ -36,7 +36,7 @@ class spv_inst : public ilist_node { bool has_result_id_; }; -using DecorationAttr = std::variant>; +using DecorationAttr = std::variant>; using ExecutionModeAttr = std::variant>; using LiteralContextDependentNumber = std::variant; diff --git a/src/spv/pass/dump_asm.cpp b/src/spv/pass/dump_asm.cpp index 577ffa55..acabfaa7 100644 --- a/src/spv/pass/dump_asm.cpp +++ b/src/spv/pass/dump_asm.cpp @@ -113,6 +113,8 @@ void dump_asm_pass::operator()(spv_inst *const &in) { *os_ << " %" << declare(in); } else if (isa(*in)) { *os_ << " %" << declare(in); + } else if (isa(*in)) { + *os_ << " %" << declare(in); } else { throw status::spirv_forbidden_forward_declaration; } diff --git a/src/spv/uniquifier.cpp b/src/spv/uniquifier.cpp index 51882ede..83cd9c32 100644 --- a/src/spv/uniquifier.cpp +++ b/src/spv/uniquifier.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -99,7 +100,8 @@ auto uniquifier::builtin_pointee_ty(BuiltIn b) -> spv_inst * { auto uniquifier::builtin_var(BuiltIn b) -> spv_inst * { return lookup(builtin_, b, [&](BuiltIn b) { - auto pointer_ty = spv_pointer_ty(StorageClass::Input, builtin_pointee_ty(b)); + auto pointer_ty = + spv_pointer_ty(StorageClass::Input, builtin_pointee_ty(b), builtin_alignment(b)); auto var = mod_->add_to(section::type_const_var, pointer_ty, StorageClass::Input, std::nullopt); mod_->add_to(section::decoration, var, Decoration::Constant); @@ -159,11 +161,17 @@ auto uniquifier::spv_function_ty(array_view params) -> spv_inst * { ->second; } -auto uniquifier::spv_pointer_ty(StorageClass cls, spv_inst *pointee_ty) -> spv_inst * { - auto key = std::make_pair(cls, pointee_ty); - return lookup(spv_pointer_tys_, key, [&](std::pair const &key) { - return mod_->add_to(section::type_const_var, key.first, key.second); - }); +auto uniquifier::spv_pointer_ty(StorageClass cls, spv_inst *pointee_ty, + std::int32_t alignment) -> spv_inst * { + auto key = std::make_tuple(cls, pointee_ty, alignment); + return lookup( + spv_pointer_tys_, key, [&](std::tuple const &key) { + auto pointer_ty = mod_->add_to(section::type_const_var, std::get<0>(key), + std::get<1>(key)); + mod_->add_to(section::decoration, pointer_ty, Decoration::Alignment, + DecorationAttr{std::get<2>(key)}); + return pointer_ty; + }); } auto uniquifier::spv_ty(const_tinytc_data_type_t ty) -> spv_inst * { @@ -176,6 +184,23 @@ auto uniquifier::spv_ty(const_tinytc_data_type_t ty) -> spv_inst * { [&](boolean_data_type const &) -> spv_inst * { return mod_->add_to(section::type_const_var); }, + [&](group_data_type const &g) -> spv_inst * { + return spv_pointer_ty(StorageClass::CrossWorkgroup, spv_ty(g.ty()), + alignment(scalar_type::i64)); + }, + [&](memref_data_type const &mr) -> spv_inst * { + const auto storage_cls = mr.addrspace() == address_space::local + ? StorageClass::Workgroup + : StorageClass::CrossWorkgroup; + auto spv_element_ty = spv_ty(mr.element_data_ty()); + const std::int32_t align = [&](scalar_type sty) -> std::int32_t { + if (is_complex_type(sty)) { + return alignment(element_type(sty), component_count::v2); + } + return alignment(sty); + }(mr.element_ty()); + return spv_pointer_ty(storage_cls, spv_element_ty, align); + }, [&](scalar_data_type const &ty) -> spv_inst * { switch (ty.ty()) { case scalar_type::i8: diff --git a/src/spv/uniquifier.hpp b/src/spv/uniquifier.hpp index f65d634a..eca69dd6 100644 --- a/src/spv/uniquifier.hpp +++ b/src/spv/uniquifier.hpp @@ -13,9 +13,9 @@ #include #include +#include #include #include -#include namespace tinytc::spv { @@ -36,7 +36,8 @@ class uniquifier { auto null_constant(spv_inst *spv_ty) -> spv_inst *; auto opencl_ext() -> spv_inst *; auto spv_function_ty(array_view params) -> spv_inst *; - auto spv_pointer_ty(StorageClass cls, spv_inst *pointee_ty) -> spv_inst *; + auto spv_pointer_ty(StorageClass cls, spv_inst *pointee_ty, + std::int32_t alignment) -> spv_inst *; auto spv_ty(const_tinytc_data_type_t ty) -> spv_inst *; private: @@ -57,9 +58,9 @@ class uniquifier { } struct pointer_key_hash { - inline auto - operator()(std::pair const &key) const -> std::size_t { - return fnv1a_combine(key.first, key.second); + inline auto operator()(std::tuple const &key) const + -> std::size_t { + return fnv1a_combine(std::get<0>(key), std::get<1>(key), std::get<2>(key)); } }; @@ -74,7 +75,8 @@ class uniquifier { std::unordered_map cst_map_; std::unordered_map null_cst_; std::unordered_multimap spv_function_tys_; - std::unordered_map, spv_inst *, pointer_key_hash> + std::unordered_map, spv_inst *, + pointer_key_hash> spv_pointer_tys_; std::unordered_map spv_tys_; }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 97e0b177..18b0d670 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -56,6 +56,7 @@ if(SPIRVTools_FOUND) spv/barrier.ir spv/builtin.ir spv/cast.ir + spv/calling_convention.ir spv/compare.ir spv/for.ir spv/if.ir diff --git a/test/spv/calling_convention.ir b/test/spv/calling_convention.ir new file mode 100644 index 00000000..f70fd5d7 --- /dev/null +++ b/test/spv/calling_convention.ir @@ -0,0 +1,50 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s + +; CHECK: OpDecorate %[[#PTR_F32:]] Alignment 4 +; CHECK: OpDecorate %[[#PTR_I16:]] Alignment 2 +; CHECK: OpDecorate %[[#PTR_C32:]] Alignment 8 +; CHECK: OpDecorate %[[#PTR_PTR_C32:]] Alignment 8 +; CHECK: OpDecorate %[[#PTR_I64:]] Alignment 8 +; CHECK: OpDecorate %[[#PTR_C64:]] Alignment 16 +; CHECK: OpDecorate %[[#PTR_PTR_C64:]] Alignment 8 +; CHECK: %[[#F32:]] = OpTypeFloat 32 +; CHECK: %[[#PTR_F32]] = OpTypePointer CrossWorkgroup %[[#F32]] +; CHECK: %[[#I64:]] = OpTypeInt 64 0 +; CHECK: %[[#I16:]] = OpTypeInt 16 0 +; CHECK: %[[#PTR_I16]] = OpTypePointer CrossWorkgroup %[[#I16]] +; CHECK: %[[#C32:]] = OpTypeVector %[[#F32]] 2 +; CHECK: %[[#PTR_C32]] = OpTypePointer CrossWorkgroup %[[#C32]] +; CHECK: %[[#PTR_PTR_C32]] = OpTypePointer CrossWorkgroup %[[#PTR_C32]] +; CHECK: %[[#PTR_I64]] = OpTypePointer CrossWorkgroup %[[#I64]] +; CHECK: %[[#F64:]] = OpTypeFloat 64 +; CHECK: %[[#C64:]] = OpTypeVector %[[#F64]] 2 +; CHECK: %[[#PTR_C64]] = OpTypePointer CrossWorkgroup %[[#C64]] +; CHECK: %[[#PTR_PTR_C64]] = OpTypePointer CrossWorkgroup %[[#PTR_C64]] + +func @cc1(%0: memref, %1: memref>) { +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#PTR_F32]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#PTR_I16]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I64]] +} + +func @cc2(%0: group>) { +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#PTR_PTR_C32]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#PTR_I64]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#PTR_I64]] +} + +func @cc3(%0: group, offset: ?>) { +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#PTR_PTR_C64]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#PTR_I64]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I64]] +} diff --git a/tools/spirvgen/spirvgen.py b/tools/spirvgen/spirvgen.py index 42c6771b..c2003da3 100755 --- a/tools/spirvgen/spirvgen.py +++ b/tools/spirvgen/spirvgen.py @@ -163,7 +163,7 @@ class spv_inst : public ilist_node { bool has_result_id_; }; -using DecorationAttr = std::variant>; +using DecorationAttr = std::variant>; using ExecutionModeAttr = std::variant>; using LiteralContextDependentNumber = std::variant; From 0623096c85d10f619577229d98a6d74b369b888f Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 11 Nov 2024 12:25:30 +0100 Subject: [PATCH 098/297] SPIR-V: Store inst Signed-off-by: Carsten Uphoff --- include/tinytc/types.h | 5 +- include/tinytc/types.hpp | 1 + src/error.cpp | 2 + src/spv/converter.cpp | 158 ++++++++++++++++++++++++++++++++++--- src/spv/converter.hpp | 5 ++ src/spv/defs.hpp | 6 +- src/spv/enums.hpp | 9 ++- src/spv/instructions.hpp | 91 ++++++++++++++++++++- src/spv/module.hpp | 17 ++-- src/spv/names.cpp | 6 ++ src/spv/names.hpp | 6 +- src/spv/pass/dump_asm.cpp | 18 ++--- src/spv/uniquifier.cpp | 15 +++- src/spv/uniquifier.hpp | 4 + src/spv/visit.hpp | 36 ++++++++- test/CMakeLists.txt | 2 + test/spv/load.ir | 51 ++++++++++++ test/spv/store.ir | 79 +++++++++++++++++++ tools/spirvgen/filter.json | 4 +- tools/spirvgen/spirvgen.py | 6 ++ 20 files changed, 472 insertions(+), 49 deletions(-) create mode 100644 test/spv/load.ir create mode 100644 test/spv/store.ir diff --git a/include/tinytc/types.h b/include/tinytc/types.h index a67d2223..67f76cd5 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -88,8 +88,9 @@ typedef enum { tinytc_status_ir_constant_mismatch = 0x127, ///< Constant mismatch // SPIR-V errors tinytc_status_spirv_forbidden_forward_declaration = - 0x1000, ///< Forward declaration of id is forbidden - tinytc_status_spirv_undefined_value = 0x1001, ///< Undefined value + 0x1000, ///< Forward declaration of id is forbidden + tinytc_status_spirv_undefined_value = 0x1001, ///< Undefined value + tinytc_status_spirv_missing_dope_vector = 0x1002, ///< Missing dope vector // Level zero errors tinytc_status_ze_result_not_ready = 0x10000, ///< ZE_RESULT_NOT_READY tinytc_status_ze_result_error_device_lost = 0x10001, ///< ZE_RESULT_ERROR_DEVICE_LOST diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index f03f1ec0..49a55508 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -96,6 +96,7 @@ enum class status { ir_constant_mismatch = tinytc_status_ir_constant_mismatch, spirv_forbidden_forward_declaration = tinytc_status_spirv_forbidden_forward_declaration, spirv_undefined_value = tinytc_status_spirv_undefined_value, + spirv_missing_dope_vector = tinytc_status_spirv_missing_dope_vector, ze_result_not_ready = tinytc_status_ze_result_not_ready, ze_result_error_device_lost = tinytc_status_ze_result_error_device_lost, ze_result_error_out_of_host_memory = tinytc_status_ze_result_error_out_of_host_memory, diff --git a/src/error.cpp b/src/error.cpp index e9535fd7..9df6d52f 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -207,6 +207,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Forward declaration of id is forbidden"; case tinytc_status_spirv_undefined_value: return "Undefined SPIR-V value"; + case tinytc_status_spirv_missing_dope_vector: + return "Dope vector missing (internal compiler error)"; // Level Zero case tinytc_status_ze_result_not_ready: return "ZE_RESULT_NOT_READY"; diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index ed1fa337..781f6454 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -59,16 +59,28 @@ auto convert_prog_to_spirv(tinytc_prog const &p, } } - // Add missing capabilites + // Add missing capabilites and extensions for (std::int32_t s = 0; s < num_module_sections; ++s) { for (auto const &i : m->insts(enum_cast
(s))) { - visit(overloaded{[&](I const &) { + visit(overloaded{[&](I const &in) { + if (isa(static_cast(in))) { + // We manage OpAtomicFAddExt manually as the required + // capabilitites depend on the data type + return; + } for (auto const &cap : I::required_capabilities) { conv.unique().capability(cap); } }, [&](auto const &) {}}, i); + visit(overloaded{[&](I const &) { + for (auto const &ext_name : I::required_extensions) { + conv.unique().extension(ext_name); + } + }, + [&](auto const &) {}}, + i); } } @@ -239,6 +251,95 @@ auto inst_converter::make_dope_vector(tinytc_value const &v) -> dope_vector * { *v.ty()); } +void inst_converter::make_store(store_flag flag, scalar_type sty, address_space as, + spv_inst *pointer, spv_inst *value) { + auto const add_fadd_caps = [&] { + switch (sty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + break; + case scalar_type::f32: + case scalar_type::c32: + unique_.capability(Capability::AtomicFloat32AddEXT); + break; + case scalar_type::f64: + case scalar_type::c64: + unique_.capability(Capability::AtomicFloat64AddEXT); + break; + } + }; + auto const split_re_im = [&]() -> std::array, 2u> { + auto component_sty = element_type(sty); + auto float_ty = unique_.spv_ty(scalar_data_type::get(ctx_, component_sty)); + const auto storage_cls = address_space_to_storage_class(as); + auto pointer_ty = unique_.spv_pointer_ty(storage_cls, float_ty, alignment(component_sty)); + auto c0 = unique_.constant(std::int32_t{0}); + auto c1 = unique_.constant(std::int32_t{1}); + auto re_ptr = mod_->add(pointer_ty, pointer, std::vector{c0}); + auto im_ptr = mod_->add(pointer_ty, pointer, std::vector{c1}); + auto re_val = + mod_->add(float_ty, value, std::vector{0}); + auto im_val = + mod_->add(float_ty, value, std::vector{1}); + return {{{re_ptr, re_val}, {im_ptr, im_val}}}; + }; + switch (flag) { + case store_flag::regular: + mod_->add(pointer, value); + break; + case store_flag::atomic: { + auto scope = unique_.constant(static_cast(Scope::Workgroup)); + auto semantics = unique_.constant(static_cast(MemorySemantics::Relaxed)); + switch (sty) { + case scalar_type::c32: + case scalar_type::c64: { + auto re_im = split_re_im(); + mod_->add(re_im[0][0], scope, semantics, re_im[0][1]); + mod_->add(re_im[1][0], scope, semantics, re_im[1][1]); + break; + } + default: + mod_->add(pointer, scope, semantics, value); + break; + } + break; + } + case store_flag::atomic_add: { + auto result_ty = unique_.spv_ty(scalar_data_type::get(ctx_, sty)); + auto scope = unique_.constant(static_cast(Scope::Workgroup)); + auto semantics = unique_.constant(static_cast(MemorySemantics::Relaxed)); + switch (sty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + mod_->add(result_ty, pointer, scope, semantics, value); + break; + case scalar_type::f32: + case scalar_type::f64: + add_fadd_caps(); + mod_->add(result_ty, pointer, scope, semantics, value); + break; + case scalar_type::c32: + case scalar_type::c64: { + add_fadd_caps(); + auto re_im = split_re_im(); + auto component_sty = element_type(sty); + auto float_ty = unique_.spv_ty(scalar_data_type::get(ctx_, component_sty)); + mod_->add(float_ty, re_im[0][0], scope, semantics, re_im[0][1]); + mod_->add(float_ty, re_im[1][0], scope, semantics, re_im[1][1]); + break; + } + } + break; + } break; + } +} + void inst_converter::operator()(inst_node const &in) { // @todo throw compilation_error(in.loc(), status::not_implemented); @@ -873,7 +974,7 @@ void inst_converter::operator()(load_inst const &in) { auto spv_result_ty = unique_.spv_ty(in.result(0).ty()); auto dv = get_dope_vector(in.operand()); if (!dv) { - throw compilation_error(in.loc(), status::internal_compiler_error); + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); } if (auto group_ty = dyn_cast(in.operand().ty()); group_ty) { @@ -898,15 +999,20 @@ void inst_converter::operator()(load_inst const &in) { rdv->stride(i, make_dope_par(dv->static_stride(i), dv->stride(i))); } } else if (auto memref_ty = dyn_cast(in.operand().ty()); memref_ty) { - auto offset = unique_.null_constant(spv_index_ty); - for (std::int64_t i = 0; i < memref_ty->dim(); ++i) { - auto tmp = mod_->add(spv_index_ty, val(in.index_list()[i]), dv->stride(i)); - offset = mod_->add(spv_index_ty, offset, tmp); - } + const auto pointer = [&]() -> spv_inst * { + if (memref_ty->dim() == 0) { + return val(in.operand()); + } - auto pointer = mod_->add(spv_pointer_ty, val(in.operand()), - offset, std::vector{}); - declare(in.result(0), mod_->add(spv_result_ty, pointer)); + auto offset = unique_.null_constant(spv_index_ty); + for (std::int64_t i = 0; i < memref_ty->dim(); ++i) { + auto tmp = mod_->add(spv_index_ty, val(in.index_list()[i]), dv->stride(i)); + offset = mod_->add(spv_index_ty, offset, tmp); + } + return mod_->add(spv_pointer_ty, val(in.operand()), offset, + std::vector{}); + }; + declare(in.result(0), mod_->add(spv_result_ty, pointer())); } else { throw compilation_error(in.loc(), status::ir_expected_memref_or_group); } @@ -918,6 +1024,36 @@ void inst_converter::operator()(num_subgroups_inst const &in) { void inst_converter::operator()(parallel_inst const &in) { run_on_region(in.body()); } +void inst_converter::operator()(store_inst const &in) { + auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); + auto spv_pointer_ty = unique_.spv_ty(in.operand().ty()); + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + + if (auto memref_ty = dyn_cast(in.operand().ty()); memref_ty) { + const auto pointer = [&]() -> spv_inst * { + if (memref_ty->dim() == 0) { + return val(in.operand()); + } + + auto offset = unique_.null_constant(spv_index_ty); + for (std::int64_t i = 0; i < memref_ty->dim(); ++i) { + auto tmp = mod_->add(spv_index_ty, val(in.index_list()[i]), dv->stride(i)); + offset = mod_->add(spv_index_ty, offset, tmp); + } + return mod_->add(spv_pointer_ty, val(in.operand()), offset, + std::vector{}); + }; + + make_store(in.flag(), memref_ty->element_ty(), memref_ty->addrspace(), pointer(), + val(in.val())); + } else { + throw compilation_error(in.loc(), status::ir_expected_memref); + } +} + void inst_converter::operator()(subgroup_id_inst const &in) { declare(in.result(0), load_builtin(BuiltIn::SubgroupId)); } diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index 9be1ec3f..c5fc354c 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -77,6 +77,7 @@ class inst_converter { void operator()(load_inst const &in); void operator()(num_subgroups_inst const &in); void operator()(parallel_inst const &in); + void operator()(store_inst const &in); void operator()(subgroup_id_inst const &in); void operator()(subgroup_local_id_inst const &in); void operator()(subgroup_size_inst const &in); @@ -115,6 +116,8 @@ class inst_converter { auto make_constant(scalar_type sty, spv_inst *spv_ty, constant_inst::value_type const &val) -> spv_inst *; auto make_dope_vector(tinytc_value const &v) -> dope_vector *; + void make_store(store_flag flag, scalar_type sty, address_space as, spv_inst *pointer, + spv_inst *value); tinytc_compiler_context_t ctx_; mod *mod_; @@ -129,5 +132,7 @@ class inst_converter { template concept spv_inst_with_required_capabilities = requires() { T::required_capabilities; }; +template +concept spv_inst_with_required_extensions = requires() { T::required_extensions; }; } // namespace tinytc::spv diff --git a/src/spv/defs.hpp b/src/spv/defs.hpp index b101e998..feba658f 100644 --- a/src/spv/defs.hpp +++ b/src/spv/defs.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_DEFS_2024118_HPP -#define GENERATED_DEFS_2024118_HPP +#ifndef GENERATED_DEFS_20241111_HPP +#define GENERATED_DEFS_20241111_HPP #include "enums.hpp" #include "support/ilist_base.hpp" @@ -55,4 +55,4 @@ using PairIdRefLiteralInteger = std::pair; } // namespace tinytc::spv -#endif // GENERATED_DEFS_2024118_HPP +#endif // GENERATED_DEFS_20241111_HPP diff --git a/src/spv/enums.hpp b/src/spv/enums.hpp index 2ee1718b..98df5440 100644 --- a/src/spv/enums.hpp +++ b/src/spv/enums.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_ENUMS_2024118_HPP -#define GENERATED_ENUMS_2024118_HPP +#ifndef GENERATED_ENUMS_20241111_HPP +#define GENERATED_ENUMS_20241111_HPP namespace tinytc::spv { @@ -354,6 +354,9 @@ enum class Op { CooperativeMatrixStoreKHR = 4458, CooperativeMatrixMulAddKHR = 4459, CooperativeMatrixLengthKHR = 4460, + AtomicFMinEXT = 5614, + AtomicFMaxEXT = 5615, + AtomicFAddEXT = 6035, }; enum class ImageOperands { None = 0x0000, @@ -1422,4 +1425,4 @@ enum class FPEncoding {}; } // namespace tinytc::spv -#endif // GENERATED_ENUMS_2024118_HPP +#endif // GENERATED_ENUMS_20241111_HPP diff --git a/src/spv/instructions.hpp b/src/spv/instructions.hpp index f9a9b2d8..c9f84ec9 100644 --- a/src/spv/instructions.hpp +++ b/src/spv/instructions.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_INSTRUCTIONS_2024118_HPP -#define GENERATED_INSTRUCTIONS_2024118_HPP +#ifndef GENERATED_INSTRUCTIONS_20241111_HPP +#define GENERATED_INSTRUCTIONS_20241111_HPP #include "defs.hpp" #include "enums.hpp" @@ -5544,6 +5544,8 @@ class OpExecutionModeId : public spv_inst { class OpDecorateId : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DecorateId; } + constexpr static std::array required_extensions = { + "SPV_GOOGLE_hlsl_functionality1"}; OpDecorateId(IdRef op0, Decoration op1) : spv_inst{Op::DecorateId, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} inline auto op0() -> IdRef & { return op0_; } @@ -6645,7 +6647,90 @@ class OpCooperativeMatrixLengthKHR : public spv_inst { IdResultType type_; IdRef op0_; }; +class OpAtomicFMinEXT : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicFMinEXT; } + constexpr static std::array required_capabilities = { + Capability::AtomicFloat16MinMaxEXT, Capability::AtomicFloat32MinMaxEXT, + Capability::AtomicFloat64MinMaxEXT, Capability::AtomicFloat16VectorNV}; + OpAtomicFMinEXT(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicFMinEXT, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicFMaxEXT : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicFMaxEXT; } + constexpr static std::array required_capabilities = { + Capability::AtomicFloat16MinMaxEXT, Capability::AtomicFloat32MinMaxEXT, + Capability::AtomicFloat64MinMaxEXT, Capability::AtomicFloat16VectorNV}; + OpAtomicFMaxEXT(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicFMaxEXT, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicFAddEXT : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicFAddEXT; } + constexpr static std::array required_capabilities = { + Capability::AtomicFloat16AddEXT, Capability::AtomicFloat32AddEXT, + Capability::AtomicFloat64AddEXT, Capability::AtomicFloat16VectorNV}; + constexpr static std::array required_extensions = { + "SPV_EXT_shader_atomic_float_add"}; + OpAtomicFAddEXT(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicFAddEXT, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; } // namespace tinytc::spv -#endif // GENERATED_INSTRUCTIONS_2024118_HPP +#endif // GENERATED_INSTRUCTIONS_20241111_HPP diff --git a/src/spv/module.hpp b/src/spv/module.hpp index e7c2ab7a..f4917c3f 100644 --- a/src/spv/module.hpp +++ b/src/spv/module.hpp @@ -26,15 +26,16 @@ namespace spv { enum class section { capability = 0, - ext_inst = 1, - memory_model = 2, - entry_point = 3, - execution_mode = 4, - decoration = 5, - type_const_var = 6, - function = 7 + extension = 1, + ext_inst = 2, + memory_model = 3, + entry_point = 4, + execution_mode = 5, + decoration = 6, + type_const_var = 7, + function = 8 }; -inline constexpr std::int32_t num_module_sections = 8; +inline constexpr std::int32_t num_module_sections = 9; class mod final { public: diff --git a/src/spv/names.cpp b/src/spv/names.cpp index ad48e41e..77e06035 100644 --- a/src/spv/names.cpp +++ b/src/spv/names.cpp @@ -699,6 +699,12 @@ auto to_string(Op op) -> char const * { return "CooperativeMatrixMulAddKHR"; case Op::CooperativeMatrixLengthKHR: return "CooperativeMatrixLengthKHR"; + case Op::AtomicFMinEXT: + return "AtomicFMinEXT"; + case Op::AtomicFMaxEXT: + return "AtomicFMaxEXT"; + case Op::AtomicFAddEXT: + return "AtomicFAddEXT"; } return "unknown"; } diff --git a/src/spv/names.hpp b/src/spv/names.hpp index 856ecfe5..3fc09b91 100644 --- a/src/spv/names.hpp +++ b/src/spv/names.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_NAMES_2024118_HPP -#define GENERATED_NAMES_2024118_HPP +#ifndef GENERATED_NAMES_20241111_HPP +#define GENERATED_NAMES_20241111_HPP #include "enums.hpp" @@ -68,4 +68,4 @@ auto to_string(FPEncoding e) -> char const *; } // namespace tinytc::spv -#endif // GENERATED_NAMES_2024118_HPP +#endif // GENERATED_NAMES_20241111_HPP diff --git a/src/spv/pass/dump_asm.cpp b/src/spv/pass/dump_asm.cpp index acabfaa7..f837117f 100644 --- a/src/spv/pass/dump_asm.cpp +++ b/src/spv/pass/dump_asm.cpp @@ -8,6 +8,7 @@ #include "support/casting.hpp" #include "support/ilist.hpp" #include "support/ilist_base.hpp" +#include "support/util.hpp" #include "tinytc/types.hpp" #include @@ -75,7 +76,11 @@ void dump_asm_pass::operator()(ExecutionModeAttr const &ea) { ea); } void dump_asm_pass::operator()(LiteralContextDependentNumber const &l) { - std::visit(overloaded{[&](std::signed_integral auto const &l) { + std::visit(overloaded{[&](std::int8_t const &l) { + *os_ << " " + << static_cast(static_cast(l)); + }, + [&](std::signed_integral auto const &l) { using unsigned_t = std::make_unsigned_t>; *os_ << " " << static_cast(l); }, @@ -157,14 +162,9 @@ void dump_asm_pass::run_on_module(mod const &m) { << "; Generator: Tiny Tensor Compiler" << std::endl << "; Bound: " << m.bound() << std::endl << "; Schema: 0"; - visit_section(section::capability); - visit_section(section::ext_inst); - visit_section(section::memory_model); - visit_section(section::entry_point); - visit_section(section::execution_mode); - visit_section(section::decoration); - visit_section(section::type_const_var); - visit_section(section::function); + for (std::int32_t s = 0; s < num_module_sections; ++s) { + visit_section(enum_cast
(s)); + } *os_ << std::endl; } diff --git a/src/spv/uniquifier.cpp b/src/spv/uniquifier.cpp index 83cd9c32..6d39c633 100644 --- a/src/spv/uniquifier.cpp +++ b/src/spv/uniquifier.cpp @@ -20,6 +20,10 @@ namespace tinytc::spv { +auto address_space_to_storage_class(address_space as) -> StorageClass { + return as == address_space::local ? StorageClass::Workgroup : StorageClass::CrossWorkgroup; +} + uniquifier::uniquifier(tinytc_compiler_context_t ctx, mod &m) : ctx_(ctx), mod_(&m) {} auto uniquifier::bool2_ty() -> spv_inst * { @@ -127,6 +131,13 @@ auto uniquifier::constant(LiteralContextDependentNumber cst) -> spv_inst * { }); } +void uniquifier::extension(char const *ext_name) { + if (!extensions_.contains(ext_name)) { + mod_->add_to(section::extension, ext_name); + extensions_.insert(ext_name); + } +} + auto uniquifier::index3_ty() -> spv_inst * { return lookup(index3_ty_, [&] { auto index_ty = spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); @@ -189,9 +200,7 @@ auto uniquifier::spv_ty(const_tinytc_data_type_t ty) -> spv_inst * { alignment(scalar_type::i64)); }, [&](memref_data_type const &mr) -> spv_inst * { - const auto storage_cls = mr.addrspace() == address_space::local - ? StorageClass::Workgroup - : StorageClass::CrossWorkgroup; + const auto storage_cls = address_space_to_storage_class(mr.addrspace()); auto spv_element_ty = spv_ty(mr.element_data_ty()); const std::int32_t align = [&](scalar_type sty) -> std::int32_t { if (is_complex_type(sty)) { diff --git a/src/spv/uniquifier.hpp b/src/spv/uniquifier.hpp index eca69dd6..88e73c31 100644 --- a/src/spv/uniquifier.hpp +++ b/src/spv/uniquifier.hpp @@ -21,6 +21,8 @@ namespace tinytc::spv { class OpTypeFunction; +auto address_space_to_storage_class(address_space as) -> StorageClass; + class uniquifier { public: uniquifier(tinytc_compiler_context_t ctx, mod &m); @@ -32,6 +34,7 @@ class uniquifier { auto builtin_var(BuiltIn b) -> spv_inst *; void capability(Capability cap); auto constant(LiteralContextDependentNumber cst) -> spv_inst *; + void extension(char const *ext_name); auto index3_ty() -> spv_inst *; auto null_constant(spv_inst *spv_ty) -> spv_inst *; auto opencl_ext() -> spv_inst *; @@ -73,6 +76,7 @@ class uniquifier { std::unordered_map builtin_; std::unordered_set capabilities_; std::unordered_map cst_map_; + std::unordered_set extensions_; std::unordered_map null_cst_; std::unordered_multimap spv_function_tys_; std::unordered_map, spv_inst *, diff --git a/src/spv/visit.hpp b/src/spv/visit.hpp index d9fa0726..10e14694 100644 --- a/src/spv/visit.hpp +++ b/src/spv/visit.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_VISIT_2024118_HPP -#define GENERATED_VISIT_2024118_HPP +#ifndef GENERATED_VISIT_20241111_HPP +#define GENERATED_VISIT_20241111_HPP namespace tinytc::spv { @@ -704,6 +704,12 @@ template auto visit(Visitor &&visitor, spv_inst const &inst) return visitor(static_cast(inst)); case Op::CooperativeMatrixLengthKHR: return visitor(static_cast(inst)); + case Op::AtomicFMinEXT: + return visitor(static_cast(inst)); + case Op::AtomicFMaxEXT: + return visitor(static_cast(inst)); + case Op::AtomicFAddEXT: + return visitor(static_cast(inst)); } throw internal_compiler_error(); } @@ -2934,8 +2940,32 @@ template class default_visitor { static_cast(this)->operator()(in.type()); static_cast(this)->operator()(in.op0()); } + auto operator()(OpAtomicFMinEXT const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpAtomicFMaxEXT const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } + auto operator()(OpAtomicFAddEXT const &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + } }; } // namespace tinytc::spv -#endif // GENERATED_VISIT_2024118_HPP +#endif // GENERATED_VISIT_20241111_HPP diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 18b0d670..64b600c1 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -60,6 +60,8 @@ if(SPIRVTools_FOUND) spv/compare.ir spv/for.ir spv/if.ir + spv/load.ir + spv/store.ir spv/work_group.ir spv/unique_function_type.ir ) diff --git a/test/spv/load.ir b/test/spv/load.ir new file mode 100644 index 00000000..56508478 --- /dev/null +++ b/test/spv/load.ir @@ -0,0 +1,51 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s + +; CHECK: %[[#F32:]] = OpTypeFloat 32 +; CHECK: %[[#PTR_F32:]] = OpTypePointer CrossWorkgroup %[[#F32]] +; CHECK: %[[#I64:]] = OpTypeInt 64 0 +; CHECK: %[[#I64_C1:]] = OpConstant %[[#I64]] 1 +; CHECK: %[[#I64_C0:]] = OpConstant %[[#I64]] 0 +; CHECK: %[[#PTR_I64:]] = OpTypePointer CrossWorkgroup %[[#I64]] +; CHECK: %[[#INDEX_NULL:]] = OpConstantNull %[[#I64]] +; CHECK: %[[#PTR_PTR_F32:]] = OpTypePointer CrossWorkgroup %[[#PTR_F32]] + +func @l1(%0: memref) { + %2 = constant 0 -> index + %3 = load %0[%2,%2] : memref +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#L1_MR:]] = OpFunctionParameter %[[#PTR_F32]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#L1_STRIDE:]] = OpFunctionParameter %[[#I64]] +; CHECK: %[[#L1_TMP1:]] = OpIMul %[[#I64]] %[[#I64_C0]] %[[#I64_C1]] +; CHECK-NEXT: %[[#L1_OFFSET1:]] = OpIAdd %[[#I64]] %[[#INDEX_NULL]] %[[#L1_TMP1]] +; CHECK-NEXT: %[[#L1_TMP2:]] = OpIMul %[[#I64]] %[[#I64_C0]] %[[#L1_STRIDE]] +; CHECK-NEXT: %[[#L1_OFFSET2:]] = OpIAdd %[[#I64]] %[[#L1_OFFSET1]] %[[#L1_TMP2]] +; CHECK-NEXT: %[[#L1_MRSUB:]] = OpInBoundsPtrAccessChain %[[#PTR_F32]] %[[#L1_MR]] %[[#L1_OFFSET2]] +; CHECK-NEXT: %[[#]] = OpLoad %[[#F32]] %[[#L1_MRSUB]] +} + +func @l2(%0: group>, offset: ?>) { + %1 = constant 0 -> index + %2 = load %0[%1] : group>, offset: ?> + %3 = load %2[%1] : memref> +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#L2_GROUP:]] = OpFunctionParameter %[[#PTR_PTR_F32]] +; CHECK-NEXT: %[[#L2_PTR_SHAPE:]] = OpFunctionParameter %[[#PTR_I64]] +; CHECK-NEXT: %[[#L2_PTR_STRIDE:]] = OpFunctionParameter %[[#PTR_I64]] +; CHECK-NEXT: %[[#L2_GROUP_OFFSET:]] = OpFunctionParameter %[[#I64]] +; CHECK: %[[#L2_OFFSET1:]] = OpIAdd %[[#I64]] %[[#L2_GROUP_OFFSET]] %[[#I64_C0]] +; CHECK-NEXT: %[[#L2_MR:]] = OpInBoundsPtrAccessChain %[[#PTR_PTR_F32]] %[[#L2_GROUP]] %[[#L2_OFFSET1]] +; CHECK-NEXT: %[[#L2_MR2:]] = OpLoad %[[#PTR_F32]] %[[#L2_MR]] +; CHECK-NEXT: %[[#L2_SUBPTR_SHAPE:]] = OpInBoundsPtrAccessChain %[[#PTR_I64]] %[[#L2_PTR_SHAPE]] %[[#L2_OFFSET1]] +; CHECK-NEXT: %[[#]] = OpLoad %[[#I64]] %[[#L2_SUBPTR_SHAPE]] +; CHECK-NEXT: %[[#L2_SUBPTR_STRIDE:]] = OpInBoundsPtrAccessChain %[[#PTR_I64]] %[[#L2_PTR_STRIDE]] %[[#L2_OFFSET1]] +; CHECK-NEXT: %[[#L2_STRIDE:]] = OpLoad %[[#I64]] %[[#L2_SUBPTR_STRIDE]] +; CHECK-NEXT: %[[#L2_TMP:]] = OpIMul %[[#I64]] %[[#I64_C0]] %[[#L2_STRIDE]] +; CHECK-NEXT: %[[#L2_OFFSET2:]] = OpIAdd %[[#I64]] %[[#INDEX_NULL]] %[[L2_TMP]] +; CHECK-NEXT: %[[#L2_SUBMR:]] = OpInBoundsPtrAccessChain %[[#PTR_F32]] %[[#L2_MR2]] %[[#L2_OFFSET2]] +; CHECK-NEXT: %[[#]] = OpLoad %[[#F32]] %[[#L2_SUBMR]] +} diff --git a/test/spv/store.ir b/test/spv/store.ir new file mode 100644 index 00000000..5168140e --- /dev/null +++ b/test/spv/store.ir @@ -0,0 +1,79 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s + +; CHECK: OpCapability AtomicFloat32AddEXT +; CHECK: OpCapability AtomicFloat64AddEXT +; CHECK: OpExtension "SPV_EXT_shader_atomic_float_add" +; CHECK: %[[#I8:]] = OpTypeInt 8 0 +; CHECK: %[[#PTR_I8:]] = OpTypePointer CrossWorkgroup %[[#I8]] +; CHECK: %[[#I64:]] = OpTypeInt 64 0 +; CHECK: %[[#I64_C4:]] = OpConstant %[[#I64]] 4 +; CHECK: %[[#I64_C1:]] = OpConstant %[[#I64]] 1 +; CHECK: %[[#I64_C0:]] = OpConstant %[[#I64]] 0 +; CHECK: %[[#I8_C214:]] = OpConstant %[[#I8]] 214 +; CHECK: %[[#INDEX_NULL:]] = OpConstantNull %[[#I64]] +; CHECK: %[[#I32:]] = OpTypeInt 32 0 +; CHECK: %[[#I32_C2:]] = OpConstant %[[#I32]] 2 +; CHECK: %[[#I32_C0:]] = OpConstant %[[#I32]] 0 +; CHECK: %[[#F32:]] = OpTypeFloat 32 +; CHECK: %[[#PTR_F32:]] = OpTypePointer CrossWorkgroup %[[#F32]] +; CHECK: %[[#F32_C42:]] = OpConstant %[[#F32]] 42 +; CHECK: %[[#F64:]] = OpTypeFloat 64 +; CHECK: %[[#C64:]] = OpTypeVector %[[#F64]] 2 +; CHECK: %[[#PTR_C64:]] = OpTypePointer CrossWorkgroup %[[#C64]] +; CHECK: %[[#F64_C42:]] = OpConstant %[[#F64]] 42 +; CHECK: %[[#F64_C1:]] = OpConstant %[[#F64]] 1 +; CHECK: %[[#C64_C42_1:]] = OpConstantComposite %[[#C64]] %[[#F64_C42]] %[[#F64_C1]] +; CHECK: %[[#PTR_F64:]] = OpTypePointer CrossWorkgroup %[[#F64]] +; CHECK: %[[#I32_C1:]] = OpConstant %[[#I32]] 1 + +func @si8(%0: memref, %1: memref) { + %2 = constant 0 -> index + %3 = constant -42 -> i8 + store %3, %0[%2,%2] : memref + store.atomic %3, %1[] : memref + store.atomic_add %3, %1[] : memref +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#SI8_MR1:]] = OpFunctionParameter %[[#PTR_I8]] +; CHECK-NEXT: %[[#SI8_MR2:]] = OpFunctionParameter %[[#PTR_I8]] +; CHECK: %[[#SI8_TMP1:]] = OpIMul %[[#I64]] %[[#I64_C0]] %[[#I64_C1]] +; CHECK-NEXT: %[[#SI8_OFFSET1:]] = OpIAdd %[[#I64]] %[[#INDEX_NULL]] %[[#SI8_TMP1]] +; CHECK-NEXT: %[[#SI8_TMP2:]] = OpIMul %[[#I64]] %[[#I64_C0]] %[[#I64_C4]] +; CHECK-NEXT: %[[#SI8_OFFSET2:]] = OpIAdd %[[#I64]] %[[#SI8_OFFSET1]] %[[#SI8_TMP2]] +; CHECK-NEXT: %[[#SI8_MR1SUB:]] = OpInBoundsPtrAccessChain %[[#PTR_I8]] %[[#SI8_MR1]] %[[#SI8_OFFSET2]] +; CHECK-NEXT: OpStore %[[#SI8_MR1SUB]] %[[#I8_C214]] +; CHECK: OpAtomicStore %[[#SI8_MR2]] %[[#I32_C2]] %[[#I32_C0]] %[[#I8_C214]] +; CHECK: %[[#]] = OpAtomicIAdd %[[#I8]] %[[#SI8_MR2]] %[[#I32_C2]] %[[#I32_C0]] %[[#I8_C214]] +} + +func @sf32(%0: memref) { + %1 = constant 42.0 -> f32 + store.atomic %1, %0[] : memref + store.atomic_add %1, %0[] : memref +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#SF32_MR:]] = OpFunctionParameter %[[#PTR_F32]] +; CHECK: OpAtomicStore %[[#SF32_MR]] %[[#I32_C2]] %[[#I32_C0]] %[[#F32_C42]] +; CHECK: %[[#]] = OpAtomicFAddEXT %[[#F32]] %[[#SF32_MR]] %[[#I32_C2]] %[[#I32_C0]] %[[#F32_C42]] +} + +func @sc64(%0: memref) { + %1 = constant [42.0, 1.0] -> c64 + store.atomic %1, %0[] : memref + store.atomic_add %1, %0[] : memref +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#SC64_MR:]] = OpFunctionParameter %[[#PTR_C64]] +; CHECK: %[[#SC64_RE_PTR1:]] = OpInBoundsAccessChain %[[#PTR_F64]] %[[#SC64_MR]] %[[#I32_C0]] +; CHECK-NEXT: %[[#SC64_IM_PTR1:]] = OpInBoundsAccessChain %[[#PTR_F64]] %[[#SC64_MR]] %[[#I32_C1]] +; CHECK-NEXT: %[[#SC64_RE_VAL1:]] = OpCompositeExtract %[[#F64]] %[[#C64_C42_1]] 0 +; CHECK-NEXT: %[[#SC64_IM_VAL1:]] = OpCompositeExtract %[[#F64]] %[[#C64_C42_1]] 1 +; CHECK-NEXT: OpAtomicStore %[[#SC64_RE_PTR1]] %[[#I32_C2]] %[[#I32_C0]] %[[#SC64_RE_VAL1]] +; CHECK-NEXT: OpAtomicStore %[[#SC64_IM_PTR1]] %[[#I32_C2]] %[[#I32_C0]] %[[#SC64_IM_VAL1]] +; CHECK: %[[#SC64_RE_PTR2:]] = OpInBoundsAccessChain %[[#PTR_F64]] %[[#SC64_MR]] %[[#I32_C0]] +; CHECK-NEXT: %[[#SC64_IM_PTR2:]] = OpInBoundsAccessChain %[[#PTR_F64]] %[[#SC64_MR]] %[[#I32_C1]] +; CHECK-NEXT: %[[#SC64_RE_VAL2:]] = OpCompositeExtract %[[#F64]] %[[#C64_C42_1]] 0 +; CHECK-NEXT: %[[#SC64_IM_VAL2:]] = OpCompositeExtract %[[#F64]] %[[#C64_C42_1]] 1 +; CHECK-NEXT: %[[#]] = OpAtomicFAddEXT %[[#F64]] %[[#SC64_RE_PTR2]] %[[#I32_C2]] %[[#I32_C0]] %[[#SC64_RE_VAL2]] +; CHECK-NEXT: %[[#]] = OpAtomicFAddEXT %[[#F64]] %[[#SC64_IM_PTR2]] %[[#I32_C2]] %[[#I32_C0]] %[[#SC64_IM_VAL2]] +} diff --git a/tools/spirvgen/filter.json b/tools/spirvgen/filter.json index 78f52aef..38b632ae 100644 --- a/tools/spirvgen/filter.json +++ b/tools/spirvgen/filter.json @@ -6,6 +6,8 @@ "include" : [ [0, 47], [53, 999], - [4456, 4460] + [4456, 4460], + [5614, 5615], + [6035, 6035] ] } diff --git a/tools/spirvgen/spirvgen.py b/tools/spirvgen/spirvgen.py index c2003da3..e512d5fb 100755 --- a/tools/spirvgen/spirvgen.py +++ b/tools/spirvgen/spirvgen.py @@ -199,6 +199,12 @@ def generate_op_classes(f, grammar): print( f'constexpr static std::array required_capabilities = {{{cap_str}}};', file=f) + if 'extensions' in instruction: + exts = instruction['extensions'] + ext_str = ','.join([f'\"{ext}\"' for ext in exts]) + print( + f'constexpr static std::array required_extensions = {{{ext_str}}};', + file=f) f.write(f'{get_class_name(instruction)}(') f.write(','.join([ f'{o.kind} {o.name}{f" = {o.init}" if o.init else ""}' From b7343898830783a5af065fed309cccdb787706eb Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 11 Nov 2024 12:35:50 +0100 Subject: [PATCH 099/297] SPIR-V: size inst Signed-off-by: Carsten Uphoff --- src/spv/converter.cpp | 8 ++++++++ src/spv/converter.hpp | 1 + test/CMakeLists.txt | 1 + test/spv/size.ir | 27 +++++++++++++++++++++++++++ 4 files changed, 37 insertions(+) create mode 100644 test/spv/size.ir diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index 781f6454..ba8bddb1 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -1024,6 +1024,14 @@ void inst_converter::operator()(num_subgroups_inst const &in) { void inst_converter::operator()(parallel_inst const &in) { run_on_region(in.body()); } +void inst_converter::operator()(size_inst const &in) { + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + declare(in.result(0), dv->shape(in.mode())); +} + void inst_converter::operator()(store_inst const &in) { auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); auto spv_pointer_ty = unique_.spv_ty(in.operand().ty()); diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index c5fc354c..9bd08cb3 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -77,6 +77,7 @@ class inst_converter { void operator()(load_inst const &in); void operator()(num_subgroups_inst const &in); void operator()(parallel_inst const &in); + void operator()(size_inst const &in); void operator()(store_inst const &in); void operator()(subgroup_id_inst const &in); void operator()(subgroup_local_id_inst const &in); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 64b600c1..4ff6cf56 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -61,6 +61,7 @@ if(SPIRVTools_FOUND) spv/for.ir spv/if.ir spv/load.ir + spv/size.ir spv/store.ir spv/work_group.ir spv/unique_function_type.ir diff --git a/test/spv/size.ir b/test/spv/size.ir new file mode 100644 index 00000000..ee463cf5 --- /dev/null +++ b/test/spv/size.ir @@ -0,0 +1,27 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s + +; CHECK: %[[#I64:]] = OpTypeInt 64 0 +; CHECK: %[[#I64_C8:]] = OpConstant %[[#I64]] 8 +; CHECK: %[[#I64_C32:]] = OpConstant %[[#I64]] 32 + +func @size(%0: memref) { + %1 = size %0[0] : memref + %2 = size %0[1] : memref + %3 = size %0[2] : memref + %4 = size %0[3] : memref + %5 = arith.add %1, %2 : index + %6 = arith.add %3, %4 : index +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#]] +; CHECK-NEXT: %[[#SHAPE1:]] = OpFunctionParameter %[[#]] +; CHECK-NEXT: %[[#SHAPE2:]] = OpFunctionParameter %[[#]] +; CHECK-NEXT: %[[#STRIDE1:]] = OpFunctionParameter %[[#]] +; CHECK-NEXT: %[[#STRIDE2:]] = OpFunctionParameter %[[#]] +; CHECK-NEXT: %[[#STRIDE3:]] = OpFunctionParameter %[[#]] +; CHECK: %[[#]] = OpIAdd %[[#]] %[[#SHAPE1]] %[[#SHAPE2]] +; CHECK: %[[#]] = OpIAdd %[[#]] %[[#I64_C8]] %[[#I64_C32]] +} + From 56dd9020207f4ab393218e71eb5ec0880f6e902d Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 11 Nov 2024 21:21:44 +0100 Subject: [PATCH 100/297] SPIR-V: Assembler Signed-off-by: Carsten Uphoff --- docs/api/core_capi.rst | 36 + docs/api/core_capi.yaml | 11 + docs/api/core_cxxapi.rst | 15 + docs/api/core_cxxapi.yaml | 5 + include/tinytc/tinytc.h | 86 +- include/tinytc/tinytc.hpp | 77 +- include/tinytc/types.h | 8 + src/CMakeLists.txt | 3 + src/compiler.cpp | 35 +- src/pass/convert_to_spirv.cpp | 2 +- src/pass/convert_to_spirv.hpp | 6 +- src/spv/converter.cpp | 44 +- src/spv/converter.hpp | 16 +- src/spv/defs.hpp | 13 +- src/spv/enums.hpp | 4 + src/spv/inst_assembler.cpp | 67 + src/spv/inst_assembler.hpp | 92 ++ src/spv/module.cpp | 88 +- src/spv/module.hpp | 54 +- src/spv/pass/assemble.cpp | 49 + src/spv/pass/assemble.hpp | 19 + src/spv/pass/assign_ids.cpp | 59 + src/spv/pass/assign_ids.hpp | 40 + src/spv/pass/dump_asm.cpp | 50 +- src/spv/pass/dump_asm.hpp | 15 +- src/spv/uniquifier.cpp | 30 +- src/spv/uniquifier.hpp | 7 +- src/spv/visit.hpp | 2186 +++++++++++++++++++++++++----- test/CMakeLists.txt | 2 +- test/spv/arith.ir | 2 +- test/spv/arith_unary.ir | 2 +- test/spv/barrier.ir | 2 +- test/spv/builtin.ir | 2 +- test/spv/calling_convention.ir | 2 +- test/spv/cast.ir | 2 +- test/spv/compare.ir | 2 +- test/spv/for.ir | 2 +- test/spv/if.ir | 2 +- test/spv/load.ir | 2 +- test/spv/size.ir | 2 +- test/spv/store.ir | 2 +- test/spv/unique_function_type.ir | 2 +- test/spv/work_group.ir | 2 +- tools/offline_compiler/main.cpp | 12 +- tools/spirvgen/spirvgen.py | 78 +- 45 files changed, 2677 insertions(+), 560 deletions(-) create mode 100644 src/spv/inst_assembler.cpp create mode 100644 src/spv/inst_assembler.hpp create mode 100644 src/spv/pass/assemble.cpp create mode 100644 src/spv/pass/assemble.hpp create mode 100644 src/spv/pass/assign_ids.cpp create mode 100644 src/spv/pass/assign_ids.hpp diff --git a/docs/api/core_capi.rst b/docs/api/core_capi.rst index fe81e9cc..056d5316 100644 --- a/docs/api/core_capi.rst +++ b/docs/api/core_capi.rst @@ -48,6 +48,8 @@ Common * :ref:`tinytc_source_t` + * :ref:`tinytc_spv_mod_t` + * :ref:`tinytc_compiler_context_t` * :ref:`const_tinytc_binary_t` @@ -60,6 +62,8 @@ Common * :ref:`const_tinytc_source_t` + * :ref:`const_tinytc_spv_mod_t` + * :ref:`const_tinytc_compiler_context_t` * :ref:`tinytc_error_reporter_t` @@ -156,6 +160,11 @@ tinytc_source_t .. doxygentypedef:: tinytc_source_t +tinytc_spv_mod_t +................ + +.. doxygentypedef:: tinytc_spv_mod_t + tinytc_compiler_context_t ......................... @@ -186,6 +195,11 @@ const_tinytc_source_t .. doxygentypedef:: const_tinytc_source_t +const_tinytc_spv_mod_t +...................... + +.. doxygentypedef:: const_tinytc_spv_mod_t + const_tinytc_compiler_context_t ............................... @@ -614,6 +628,28 @@ tinytc_recipe_handler_retain .. doxygenfunction:: tinytc_recipe_handler_retain +SPIR-V module +============= + +* Functions + + * :ref:`tinytc_spv_mod_release` + + * :ref:`tinytc_spv_mod_retain` + +SPIR-V module Functions +----------------------- + +tinytc_spv_mod_release +...................... + +.. doxygenfunction:: tinytc_spv_mod_release + +tinytc_spv_mod_retain +..................... + +.. doxygenfunction:: tinytc_spv_mod_retain + Source ====== diff --git a/docs/api/core_capi.yaml b/docs/api/core_capi.yaml index f2790094..5b0fe1f6 100644 --- a/docs/api/core_capi.yaml +++ b/docs/api/core_capi.yaml @@ -22,12 +22,14 @@ Core C-API: - tinytc_recipe_t - tinytc_recipe_handler_t - tinytc_source_t + - tinytc_spv_mod_t - tinytc_compiler_context_t - const_tinytc_binary_t - const_tinytc_core_info_t - const_tinytc_recipe_t - const_tinytc_recipe_handler_t - const_tinytc_source_t + - const_tinytc_spv_mod_t - const_tinytc_compiler_context_t - tinytc_error_reporter_t Binary: @@ -47,6 +49,8 @@ Core C-API: - tinytc_list_function_passes - tinytc_prog_compile_to_opencl - tinytc_prog_compile_to_spirv + - tinytc_prog_compile_to_spirv_and_assemble + - tinytc_spirv_assemble Compiler Context: function: - tinytc_compiler_context_create @@ -96,6 +100,13 @@ Core C-API: - tinytc_recipe_retain - tinytc_recipe_handler_release - tinytc_recipe_handler_retain + SPIR-V module: + function: + - tinytc_spv_mod_dump + - tinytc_spv_mod_print_to_file + - tinytc_spv_mod_print_to_string + - tinytc_spv_mod_release + - tinytc_spv_mod_retain Source: function: - tinytc_source_get_code diff --git a/docs/api/core_cxxapi.rst b/docs/api/core_cxxapi.rst index 5038b0af..dccbd1a0 100644 --- a/docs/api/core_cxxapi.rst +++ b/docs/api/core_cxxapi.rst @@ -436,6 +436,21 @@ is_usm_pointer_type .. doxygenvariable:: tinytc::is_usm_pointer_type +SPIR-V module +============= + +* Classes + + * :ref:`spv_mod` + +SPIR-V module Classes +--------------------- + +spv_mod +....... + +.. doxygenclass:: tinytc::spv_mod + Source ====== diff --git a/docs/api/core_cxxapi.yaml b/docs/api/core_cxxapi.yaml index b329c907..f8cb1c6a 100644 --- a/docs/api/core_cxxapi.yaml +++ b/docs/api/core_cxxapi.yaml @@ -31,6 +31,8 @@ Core C++-API: - tinytc::list_function_passes - tinytc::compile_to_opencl - tinytc::compile_to_spirv + - tinytc::compile_to_spirv_and_assemble + - tinytc::spirv_assemble Compiler Context: function: - tinytc::make_compiler_context @@ -72,6 +74,9 @@ Core C++-API: - tinytc::auto_mem_type_v - tinytc::is_supported_scalar_type - tinytc::is_usm_pointer_type + SPIR-V module: + class: + - tinytc::spv_mod Source: class: - tinytc::source diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index aaeca1d7..027129cb 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -1134,6 +1134,30 @@ TINYTC_EXPORT tinytc_status_t tinytc_prog_release(tinytc_prog_t prg); */ TINYTC_EXPORT tinytc_status_t tinytc_prog_retain(tinytc_prog_t prg); +//////////////////////////// +/////// SPIR-V Module ////// +//////////////////////////// + +/** + * @brief Release SPIR-V module + * + * Decreases reference count by 1, free memory if reference count is 0. + * + * @param mod [inout] SPIR-V module + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spv_mod_release(tinytc_spv_mod_t mod); + +/** + * @brief Increase reference count of SPIR-V module by 1 + * + * @param mod [inout] SPIR-V module + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spv_mod_retain(tinytc_spv_mod_t mod); + //////////////////////////// // Visitors and transforms / //////////////////////////// @@ -1170,6 +1194,39 @@ TINYTC_EXPORT tinytc_status_t tinytc_prog_print_to_file(const_tinytc_prog_t prg, */ TINYTC_EXPORT tinytc_status_t tinytc_prog_print_to_string(const_tinytc_prog_t prg, char **str); +/** + * @brief Dump SPIR-V module to stderr + * + * @param mod [in] module + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spv_mod_dump(const_tinytc_spv_mod_t mod); + +/** + * @brief Print SPIR-V module to file + * + * @param mod [in] module + * @param filename [in] filename + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spv_mod_print_to_file(const_tinytc_spv_mod_t mod, + char const *filename); + +/** + * @brief Print SPIR-V module to string + * + * The user is responsible to dispose the string with tinytc_string_destroy. + * + * @param mod [in] module + * @param str [out] pointer to string + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spv_mod_print_to_string(const_tinytc_spv_mod_t mod, + char **str); + /** * @brief Delete a (non-const) string returned from tinytc API * @@ -1486,17 +1543,40 @@ TINYTC_EXPORT tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src const_tinytc_core_info_t info); /** - * @brief Compiler tensor language to SPIR-V + * @brief Compile tensor language to SPIR-V * - * @param bin [out] pointer to the binary object created + * @param mod [out] pointer to the SPIR-V module created * @param prg [inout] tensor program; modified as compiler passes are run * @param info [in] core info object * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_prog_compile_to_spirv(tinytc_binary_t *bin, tinytc_prog_t prg, +TINYTC_EXPORT tinytc_status_t tinytc_prog_compile_to_spirv(tinytc_spv_mod_t *mod, tinytc_prog_t prg, const_tinytc_core_info_t info); +/** + * @brief Compiler tensor language to SPIR-V and assemble + * + * @param bin [out] pointer to the binary object created + * @param prg [inout] tensor program; modified as compiler passes are run + * @param info [in] core info object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prog_compile_to_spirv_and_assemble( + tinytc_binary_t *bin, tinytc_prog_t prg, const_tinytc_core_info_t info); + +/** + * @brief Assemble SPIR-V module + * + * @param bin [out] pointer to the binary object created + * @param mod [in] SPIR-V module + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spirv_assemble(tinytc_binary_t *bin, + const_tinytc_spv_mod_t mod); + /** * @brief Get source text * diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 6ce71fb2..988db08e 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1722,6 +1722,50 @@ inline prog make_prog(compiler_context const &ctx, location const &loc = {}) { return prog{prg}; } +//////////////////////////// +/////// SPIR-V Module ////// +//////////////////////////// + +namespace internal { +template <> struct shared_handle_traits { + static auto retain(tinytc_spv_mod_t handle) -> tinytc_status_t { + return tinytc_spv_mod_retain(handle); + } + static auto release(tinytc_spv_mod_t handle) -> tinytc_status_t { + return tinytc_spv_mod_release(handle); + } +}; +} // namespace internal + +//! @brief Reference-counting wrapper for tinytc_spv_mod_t +class spv_mod : public shared_handle { + public: + using shared_handle::shared_handle; + + /** + * @brief Dump module to stderr + */ + void dump() const { CHECK_STATUS(tinytc_spv_mod_dump(obj_)); } + /** + * @brief Dump module to file + * + * @param filename Path to file + */ + void print_to_file(char const *filename) const { + CHECK_STATUS(tinytc_spv_mod_print_to_file(obj_, filename)); + } + /** + * @brief Dump module to string + * + * @return C-string (unique handle) + */ + auto print_to_string() const -> unique_handle { + char *str; + CHECK_STATUS(tinytc_spv_mod_print_to_string(obj_, &str)); + return unique_handle{str}; + } +}; + //////////////////////////// ////////// Builder ///////// //////////////////////////// @@ -2279,16 +2323,43 @@ inline auto compile_to_opencl(prog prg, core_info const &info) -> source { } /** - * @brief Compile program to SPIR-V + * @brief Convert tensor language to SPIR-V * * @param prg Program * @param info Core info * + * @return SPIR-V module + */ +inline auto compile_to_spirv(prog prg, core_info const &info) -> spv_mod { + tinytc_spv_mod_t mod; + CHECK_STATUS(tinytc_prog_compile_to_spirv(&mod, prg.get(), info.get())); + return spv_mod{mod}; +} + +/** + * @brief Compile program to SPIR-V and assemble + * + * @param prg Program + * @param info Core info + * + * @return Binary + */ +inline auto compile_to_spirv_and_assemble(prog prg, core_info const &info) -> binary { + tinytc_binary_t bin; + CHECK_STATUS(tinytc_prog_compile_to_spirv_and_assemble(&bin, prg.get(), info.get())); + return binary{bin}; +} + +/** + * @brief Assemble SPIR-V module + * + * @param mod [in] SPIR-V module + * * @return Binary */ -inline auto compile_to_spirv(prog prg, core_info const &info) -> binary { +inline auto spirv_assemble(spv_mod const &mod) -> binary { tinytc_binary_t bin; - CHECK_STATUS(tinytc_prog_compile_to_spirv(&bin, prg.get(), info.get())); + CHECK_STATUS(tinytc_spirv_assemble(&bin, mod.get())); return binary{bin}; } diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 67f76cd5..18592549 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -442,6 +442,14 @@ typedef struct tinytc_prog *tinytc_prog_t; //! @brief const prog handle typedef const struct tinytc_prog *const_tinytc_prog_t; +//! @struct tinytc_spv_mod +//! @brief Opaque struct for a SPIR-V module +struct tinytc_spv_mod; +//! @brief spv_mod handle +typedef struct tinytc_spv_mod *tinytc_spv_mod_t; +//! @brief const spv_mod handle +typedef const struct tinytc_spv_mod *const_tinytc_spv_mod_t; + //! @struct tinytc_core_info; //! @brief Opaque struct for core information struct tinytc_core_info; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c7fe5e1c..20bdf755 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -64,9 +64,12 @@ set(SOURCES required_extensions.cpp scalar_type.cpp spv/converter.cpp + spv/inst_assembler.cpp spv/module.cpp spv/names.cpp spv/opencl.std.cpp + spv/pass/assemble.cpp + spv/pass/assign_ids.cpp spv/pass/dump_asm.cpp spv/uniquifier.cpp source.cpp diff --git a/src/compiler.cpp b/src/compiler.cpp index 5b1a359b..49db9efc 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -21,7 +21,8 @@ #include "reference_counted.hpp" #include "required_extensions.hpp" #include "source.hpp" -#include "spv/pass/dump_asm.hpp" +#include "spv/pass/assemble.hpp" +#include "spv/pass/assign_ids.hpp" #include "tinytc/tinytc.h" #include "tinytc/types.h" #include "tinytc/types.hpp" @@ -152,23 +153,37 @@ tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_ prg->context()); } -tinytc_status_t tinytc_prog_compile_to_spirv(tinytc_binary_t *bin, tinytc_prog_t prg, +tinytc_status_t tinytc_prog_compile_to_spirv(tinytc_spv_mod_t *mod, tinytc_prog_t prg, const_tinytc_core_info_t info) { - if (bin == nullptr || prg == nullptr || info == nullptr) { + if (mod == nullptr || prg == nullptr || info == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code( [&] { apply_default_optimization_pipeline(prg, info); - // opencl - auto mod = convert_to_spirv_pass{info}.run_on_program(*prg); - spv::dump_asm_pass{std::cout}.run_on_module(*mod); - - //*bin = std::make_unique<::tinytc_binary>(prg->share_context(), mod.to_binary(), - // bundle_format::spirv, info->core_features()) - //.release(); + *mod = convert_to_spirv_pass{info}.run_on_program(*prg).release(); + spv::id_assigner{}.run_on_module(**mod); }, prg->context()); } + +tinytc_status_t tinytc_prog_compile_to_spirv_and_assemble(tinytc_binary_t *bin, tinytc_prog_t prg, + const_tinytc_core_info_t info) { + if (bin == nullptr || prg == nullptr || info == nullptr) { + return tinytc_status_invalid_arguments; + } + tinytc_spv_mod_t mod; + TINYTC_CHECK_STATUS(tinytc_prog_compile_to_spirv(&mod, prg, info)); + auto mod_ = spv_mod{mod}; // For clean-up + TINYTC_CHECK_STATUS(tinytc_spirv_assemble(bin, mod_.get())); + return tinytc_status_success; +} + +tinytc_status_t tinytc_spirv_assemble(tinytc_binary_t *bin, const_tinytc_spv_mod_t mod) { + if (bin == nullptr || mod == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *bin = spv::assembler{}.run_on_module(*mod).release(); }); +} } diff --git a/src/pass/convert_to_spirv.cpp b/src/pass/convert_to_spirv.cpp index 5b221b13..815697ab 100644 --- a/src/pass/convert_to_spirv.cpp +++ b/src/pass/convert_to_spirv.cpp @@ -16,7 +16,7 @@ convert_to_spirv_pass::convert_to_spirv_pass(::tinytc_core_info const *info) } } -auto convert_to_spirv_pass::run_on_program(program_node const &p) -> std::unique_ptr { +auto convert_to_spirv_pass::run_on_program(program_node const &p) -> spv_mod { return spv::convert_prog_to_spirv(p, *info_); } diff --git a/src/pass/convert_to_spirv.hpp b/src/pass/convert_to_spirv.hpp index 7559511a..57b23049 100644 --- a/src/pass/convert_to_spirv.hpp +++ b/src/pass/convert_to_spirv.hpp @@ -6,9 +6,7 @@ #include "device_info.hpp" #include "node/program_node.hpp" -#include "spv/module.hpp" - -#include +#include "tinytc/tinytc.hpp" namespace tinytc { @@ -16,7 +14,7 @@ class convert_to_spirv_pass { public: convert_to_spirv_pass(::tinytc_core_info const *info); - auto run_on_program(program_node const &p) -> std::unique_ptr; + auto run_on_program(program_node const &p) -> spv_mod; private: ::tinytc_core_info const *info_; diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index ba8bddb1..75fb58b8 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -2,7 +2,6 @@ // SPDX-License-Identifier: BSD-3-Clause #include "spv/converter.hpp" -#include "compiler_context.hpp" #include "error.hpp" #include "node/data_type_node.hpp" #include "node/function_node.hpp" @@ -21,13 +20,13 @@ #include "support/util.hpp" #include "support/visit.hpp" #include "tinytc/tinytc.hpp" -#include "tinytc/types.h" #include "tinytc/types.hpp" #include #include #include #include +#include #include #include #include @@ -39,10 +38,11 @@ namespace tinytc::spv { auto convert_prog_to_spirv(tinytc_prog const &p, - tinytc_core_info const &info) -> std::unique_ptr { - auto m = std::make_unique(); + tinytc_core_info const &info) -> ::tinytc::spv_mod { + auto m = ::tinytc::spv_mod{ + std::make_unique(p.share_context(), info.core_features()).release()}; - auto conv = inst_converter{p.context(), *m}; + auto conv = inst_converter{*m}; conv.unique().capability(Capability::Addresses); conv.unique().capability(Capability::Kernel); @@ -111,8 +111,7 @@ auto dope_vector::num_dynamic() const -> std::int64_t { return sum_dynamic(static_shape_) + sum_dynamic(static_stride_); } -inst_converter::inst_converter(tinytc_compiler_context_t ctx, mod &m) - : ctx_(ctx), mod_(&m), unique_(ctx, m) {} +inst_converter::inst_converter(tinytc_spv_mod &m) : mod_(&m), unique_(m) {} auto inst_converter::get_last_label() -> spv_inst * { auto &insts = mod_->insts(section::function); @@ -230,7 +229,7 @@ auto inst_converter::make_dope_vector(tinytc_value const &v) -> dope_vector * { throw compilation_error(v.loc(), status::internal_compiler_error); } - auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); + auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); return ::tinytc::visit( overloaded{[&](memref_data_type const &mr) -> dope_vector * { return &(dope_vec_[&v] = dope_vector{spv_index_ty, mr.shape(), mr.stride()}); @@ -273,7 +272,7 @@ void inst_converter::make_store(store_flag flag, scalar_type sty, address_space }; auto const split_re_im = [&]() -> std::array, 2u> { auto component_sty = element_type(sty); - auto float_ty = unique_.spv_ty(scalar_data_type::get(ctx_, component_sty)); + auto float_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), component_sty)); const auto storage_cls = address_space_to_storage_class(as); auto pointer_ty = unique_.spv_pointer_ty(storage_cls, float_ty, alignment(component_sty)); auto c0 = unique_.constant(std::int32_t{0}); @@ -308,7 +307,7 @@ void inst_converter::make_store(store_flag flag, scalar_type sty, address_space break; } case store_flag::atomic_add: { - auto result_ty = unique_.spv_ty(scalar_data_type::get(ctx_, sty)); + auto result_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), sty)); auto scope = unique_.constant(static_cast(Scope::Workgroup)); auto semantics = unique_.constant(static_cast(MemorySemantics::Relaxed)); switch (sty) { @@ -329,7 +328,7 @@ void inst_converter::make_store(store_flag flag, scalar_type sty, address_space add_fadd_caps(); auto re_im = split_re_im(); auto component_sty = element_type(sty); - auto float_ty = unique_.spv_ty(scalar_data_type::get(ctx_, component_sty)); + auto float_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), component_sty)); mod_->add(float_ty, re_im[0][0], scope, semantics, re_im[0][1]); mod_->add(float_ty, re_im[1][0], scope, semantics, re_im[1][1]); break; @@ -490,7 +489,7 @@ void inst_converter::operator()(arith_unary_inst const &in) { spv_inst *a) -> spv_inst * { switch (op) { case arithmetic_unary::abs: { - auto spv_a_ty = unique_.spv_ty(scalar_data_type::get(ctx_, sty)); + auto spv_a_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), sty)); auto a2 = mod_->add(spv_a_ty, a, a); auto a2_0 = mod_->add(ty, a2, std::vector{0}); auto a2_1 = mod_->add(ty, a2, std::vector{1}); @@ -502,7 +501,8 @@ void inst_converter::operator()(arith_unary_inst const &in) { case arithmetic_unary::neg: return mod_->add(ty, a); case arithmetic_unary::conj: { - auto spv_float_ty = unique_.spv_ty(scalar_data_type::get(ctx_, element_type(sty))); + auto spv_float_ty = + unique_.spv_ty(scalar_data_type::get(mod_->context(), element_type(sty))); auto a_im = mod_->add(spv_float_ty, a, std::vector{1}); auto neg_a_im = mod_->add(spv_float_ty, a_im); @@ -590,7 +590,8 @@ void inst_converter::operator()(cast_inst const &in) { return mod_->add(spv_to_ty, a); case scalar_type::c32: case scalar_type::c64: { - auto spv_float_ty = unique_.spv_ty(scalar_data_type::get(ctx_, element_type(to_ty))); + auto spv_float_ty = + unique_.spv_ty(scalar_data_type::get(mod_->context(), element_type(to_ty))); auto re = mod_->add(spv_float_ty, a); return mod_->add(spv_to_ty, re, unique_.null_constant(spv_to_ty), std::vector{0}); @@ -612,7 +613,8 @@ void inst_converter::operator()(cast_inst const &in) { return mod_->add(spv_to_ty, a); case scalar_type::c32: case scalar_type::c64: { - auto spv_float_ty = unique_.spv_ty(scalar_data_type::get(ctx_, element_type(to_ty))); + auto spv_float_ty = + unique_.spv_ty(scalar_data_type::get(mod_->context(), element_type(to_ty))); auto re = mod_->add(spv_float_ty, a); return mod_->add(spv_to_ty, re, unique_.null_constant(spv_to_ty), std::vector{0}); @@ -793,7 +795,7 @@ void inst_converter::operator()(for_inst const &in) { mod_->add(header_label.get()); // Header block - auto spv_bool_ty = unique_.spv_ty(boolean_data_type::get(ctx_)); + auto spv_bool_ty = unique_.spv_ty(boolean_data_type::get(mod_->context())); auto spv_loop_var_ty = unique_.spv_ty(in.loop_var().ty()); auto header_block_last_label = header_label.get(); mod_->insts(section::function).push_back(header_label.release()); @@ -902,13 +904,13 @@ void inst_converter::operator()(for_inst const &in) { void inst_converter::operator()(group_id_inst const &in) { auto gid = load_builtin(BuiltIn::GlobalInvocationId); - auto index_ty = unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); + auto index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); declare(in.result(0), mod_->add(index_ty, gid, std::vector{2})); } void inst_converter::operator()(group_size_inst const &in) { auto gs = load_builtin(BuiltIn::GlobalSize); - auto index_ty = unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); + auto index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); declare(in.result(0), mod_->add(index_ty, gs, std::vector{2})); } @@ -967,7 +969,7 @@ void inst_converter::operator()(if_inst const &in) { } void inst_converter::operator()(load_inst const &in) { - auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); + auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); auto spv_pointer_index_ty = unique_.spv_pointer_ty(StorageClass::CrossWorkgroup, spv_index_ty, alignment(scalar_type::i64)); auto spv_pointer_ty = unique_.spv_ty(in.operand().ty()); @@ -1033,7 +1035,7 @@ void inst_converter::operator()(size_inst const &in) { } void inst_converter::operator()(store_inst const &in) { - auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); + auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); auto spv_pointer_ty = unique_.spv_ty(in.operand().ty()); auto dv = get_dope_vector(in.operand()); if (!dv) { @@ -1166,7 +1168,7 @@ void inst_converter::run_on_function(function_node const &fn, core_config const }()); // Function - auto void_ty = unique_.spv_ty(void_data_type::get(ctx_)); + auto void_ty = unique_.spv_ty(void_data_type::get(mod_->context())); auto fun = mod_->add(void_ty, FunctionControl::None, fun_ty); for (auto const &p : fn.params()) { declare(p, mod_->add(unique_.spv_ty(p.ty()))); diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index 9bd08cb3..7710baca 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -1,7 +1,9 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "compiler_context.hpp" +#ifndef CONVERTER_20241111_HPP +#define CONVERTER_20241111_HPP + #include "device_info.hpp" #include "node/data_type_node.hpp" #include "node/inst_node.hpp" @@ -12,19 +14,18 @@ #include "spv/module.hpp" #include "spv/uniquifier.hpp" #include "support/casting.hpp" +#include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" #include -#include #include #include #include namespace tinytc::spv { -auto convert_prog_to_spirv(tinytc_prog const &p, - tinytc_core_info const &info) -> std::unique_ptr; +auto convert_prog_to_spirv(tinytc_prog const &p, tinytc_core_info const &info) -> ::tinytc::spv_mod; class dope_vector { public: @@ -60,7 +61,7 @@ class dope_vector { class inst_converter { public: - inst_converter(tinytc_compiler_context_t ctx, mod &m); + inst_converter(tinytc_spv_mod &m); // Instruction nodes void operator()(inst_node const &in); @@ -120,8 +121,7 @@ class inst_converter { void make_store(store_flag flag, scalar_type sty, address_space as, spv_inst *pointer, spv_inst *value); - tinytc_compiler_context_t ctx_; - mod *mod_; + tinytc_spv_mod_t mod_; uniquifier unique_; std::unordered_map dope_vec_; std::unordered_map vals_; @@ -137,3 +137,5 @@ template concept spv_inst_with_required_extensions = requires() { T::required_extensions; }; } // namespace tinytc::spv + +#endif // CONVERTER_20241111_HPP diff --git a/src/spv/defs.hpp b/src/spv/defs.hpp index feba658f..24e7a271 100644 --- a/src/spv/defs.hpp +++ b/src/spv/defs.hpp @@ -11,6 +11,7 @@ #include "support/ilist_base.hpp" #include +#include #include #include #include @@ -20,7 +21,7 @@ namespace tinytc::spv { class spv_inst : public ilist_node { public: inline spv_inst(Op opcode, bool has_result_id) - : opcode_{opcode}, has_result_id_{has_result_id} {} + : opcode_{opcode}, id_{has_result_id ? 0 : std::numeric_limits::max()} {} virtual ~spv_inst() = default; spv_inst(spv_inst const &other) = delete; @@ -29,11 +30,17 @@ class spv_inst : public ilist_node { spv_inst &operator=(spv_inst &&other) = delete; inline auto opcode() const -> Op { return opcode_; } - inline auto has_result_id() const -> bool { return has_result_id_; } + // SPIR-V requires 0 < id < Bound, therefore, we can reserve 0 for encoding "produces result; id + // not yet assigned" and uint32_max for encoding "does not produce result" + inline auto has_result_id() const -> bool { + return id_ != std::numeric_limits::max(); + } + inline auto id() const -> std::uint32_t { return id_; } + inline void id(std::uint32_t id) { id_ = id; } private: Op opcode_; - bool has_result_id_; + std::uint32_t id_; }; using DecorationAttr = std::variant>; diff --git a/src/spv/enums.hpp b/src/spv/enums.hpp index 98df5440..1557607b 100644 --- a/src/spv/enums.hpp +++ b/src/spv/enums.hpp @@ -7,8 +7,12 @@ #ifndef GENERATED_ENUMS_20241111_HPP #define GENERATED_ENUMS_20241111_HPP +#include + namespace tinytc::spv { +constexpr std::int32_t magic_number = 0x07230203; + enum class Op { Nop = 0, Undef = 1, diff --git a/src/spv/inst_assembler.cpp b/src/spv/inst_assembler.cpp new file mode 100644 index 00000000..b99a840d --- /dev/null +++ b/src/spv/inst_assembler.cpp @@ -0,0 +1,67 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/inst_assembler.hpp" +#include "spv/enums.hpp" + +#include +#include +#include + +namespace tinytc::spv { + +inst_assembler::inst_assembler(word_stream &stream) : stream_{&stream} {} + +void inst_assembler::operator()(DecorationAttr const &da) { + std::visit(overloaded{[&](auto const &a) { this->operator()(a); }, + [&](std::pair const &a) { + *stream_ << a.first; + this->operator()(a.second); + }}, + da); +} +void inst_assembler::operator()(ExecutionModeAttr const &ea) { + std::visit(overloaded{[&](std::int32_t const &a) { *stream_ << a; }, + [&](std::array const &a) { + for (auto const &s : a) { + *stream_ << s; + } + }}, + ea); +} +void inst_assembler::operator()(LiteralContextDependentNumber const &l) { + std::visit(overloaded{[&](auto const &l) { *stream_ << l; }}, l); +} +void inst_assembler::operator()(LiteralInteger const &l) { *stream_ << l; } +void inst_assembler::operator()(LiteralString const &l) { *stream_ << l; } + +void inst_assembler::operator()(PairIdRefIdRef const &p) { + this->operator()(p.first); + this->operator()(p.second); +} +void inst_assembler::operator()(PairIdRefLiteralInteger const &p) { + this->operator()(p.first); + this->operator()(p.second); +} +void inst_assembler::operator()(PairLiteralIntegerIdRef const &p) { + std::visit(overloaded{[&](auto const &l) { *stream_ << l; }}, p.first); + this->operator()(p.second); +} + +void inst_assembler::pre_visit(spv_inst const &) { + *stream_ << std::int32_t{0}; + last_opcode_pos_ = stream_->tell(); +} + +void inst_assembler::visit_result(spv_inst const &in) { *stream_ << in.id(); } + +void inst_assembler::post_visit(spv_inst const &in) { + const std::int32_t word_count = stream_->tell() - last_opcode_pos_ + 1; + const auto ophead = (word_count << 16) | static_cast(in.opcode()); + stream_->update(last_opcode_pos_, ophead); +} + +void inst_assembler::operator()(spv_inst *const &in) { *stream_ << in->id(); } + +} // namespace tinytc::spv + diff --git a/src/spv/inst_assembler.hpp b/src/spv/inst_assembler.hpp new file mode 100644 index 00000000..cd322269 --- /dev/null +++ b/src/spv/inst_assembler.hpp @@ -0,0 +1,92 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef INST_ASSEMBLER_20241111_HPP +#define INST_ASSEMBLER_20241111_HPP + +#include "spv/defs.hpp" +#include "spv/visit.hpp" + +#include +#include +#include +#include + +namespace tinytc::spv { + +template class word_stream { + public: + word_stream(std::vector &vec) : vec_{&vec} {} + + template auto operator<<(T const &t) -> word_stream & { + const std::size_t insert_pos = vec_->size() / sizeof(WordT); + vec_->resize(vec_->size() + word_count(t) * sizeof(WordT)); + update(insert_pos, t); + return *this; + } + + template static auto word_count(T const &) -> std::size_t { + return 1 + (sizeof(T) - 1) / sizeof(WordT); // ceil(sizeof(T)/sizeof(WordT)) + } + + static auto word_count(std::string const &s) -> std::size_t { + return 1 + s.size() / sizeof(WordT); // ceil((s.size()+1)/sizeof(WordT)) + } + + template auto update(std::size_t word, T const &t) -> word_stream & { + const std::size_t addr = word * sizeof(WordT); + std::memcpy(vec_->data() + addr, &t, sizeof(T)); + return *this; + } + + auto update(std::size_t word, std::string const &s) -> word_stream & { + const std::size_t addr = word * sizeof(WordT); + std::memcpy(vec_->data() + addr, s.c_str(), s.size() + 1); + return *this; + } + + //! Returns last word position + auto tell() const -> std::size_t { + const auto size = vec_->size(); + return size > 0 ? size / sizeof(WordT) - 1 : 0; + } + + private: + std::vector *vec_; +}; + +class inst_assembler : public default_visitor { + public: + using default_visitor::operator(); + + inst_assembler(word_stream &stream); + + void pre_visit(spv_inst const &in); + void visit_result(spv_inst const &in); + void post_visit(spv_inst const &in); + + template + requires std::is_enum_v + void operator()(T const &t) { + *stream_ << static_cast(t); + } + void operator()(DecorationAttr const &da); + void operator()(ExecutionModeAttr const &ea); + void operator()(LiteralContextDependentNumber const &l); + void operator()(LiteralInteger const &l); + void operator()(LiteralString const &l); + + void operator()(PairIdRefIdRef const &p); + void operator()(PairIdRefLiteralInteger const &p); + void operator()(PairLiteralIntegerIdRef const &p); + + void operator()(spv_inst *const &in); + + private: + word_stream *stream_; + std::size_t last_opcode_pos_ = 0; +}; + +} // namespace tinytc::spv + +#endif // INST_ASSEMBLER_20241111_HPP diff --git a/src/spv/module.cpp b/src/spv/module.cpp index 60a83361..71156b62 100644 --- a/src/spv/module.cpp +++ b/src/spv/module.cpp @@ -1,31 +1,99 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "spv/module.hpp" +#include "tinytc/types.h" + #include "spv/defs.hpp" +#include "spv/module.hpp" +#include "spv/pass/dump_asm.hpp" #include "support/ilist_base.hpp" +#include +#include +#include +#include + namespace tinytc { void ilist_callbacks::node_added(spv::spv_inst *) {} void ilist_callbacks::node_removed(spv::spv_inst *node) { delete node; } } // namespace tinytc -namespace tinytc::spv { -mod::mod(std::int32_t major_version, std::int32_t minor_version) - : major_version_{major_version}, minor_version_{minor_version} {} -mod::~mod() {} +using namespace tinytc; + +tinytc_spv_mod::tinytc_spv_mod(compiler_context ctx, tinytc_core_feature_flags_t core_features, + std::int32_t major_version, std::int32_t minor_version) + : ctx_{std::move(ctx)}, core_features_(core_features), major_version_{major_version}, + minor_version_{minor_version} {} +tinytc_spv_mod::~tinytc_spv_mod() {} -auto mod::bound() const -> std::int32_t { - std::int32_t bnd = 0; +auto tinytc_spv_mod::bound() const -> std::uint32_t { + std::uint32_t bnd = 0; for (auto const &sec : insts_) { for (auto const &i : sec) { if (i.has_result_id()) { - ++bnd; + bnd = std::max(bnd, i.id()); } } } - return bnd; + return bnd + 1; +} + +extern "C" { + +tinytc_status_t tinytc_spv_mod_dump(const_tinytc_spv_mod_t mod) { + if (mod == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { spv::dump_asm_pass{std::cerr}.run_on_module(*mod); }); +} + +tinytc_status_t tinytc_spv_mod_print_to_file(const_tinytc_spv_mod_t mod, char const *filename) { + if (mod == nullptr || filename == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + auto stream = std::ofstream(filename); + if (!stream.good()) { + throw status::file_io_error; + } + spv::dump_asm_pass{stream}.run_on_module(*mod); + }); } -} // namespace tinytc::spv +tinytc_status_t tinytc_spv_mod_print_to_string(const_tinytc_spv_mod_t mod, char **str) { + if (mod == nullptr || str == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + auto const text = [&] { + auto oss = std::ostringstream{}; + spv::dump_asm_pass{oss}.run_on_module(*mod); + return std::move(oss).str(); + }(); + auto const length = text.size() + 1; // Need to include terminating null character + *str = (char *)malloc(length * sizeof(char)); + if (!str) { + throw status::bad_alloc; + } + std::strncpy(*str, text.c_str(), length); + }); +} +tinytc_status_t tinytc_spv_mod_release(tinytc_spv_mod_t obj) { + if (obj == nullptr) { + return tinytc_status_invalid_arguments; + } + auto ref_count = obj->dec_ref(); + if (ref_count == 0) { + delete obj; + } + return tinytc_status_success; +} +tinytc_status_t tinytc_spv_mod_retain(tinytc_spv_mod_t obj) { + if (obj == nullptr) { + return tinytc_status_invalid_arguments; + } + obj->inc_ref(); + return tinytc_status_success; +} +} diff --git a/src/spv/module.hpp b/src/spv/module.hpp index f4917c3f..7955ef56 100644 --- a/src/spv/module.hpp +++ b/src/spv/module.hpp @@ -4,7 +4,12 @@ #ifndef MODULE_20241029_HPP #define MODULE_20241029_HPP +#include "compiler_context.hpp" +#include "reference_counted.hpp" +#include "spv/defs.hpp" #include "support/ilist.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" #include #include @@ -13,10 +18,6 @@ namespace tinytc { -namespace spv { -class spv_inst; -} - template <> struct ilist_callbacks { void node_added(spv::spv_inst *node); void node_removed(spv::spv_inst *node); @@ -37,40 +38,57 @@ enum class section { }; inline constexpr std::int32_t num_module_sections = 9; -class mod final { +} // namespace spv +} // namespace tinytc + +struct tinytc_spv_mod final : tinytc::reference_counted { public: - using iterator = ilist::iterator; - using const_iterator = ilist::const_iterator; + using iterator = tinytc::ilist::iterator; + using const_iterator = tinytc::ilist::const_iterator; + + tinytc_spv_mod(tinytc::compiler_context ctx, tinytc_core_feature_flags_t core_features, + std::int32_t major_version = 1, std::int32_t minor_version = 6); + ~tinytc_spv_mod(); - mod(std::int32_t major_version = 1, std::int32_t minor_version = 6); - ~mod(); + inline auto context() const -> tinytc_compiler_context_t { return ctx_.get(); } + inline auto share_context() const -> tinytc::compiler_context { return ctx_; } + inline auto core_features() const -> tinytc_core_feature_flags_t { return core_features_; } - auto bound() const -> std::int32_t; + auto bound() const -> std::uint32_t; - inline auto insts(section s) -> ilist & { return insts_[static_cast(s)]; } - inline auto insts(section s) const -> ilist const & { + inline auto insts(tinytc::spv::section s) -> tinytc::ilist & { return insts_[static_cast(s)]; } - inline auto empty(section s) const -> bool { return insts_[static_cast(s)].empty(); } + inline auto + insts(tinytc::spv::section s) const -> tinytc::ilist const & { + return insts_[static_cast(s)]; + } + inline auto empty(tinytc::spv::section s) const -> bool { + return insts_[static_cast(s)].empty(); + } inline auto major_version() const -> std::int32_t { return major_version_; } inline auto minor_version() const -> std::int32_t { return minor_version_; } - template auto add_to(section s, Args &&...args) -> T * { + template + auto add_to(tinytc::spv::section s, Args &&...args) -> T * { auto ptr = std::make_unique(std::forward(args)...).release(); insts(s).push_back(ptr); return ptr; } template auto add(Args &&...args) -> T * { - return add_to(section::function, std::forward(args)...); + return add_to(tinytc::spv::section::function, std::forward(args)...); } private: - std::array, num_module_sections> insts_; + tinytc::compiler_context ctx_; + tinytc_core_feature_flags_t core_features_; + std::array, tinytc::spv::num_module_sections> insts_; std::int32_t major_version_, minor_version_; }; -} // namespace spv -} // namespace tinytc +namespace tinytc::spv { +// using mod = ::tinytc_spv_mod; +} // namespace tinytc::spv #endif // MODULE_20241029_HPP diff --git a/src/spv/pass/assemble.cpp b/src/spv/pass/assemble.cpp new file mode 100644 index 00000000..77cf0d86 --- /dev/null +++ b/src/spv/pass/assemble.cpp @@ -0,0 +1,49 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/pass/assemble.hpp" +#include "spv/enums.hpp" +#include "spv/inst_assembler.hpp" +#include "spv/module.hpp" +#include "spv/visit.hpp" +#include "support/ilist.hpp" +#include "support/ilist_base.hpp" +#include "support/util.hpp" +#include "tinytc/tinytc.h" + +#include +#include + +namespace tinytc::spv { + +auto assembler::run_on_module(tinytc_spv_mod const &mod) -> binary { + auto data = std::vector{}; + auto stream = word_stream{data}; + + const std::int32_t bound = mod.bound(); + // Guess instruction stream by using 5 words per instruction that produces a result + // Not really important, but could be improved + data.reserve(5 * sizeof(std::int32_t) * bound); + + // Make header + const std::int32_t version = (mod.major_version() << 16) | (mod.minor_version() << 8); + const std::int32_t generator_number = 0; + stream << magic_number << version << generator_number << bound << std::int32_t{0}; + + // Assemble instructions + auto ia = inst_assembler{stream}; + for (std::int32_t s = 0; s < num_module_sections; ++s) { + for (auto const &i : mod.insts(enum_cast
(s))) { + visit(ia, i); + } + } + + // Create binary + tinytc_binary_t bin; + CHECK_STATUS(tinytc_binary_create(&bin, mod.context(), tinytc_bundle_format_spirv, data.size(), + data.data(), mod.core_features())); + return binary{bin}; +} + +} // namespace tinytc::spv + diff --git a/src/spv/pass/assemble.hpp b/src/spv/pass/assemble.hpp new file mode 100644 index 00000000..e4f6172b --- /dev/null +++ b/src/spv/pass/assemble.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef ASSEMBLE_20241111_HPP +#define ASSEMBLE_20241111_HPP + +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" + +namespace tinytc::spv { + +class assembler { + public: + auto run_on_module(tinytc_spv_mod const &mod) -> binary; +}; + +} // namespace tinytc::spv + +#endif // ASSEMBLE_20241111_HPP diff --git a/src/spv/pass/assign_ids.cpp b/src/spv/pass/assign_ids.cpp new file mode 100644 index 00000000..7d1a3359 --- /dev/null +++ b/src/spv/pass/assign_ids.cpp @@ -0,0 +1,59 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/pass/assign_ids.hpp" +#include "spv/module.hpp" +#include "support/casting.hpp" +#include "support/ilist.hpp" +#include "support/ilist_base.hpp" +#include "support/util.hpp" +#include "tinytc/types.hpp" + +#include +#include + +namespace tinytc::spv { + +void id_assigner::declare(spv_inst *in) { + if (!slot_map_.contains(in)) { + const auto slot = slot_++; + slot_map_[in] = slot; + in->id(slot); + } +} + +void id_assigner::visit_result(spv_inst &in) { declare(&in); } + +void id_assigner::operator()(spv_inst *&in) { + if (!slot_map_.contains(in)) { + if (isa(*in) || isa(*in) || isa(*in) || + isa(*in)) { + declare(in); + } else { + throw status::spirv_forbidden_forward_declaration; + } + } +} + +void id_assigner::operator()(OpPhi &in) { + pre_visit(in); + this->operator()(in.type()); + this->visit_result(in); + for (auto &op : in.op0()) { + // Forward references are allowed in phi instructions + declare(op.first); + this->operator()(op); + } + post_visit(in); +} + +void id_assigner::run_on_module(tinytc_spv_mod &m) { + for (std::int32_t s = 0; s < num_module_sections; ++s) { + for (auto &i : m.insts(enum_cast
(s))) { + visit(*this, i); + } + } +} + +} // namespace tinytc::spv + diff --git a/src/spv/pass/assign_ids.hpp b/src/spv/pass/assign_ids.hpp new file mode 100644 index 00000000..7eaa50d3 --- /dev/null +++ b/src/spv/pass/assign_ids.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef ASSIGN_IDS_20241111_HPP +#define ASSIGN_IDS_20241111_HPP + +#include "spv/defs.hpp" +#include "spv/instructions.hpp" +#include "spv/visit.hpp" +#include "tinytc/types.h" + +#include +#include + +namespace tinytc::spv { + +class id_assigner : public default_visitor { + public: + using default_visitor::operator(); + + void visit_result(spv_inst &in); + + // Do nothing by default + template void operator()(T &) {} + + void operator()(spv_inst *&in); + void operator()(OpPhi &in); + + void run_on_module(tinytc_spv_mod &m); + + private: + void declare(spv_inst *in); + + std::uint32_t slot_ = 1; + std::unordered_map slot_map_; +}; + +} // namespace tinytc::spv + +#endif // ASSIGN_IDS_20241111_HPP diff --git a/src/spv/pass/dump_asm.cpp b/src/spv/pass/dump_asm.cpp index f837117f..10b80cd5 100644 --- a/src/spv/pass/dump_asm.cpp +++ b/src/spv/pass/dump_asm.cpp @@ -9,10 +9,10 @@ #include "support/ilist.hpp" #include "support/ilist_base.hpp" #include "support/util.hpp" -#include "tinytc/types.hpp" #include #include +#include #include #include #include @@ -23,16 +23,6 @@ namespace tinytc::spv { dump_asm_pass::dump_asm_pass(std::ostream &os) : os_(&os) {} -auto dump_asm_pass::declare(spv_inst const *in) -> std::int64_t { - auto s = slot_map_.find(in); - if (s == slot_map_.end()) { - const auto slot = slot_++; - slot_map_[in] = slot; - return slot; - } - return s->second; -} - void dump_asm_pass::pre_visit(spv_inst const &in) { auto const num_digits = [](std::int64_t number) { std::int64_t d = 1; @@ -43,12 +33,12 @@ void dump_asm_pass::pre_visit(spv_inst const &in) { }; *os_ << std::endl; if (in.has_result_id()) { - const auto slot = declare(&in); + const auto id = in.id(); - for (int i = 0; i < rhs_indent - 4 - num_digits(slot); ++i) { + for (int i = 0; i < rhs_indent - 4 - num_digits(id); ++i) { *os_ << ' '; } - *os_ << "%" << slot << " = "; + *os_ << "%" << id << " = "; } else { for (int i = 0; i < rhs_indent; ++i) { *os_ << ' '; @@ -109,24 +99,11 @@ void dump_asm_pass::operator()(PairLiteralIntegerIdRef const &p) { this->operator()(p.second); } -void dump_asm_pass::operator()(spv_inst *const &in) { - if (auto s = slot_map_.find(in); s != slot_map_.end()) { - *os_ << " %" << s->second; - } else if (isa(*in)) { - *os_ << " %" << declare(in); - } else if (isa(*in)) { - *os_ << " %" << declare(in); - } else if (isa(*in)) { - *os_ << " %" << declare(in); - } else if (isa(*in)) { - *os_ << " %" << declare(in); - } else { - throw status::spirv_forbidden_forward_declaration; - } -} -auto dump_asm_pass::operator()(OpExtInst const &in) { +void dump_asm_pass::operator()(spv_inst *const &in) { *os_ << " %" << in->id(); } +void dump_asm_pass::operator()(OpExtInst const &in) { pre_visit(in); this->operator()(in.type()); + visit_result(in); this->operator()(in.op0()); if (auto extimport = dyn_cast(in.op0()); @@ -139,19 +116,10 @@ auto dump_asm_pass::operator()(OpExtInst const &in) { for (auto const &op : in.op2()) { this->operator()(op); } + post_visit(in); } -auto dump_asm_pass::operator()(OpPhi const &in) { - pre_visit(in); - this->operator()(in.type()); - for (auto const &op : in.op0()) { - // Forward references are allowed in phi instructions - declare(op.first); - this->operator()(op); - } -} - -void dump_asm_pass::run_on_module(mod const &m) { +void dump_asm_pass::run_on_module(tinytc_spv_mod const &m) { auto const visit_section = [&](section s) { for (auto const &i : m.insts(s)) { visit(*this, i); diff --git a/src/spv/pass/dump_asm.hpp b/src/spv/pass/dump_asm.hpp index 9b4de2f2..a037e336 100644 --- a/src/spv/pass/dump_asm.hpp +++ b/src/spv/pass/dump_asm.hpp @@ -8,15 +8,12 @@ #include "spv/instructions.hpp" #include "spv/names.hpp" #include "spv/visit.hpp" +#include "tinytc/types.h" -#include #include -#include namespace tinytc::spv { -class mod; - class dump_asm_pass : public default_visitor { public: using default_visitor::operator(); @@ -42,18 +39,12 @@ class dump_asm_pass : public default_visitor { void operator()(PairLiteralIntegerIdRef const &p); void operator()(spv_inst *const &in); - auto operator()(OpExtInst const &in); - auto operator()(OpPhi const &in); + void operator()(OpExtInst const &in); - void run_on_module(mod const &m); + void run_on_module(tinytc_spv_mod const &m); private: - auto declare(spv_inst const *in) -> std::int64_t; - std::ostream *os_; - - std::int64_t slot_ = 0; - std::unordered_map slot_map_; }; } // namespace tinytc::spv diff --git a/src/spv/uniquifier.cpp b/src/spv/uniquifier.cpp index 6d39c633..397f57b8 100644 --- a/src/spv/uniquifier.cpp +++ b/src/spv/uniquifier.cpp @@ -2,8 +2,6 @@ // SPDX-License-Identifier: BSD-3-Clause #include "spv/uniquifier.hpp" -#include "compiler_context.hpp" -#include "node/data_type_node.hpp" #include "scalar_type.hpp" #include "spv/instructions.hpp" #include "spv/opencl.std.hpp" @@ -24,11 +22,11 @@ auto address_space_to_storage_class(address_space as) -> StorageClass { return as == address_space::local ? StorageClass::Workgroup : StorageClass::CrossWorkgroup; } -uniquifier::uniquifier(tinytc_compiler_context_t ctx, mod &m) : ctx_(ctx), mod_(&m) {} +uniquifier::uniquifier(tinytc_spv_mod &m) : mod_(&m) {} auto uniquifier::bool2_ty() -> spv_inst * { return lookup(bool2_ty_, [&] { - auto bool_ty = spv_ty(boolean_data_type::get(ctx_)); + auto bool_ty = spv_ty(boolean_data_type::get(mod_->context())); return mod_->add_to(section::type_const_var, bool_ty, 2); }); } @@ -36,12 +34,12 @@ auto uniquifier::bool2_ty() -> spv_inst * { auto uniquifier::bool_constant(bool b) -> spv_inst * { if (b) { return lookup(bool_true_, [&] { - auto bool_ty = spv_ty(boolean_data_type::get(ctx_)); + auto bool_ty = spv_ty(boolean_data_type::get(mod_->context())); return mod_->add_to(section::type_const_var, bool_ty); }); } return lookup(bool_false_, [&] { - auto bool_ty = spv_ty(boolean_data_type::get(ctx_)); + auto bool_ty = spv_ty(boolean_data_type::get(mod_->context())); return mod_->add_to(section::type_const_var, bool_ty); }); } @@ -83,10 +81,10 @@ auto uniquifier::builtin_pointee_ty(BuiltIn b) -> spv_inst * { case BuiltIn::NumEnqueuedSubgroups: case BuiltIn::SubgroupId: case BuiltIn::SubgroupLocalInvocationId: - return spv_ty(scalar_data_type::get(ctx_, scalar_type::i32)); + return spv_ty(scalar_data_type::get(mod_->context(), scalar_type::i32)); case BuiltIn::GlobalLinearId: case BuiltIn::LocalInvocationIndex: - return spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); + return spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); case BuiltIn::GlobalSize: case BuiltIn::GlobalInvocationId: case BuiltIn::WorkgroupSize: @@ -126,7 +124,7 @@ auto uniquifier::constant(LiteralContextDependentNumber cst) -> spv_inst * { scalar_type sty = std::visit( overloaded{[](auto const &c) { return to_scalar_type_v>; }}, cst); - auto ty = spv_ty(scalar_data_type::get(ctx_, sty)); + auto ty = spv_ty(scalar_data_type::get(mod_->context(), sty)); return mod_->add_to(section::type_const_var, ty, cst); }); } @@ -140,7 +138,7 @@ void uniquifier::extension(char const *ext_name) { auto uniquifier::index3_ty() -> spv_inst * { return lookup(index3_ty_, [&] { - auto index_ty = spv_ty(scalar_data_type::get(ctx_, scalar_type::index)); + auto index_ty = spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); return mod_->add_to(section::type_const_var, index_ty, 3); }); } @@ -165,7 +163,7 @@ auto uniquifier::spv_function_ty(array_view params) -> spv_inst * { return it->second; } } - auto void_ty = spv_ty(void_data_type::get(ctx_)); + auto void_ty = spv_ty(void_data_type::get(mod_->context())); return spv_function_tys_ .emplace(map_key, mod_->add_to(section::type_const_var, void_ty, std::move(params))) @@ -227,9 +225,9 @@ auto uniquifier::spv_ty(const_tinytc_data_type_t ty) -> spv_inst * { const auto sz = size(ty.ty()); if (sz == 8) { capability(Capability::Int64); - return spv_ty(scalar_data_type::get(ctx_, scalar_type::i64)); + return spv_ty(scalar_data_type::get(mod_->context(), scalar_type::i64)); } - return spv_ty(scalar_data_type::get(ctx_, scalar_type::i32)); + return spv_ty(scalar_data_type::get(mod_->context(), scalar_type::i32)); } case scalar_type::f32: case scalar_type::f64: @@ -237,11 +235,13 @@ auto uniquifier::spv_ty(const_tinytc_data_type_t ty) -> spv_inst * { return mod_->add_to(section::type_const_var, size(ty.ty()) * 8, std::nullopt); case scalar_type::c32: { - auto float_ty = spv_ty(scalar_data_type::get(ctx_, scalar_type::f32)); + auto float_ty = + spv_ty(scalar_data_type::get(mod_->context(), scalar_type::f32)); return mod_->add_to(section::type_const_var, float_ty, 2); } case scalar_type::c64: { - auto float_ty = spv_ty(scalar_data_type::get(ctx_, scalar_type::f64)); + auto float_ty = + spv_ty(scalar_data_type::get(mod_->context(), scalar_type::f64)); return mod_->add_to(section::type_const_var, float_ty, 2); } } diff --git a/src/spv/uniquifier.hpp b/src/spv/uniquifier.hpp index 88e73c31..bc1521a9 100644 --- a/src/spv/uniquifier.hpp +++ b/src/spv/uniquifier.hpp @@ -4,12 +4,14 @@ #ifndef UNIQUIFIER_20241107_HPP #define UNIQUIFIER_20241107_HPP +#include "node/data_type_node.hpp" #include "spv/defs.hpp" #include "spv/enums.hpp" #include "spv/module.hpp" #include "support/fnv1a.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" +#include "tinytc/types.hpp" #include #include @@ -25,7 +27,7 @@ auto address_space_to_storage_class(address_space as) -> StorageClass; class uniquifier { public: - uniquifier(tinytc_compiler_context_t ctx, mod &m); + uniquifier(tinytc_spv_mod &m); auto bool2_ty() -> spv_inst *; auto bool_constant(bool b) -> spv_inst *; @@ -67,8 +69,7 @@ class uniquifier { } }; - tinytc_compiler_context_t ctx_; - mod *mod_; + tinytc_spv_mod_t mod_; spv_inst *bool2_ty_ = nullptr; spv_inst *bool_true_ = nullptr, *bool_false_ = nullptr; spv_inst *index3_ty_ = nullptr; diff --git a/src/spv/visit.hpp b/src/spv/visit.hpp index 10e14694..dc09c379 100644 --- a/src/spv/visit.hpp +++ b/src/spv/visit.hpp @@ -7,12 +7,715 @@ #ifndef GENERATED_VISIT_20241111_HPP #define GENERATED_VISIT_20241111_HPP +#include "defs.hpp" +#include "enums.hpp" +#include "instructions.hpp" + namespace tinytc::spv { template struct overloaded : Ts... { using Ts::operator()...; }; template overloaded(Ts...) -> overloaded; +template auto visit(Visitor &&visitor, spv_inst &inst) { + switch (inst.opcode()) { + case Op::Nop: + return visitor(static_cast(inst)); + case Op::Undef: + return visitor(static_cast(inst)); + case Op::SourceContinued: + return visitor(static_cast(inst)); + case Op::Source: + return visitor(static_cast(inst)); + case Op::SourceExtension: + return visitor(static_cast(inst)); + case Op::Name: + return visitor(static_cast(inst)); + case Op::MemberName: + return visitor(static_cast(inst)); + case Op::String: + return visitor(static_cast(inst)); + case Op::Line: + return visitor(static_cast(inst)); + case Op::Extension: + return visitor(static_cast(inst)); + case Op::ExtInstImport: + return visitor(static_cast(inst)); + case Op::ExtInst: + return visitor(static_cast(inst)); + case Op::MemoryModel: + return visitor(static_cast(inst)); + case Op::EntryPoint: + return visitor(static_cast(inst)); + case Op::ExecutionMode: + return visitor(static_cast(inst)); + case Op::Capability: + return visitor(static_cast(inst)); + case Op::TypeVoid: + return visitor(static_cast(inst)); + case Op::TypeBool: + return visitor(static_cast(inst)); + case Op::TypeInt: + return visitor(static_cast(inst)); + case Op::TypeFloat: + return visitor(static_cast(inst)); + case Op::TypeVector: + return visitor(static_cast(inst)); + case Op::TypeMatrix: + return visitor(static_cast(inst)); + case Op::TypeImage: + return visitor(static_cast(inst)); + case Op::TypeSampler: + return visitor(static_cast(inst)); + case Op::TypeSampledImage: + return visitor(static_cast(inst)); + case Op::TypeArray: + return visitor(static_cast(inst)); + case Op::TypeRuntimeArray: + return visitor(static_cast(inst)); + case Op::TypeStruct: + return visitor(static_cast(inst)); + case Op::TypeOpaque: + return visitor(static_cast(inst)); + case Op::TypePointer: + return visitor(static_cast(inst)); + case Op::TypeFunction: + return visitor(static_cast(inst)); + case Op::TypeEvent: + return visitor(static_cast(inst)); + case Op::TypeDeviceEvent: + return visitor(static_cast(inst)); + case Op::TypeReserveId: + return visitor(static_cast(inst)); + case Op::TypeQueue: + return visitor(static_cast(inst)); + case Op::TypePipe: + return visitor(static_cast(inst)); + case Op::TypeForwardPointer: + return visitor(static_cast(inst)); + case Op::ConstantTrue: + return visitor(static_cast(inst)); + case Op::ConstantFalse: + return visitor(static_cast(inst)); + case Op::Constant: + return visitor(static_cast(inst)); + case Op::ConstantComposite: + return visitor(static_cast(inst)); + case Op::ConstantSampler: + return visitor(static_cast(inst)); + case Op::ConstantNull: + return visitor(static_cast(inst)); + case Op::Function: + return visitor(static_cast(inst)); + case Op::FunctionParameter: + return visitor(static_cast(inst)); + case Op::FunctionEnd: + return visitor(static_cast(inst)); + case Op::FunctionCall: + return visitor(static_cast(inst)); + case Op::Variable: + return visitor(static_cast(inst)); + case Op::ImageTexelPointer: + return visitor(static_cast(inst)); + case Op::Load: + return visitor(static_cast(inst)); + case Op::Store: + return visitor(static_cast(inst)); + case Op::CopyMemory: + return visitor(static_cast(inst)); + case Op::CopyMemorySized: + return visitor(static_cast(inst)); + case Op::AccessChain: + return visitor(static_cast(inst)); + case Op::InBoundsAccessChain: + return visitor(static_cast(inst)); + case Op::PtrAccessChain: + return visitor(static_cast(inst)); + case Op::ArrayLength: + return visitor(static_cast(inst)); + case Op::GenericPtrMemSemantics: + return visitor(static_cast(inst)); + case Op::InBoundsPtrAccessChain: + return visitor(static_cast(inst)); + case Op::Decorate: + return visitor(static_cast(inst)); + case Op::MemberDecorate: + return visitor(static_cast(inst)); + case Op::DecorationGroup: + return visitor(static_cast(inst)); + case Op::GroupDecorate: + return visitor(static_cast(inst)); + case Op::GroupMemberDecorate: + return visitor(static_cast(inst)); + case Op::VectorExtractDynamic: + return visitor(static_cast(inst)); + case Op::VectorInsertDynamic: + return visitor(static_cast(inst)); + case Op::VectorShuffle: + return visitor(static_cast(inst)); + case Op::CompositeConstruct: + return visitor(static_cast(inst)); + case Op::CompositeExtract: + return visitor(static_cast(inst)); + case Op::CompositeInsert: + return visitor(static_cast(inst)); + case Op::CopyObject: + return visitor(static_cast(inst)); + case Op::Transpose: + return visitor(static_cast(inst)); + case Op::SampledImage: + return visitor(static_cast(inst)); + case Op::ImageSampleImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageFetch: + return visitor(static_cast(inst)); + case Op::ImageGather: + return visitor(static_cast(inst)); + case Op::ImageDrefGather: + return visitor(static_cast(inst)); + case Op::ImageRead: + return visitor(static_cast(inst)); + case Op::ImageWrite: + return visitor(static_cast(inst)); + case Op::Image: + return visitor(static_cast(inst)); + case Op::ImageQueryFormat: + return visitor(static_cast(inst)); + case Op::ImageQueryOrder: + return visitor(static_cast(inst)); + case Op::ImageQuerySizeLod: + return visitor(static_cast(inst)); + case Op::ImageQuerySize: + return visitor(static_cast(inst)); + case Op::ImageQueryLod: + return visitor(static_cast(inst)); + case Op::ImageQueryLevels: + return visitor(static_cast(inst)); + case Op::ImageQuerySamples: + return visitor(static_cast(inst)); + case Op::ConvertFToU: + return visitor(static_cast(inst)); + case Op::ConvertFToS: + return visitor(static_cast(inst)); + case Op::ConvertSToF: + return visitor(static_cast(inst)); + case Op::ConvertUToF: + return visitor(static_cast(inst)); + case Op::UConvert: + return visitor(static_cast(inst)); + case Op::SConvert: + return visitor(static_cast(inst)); + case Op::FConvert: + return visitor(static_cast(inst)); + case Op::QuantizeToF16: + return visitor(static_cast(inst)); + case Op::ConvertPtrToU: + return visitor(static_cast(inst)); + case Op::SatConvertSToU: + return visitor(static_cast(inst)); + case Op::SatConvertUToS: + return visitor(static_cast(inst)); + case Op::ConvertUToPtr: + return visitor(static_cast(inst)); + case Op::PtrCastToGeneric: + return visitor(static_cast(inst)); + case Op::GenericCastToPtr: + return visitor(static_cast(inst)); + case Op::GenericCastToPtrExplicit: + return visitor(static_cast(inst)); + case Op::Bitcast: + return visitor(static_cast(inst)); + case Op::SNegate: + return visitor(static_cast(inst)); + case Op::FNegate: + return visitor(static_cast(inst)); + case Op::IAdd: + return visitor(static_cast(inst)); + case Op::FAdd: + return visitor(static_cast(inst)); + case Op::ISub: + return visitor(static_cast(inst)); + case Op::FSub: + return visitor(static_cast(inst)); + case Op::IMul: + return visitor(static_cast(inst)); + case Op::FMul: + return visitor(static_cast(inst)); + case Op::UDiv: + return visitor(static_cast(inst)); + case Op::SDiv: + return visitor(static_cast(inst)); + case Op::FDiv: + return visitor(static_cast(inst)); + case Op::UMod: + return visitor(static_cast(inst)); + case Op::SRem: + return visitor(static_cast(inst)); + case Op::SMod: + return visitor(static_cast(inst)); + case Op::FRem: + return visitor(static_cast(inst)); + case Op::FMod: + return visitor(static_cast(inst)); + case Op::VectorTimesScalar: + return visitor(static_cast(inst)); + case Op::MatrixTimesScalar: + return visitor(static_cast(inst)); + case Op::VectorTimesMatrix: + return visitor(static_cast(inst)); + case Op::MatrixTimesVector: + return visitor(static_cast(inst)); + case Op::MatrixTimesMatrix: + return visitor(static_cast(inst)); + case Op::OuterProduct: + return visitor(static_cast(inst)); + case Op::Dot: + return visitor(static_cast(inst)); + case Op::IAddCarry: + return visitor(static_cast(inst)); + case Op::ISubBorrow: + return visitor(static_cast(inst)); + case Op::UMulExtended: + return visitor(static_cast(inst)); + case Op::SMulExtended: + return visitor(static_cast(inst)); + case Op::Any: + return visitor(static_cast(inst)); + case Op::All: + return visitor(static_cast(inst)); + case Op::IsNan: + return visitor(static_cast(inst)); + case Op::IsInf: + return visitor(static_cast(inst)); + case Op::IsFinite: + return visitor(static_cast(inst)); + case Op::IsNormal: + return visitor(static_cast(inst)); + case Op::SignBitSet: + return visitor(static_cast(inst)); + case Op::LessOrGreater: + return visitor(static_cast(inst)); + case Op::Ordered: + return visitor(static_cast(inst)); + case Op::Unordered: + return visitor(static_cast(inst)); + case Op::LogicalEqual: + return visitor(static_cast(inst)); + case Op::LogicalNotEqual: + return visitor(static_cast(inst)); + case Op::LogicalOr: + return visitor(static_cast(inst)); + case Op::LogicalAnd: + return visitor(static_cast(inst)); + case Op::LogicalNot: + return visitor(static_cast(inst)); + case Op::Select: + return visitor(static_cast(inst)); + case Op::IEqual: + return visitor(static_cast(inst)); + case Op::INotEqual: + return visitor(static_cast(inst)); + case Op::UGreaterThan: + return visitor(static_cast(inst)); + case Op::SGreaterThan: + return visitor(static_cast(inst)); + case Op::UGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::SGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::ULessThan: + return visitor(static_cast(inst)); + case Op::SLessThan: + return visitor(static_cast(inst)); + case Op::ULessThanEqual: + return visitor(static_cast(inst)); + case Op::SLessThanEqual: + return visitor(static_cast(inst)); + case Op::FOrdEqual: + return visitor(static_cast(inst)); + case Op::FUnordEqual: + return visitor(static_cast(inst)); + case Op::FOrdNotEqual: + return visitor(static_cast(inst)); + case Op::FUnordNotEqual: + return visitor(static_cast(inst)); + case Op::FOrdLessThan: + return visitor(static_cast(inst)); + case Op::FUnordLessThan: + return visitor(static_cast(inst)); + case Op::FOrdGreaterThan: + return visitor(static_cast(inst)); + case Op::FUnordGreaterThan: + return visitor(static_cast(inst)); + case Op::FOrdLessThanEqual: + return visitor(static_cast(inst)); + case Op::FUnordLessThanEqual: + return visitor(static_cast(inst)); + case Op::FOrdGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::FUnordGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::ShiftRightLogical: + return visitor(static_cast(inst)); + case Op::ShiftRightArithmetic: + return visitor(static_cast(inst)); + case Op::ShiftLeftLogical: + return visitor(static_cast(inst)); + case Op::BitwiseOr: + return visitor(static_cast(inst)); + case Op::BitwiseXor: + return visitor(static_cast(inst)); + case Op::BitwiseAnd: + return visitor(static_cast(inst)); + case Op::Not: + return visitor(static_cast(inst)); + case Op::BitFieldInsert: + return visitor(static_cast(inst)); + case Op::BitFieldSExtract: + return visitor(static_cast(inst)); + case Op::BitFieldUExtract: + return visitor(static_cast(inst)); + case Op::BitReverse: + return visitor(static_cast(inst)); + case Op::BitCount: + return visitor(static_cast(inst)); + case Op::DPdx: + return visitor(static_cast(inst)); + case Op::DPdy: + return visitor(static_cast(inst)); + case Op::Fwidth: + return visitor(static_cast(inst)); + case Op::DPdxFine: + return visitor(static_cast(inst)); + case Op::DPdyFine: + return visitor(static_cast(inst)); + case Op::FwidthFine: + return visitor(static_cast(inst)); + case Op::DPdxCoarse: + return visitor(static_cast(inst)); + case Op::DPdyCoarse: + return visitor(static_cast(inst)); + case Op::FwidthCoarse: + return visitor(static_cast(inst)); + case Op::EmitVertex: + return visitor(static_cast(inst)); + case Op::EndPrimitive: + return visitor(static_cast(inst)); + case Op::EmitStreamVertex: + return visitor(static_cast(inst)); + case Op::EndStreamPrimitive: + return visitor(static_cast(inst)); + case Op::ControlBarrier: + return visitor(static_cast(inst)); + case Op::MemoryBarrier: + return visitor(static_cast(inst)); + case Op::AtomicLoad: + return visitor(static_cast(inst)); + case Op::AtomicStore: + return visitor(static_cast(inst)); + case Op::AtomicExchange: + return visitor(static_cast(inst)); + case Op::AtomicCompareExchange: + return visitor(static_cast(inst)); + case Op::AtomicCompareExchangeWeak: + return visitor(static_cast(inst)); + case Op::AtomicIIncrement: + return visitor(static_cast(inst)); + case Op::AtomicIDecrement: + return visitor(static_cast(inst)); + case Op::AtomicIAdd: + return visitor(static_cast(inst)); + case Op::AtomicISub: + return visitor(static_cast(inst)); + case Op::AtomicSMin: + return visitor(static_cast(inst)); + case Op::AtomicUMin: + return visitor(static_cast(inst)); + case Op::AtomicSMax: + return visitor(static_cast(inst)); + case Op::AtomicUMax: + return visitor(static_cast(inst)); + case Op::AtomicAnd: + return visitor(static_cast(inst)); + case Op::AtomicOr: + return visitor(static_cast(inst)); + case Op::AtomicXor: + return visitor(static_cast(inst)); + case Op::Phi: + return visitor(static_cast(inst)); + case Op::LoopMerge: + return visitor(static_cast(inst)); + case Op::SelectionMerge: + return visitor(static_cast(inst)); + case Op::Label: + return visitor(static_cast(inst)); + case Op::Branch: + return visitor(static_cast(inst)); + case Op::BranchConditional: + return visitor(static_cast(inst)); + case Op::Switch: + return visitor(static_cast(inst)); + case Op::Kill: + return visitor(static_cast(inst)); + case Op::Return: + return visitor(static_cast(inst)); + case Op::ReturnValue: + return visitor(static_cast(inst)); + case Op::Unreachable: + return visitor(static_cast(inst)); + case Op::LifetimeStart: + return visitor(static_cast(inst)); + case Op::LifetimeStop: + return visitor(static_cast(inst)); + case Op::GroupAsyncCopy: + return visitor(static_cast(inst)); + case Op::GroupWaitEvents: + return visitor(static_cast(inst)); + case Op::GroupAll: + return visitor(static_cast(inst)); + case Op::GroupAny: + return visitor(static_cast(inst)); + case Op::GroupBroadcast: + return visitor(static_cast(inst)); + case Op::GroupIAdd: + return visitor(static_cast(inst)); + case Op::GroupFAdd: + return visitor(static_cast(inst)); + case Op::GroupFMin: + return visitor(static_cast(inst)); + case Op::GroupUMin: + return visitor(static_cast(inst)); + case Op::GroupSMin: + return visitor(static_cast(inst)); + case Op::GroupFMax: + return visitor(static_cast(inst)); + case Op::GroupUMax: + return visitor(static_cast(inst)); + case Op::GroupSMax: + return visitor(static_cast(inst)); + case Op::ReadPipe: + return visitor(static_cast(inst)); + case Op::WritePipe: + return visitor(static_cast(inst)); + case Op::ReservedReadPipe: + return visitor(static_cast(inst)); + case Op::ReservedWritePipe: + return visitor(static_cast(inst)); + case Op::ReserveReadPipePackets: + return visitor(static_cast(inst)); + case Op::ReserveWritePipePackets: + return visitor(static_cast(inst)); + case Op::CommitReadPipe: + return visitor(static_cast(inst)); + case Op::CommitWritePipe: + return visitor(static_cast(inst)); + case Op::IsValidReserveId: + return visitor(static_cast(inst)); + case Op::GetNumPipePackets: + return visitor(static_cast(inst)); + case Op::GetMaxPipePackets: + return visitor(static_cast(inst)); + case Op::GroupReserveReadPipePackets: + return visitor(static_cast(inst)); + case Op::GroupReserveWritePipePackets: + return visitor(static_cast(inst)); + case Op::GroupCommitReadPipe: + return visitor(static_cast(inst)); + case Op::GroupCommitWritePipe: + return visitor(static_cast(inst)); + case Op::EnqueueMarker: + return visitor(static_cast(inst)); + case Op::EnqueueKernel: + return visitor(static_cast(inst)); + case Op::GetKernelNDrangeSubGroupCount: + return visitor(static_cast(inst)); + case Op::GetKernelNDrangeMaxSubGroupSize: + return visitor(static_cast(inst)); + case Op::GetKernelWorkGroupSize: + return visitor(static_cast(inst)); + case Op::GetKernelPreferredWorkGroupSizeMultiple: + return visitor(static_cast(inst)); + case Op::RetainEvent: + return visitor(static_cast(inst)); + case Op::ReleaseEvent: + return visitor(static_cast(inst)); + case Op::CreateUserEvent: + return visitor(static_cast(inst)); + case Op::IsValidEvent: + return visitor(static_cast(inst)); + case Op::SetUserEventStatus: + return visitor(static_cast(inst)); + case Op::CaptureEventProfilingInfo: + return visitor(static_cast(inst)); + case Op::GetDefaultQueue: + return visitor(static_cast(inst)); + case Op::BuildNDRange: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseFetch: + return visitor(static_cast(inst)); + case Op::ImageSparseGather: + return visitor(static_cast(inst)); + case Op::ImageSparseDrefGather: + return visitor(static_cast(inst)); + case Op::ImageSparseTexelsResident: + return visitor(static_cast(inst)); + case Op::NoLine: + return visitor(static_cast(inst)); + case Op::AtomicFlagTestAndSet: + return visitor(static_cast(inst)); + case Op::AtomicFlagClear: + return visitor(static_cast(inst)); + case Op::ImageSparseRead: + return visitor(static_cast(inst)); + case Op::SizeOf: + return visitor(static_cast(inst)); + case Op::TypePipeStorage: + return visitor(static_cast(inst)); + case Op::ConstantPipeStorage: + return visitor(static_cast(inst)); + case Op::CreatePipeFromPipeStorage: + return visitor(static_cast(inst)); + case Op::GetKernelLocalSizeForSubgroupCount: + return visitor(static_cast(inst)); + case Op::GetKernelMaxNumSubgroups: + return visitor(static_cast(inst)); + case Op::TypeNamedBarrier: + return visitor(static_cast(inst)); + case Op::NamedBarrierInitialize: + return visitor(static_cast(inst)); + case Op::MemoryNamedBarrier: + return visitor(static_cast(inst)); + case Op::ModuleProcessed: + return visitor(static_cast(inst)); + case Op::ExecutionModeId: + return visitor(static_cast(inst)); + case Op::DecorateId: + return visitor(static_cast(inst)); + case Op::GroupNonUniformElect: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAll: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAny: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAllEqual: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBroadcast: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBroadcastFirst: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallot: + return visitor(static_cast(inst)); + case Op::GroupNonUniformInverseBallot: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotBitExtract: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotBitCount: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotFindLSB: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotFindMSB: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffle: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleUp: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleDown: + return visitor(static_cast(inst)); + case Op::GroupNonUniformIAdd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFAdd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformIMul: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMul: + return visitor(static_cast(inst)); + case Op::GroupNonUniformSMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformUMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformSMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformUMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseAnd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseOr: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalAnd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalOr: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformQuadBroadcast: + return visitor(static_cast(inst)); + case Op::GroupNonUniformQuadSwap: + return visitor(static_cast(inst)); + case Op::CopyLogical: + return visitor(static_cast(inst)); + case Op::PtrEqual: + return visitor(static_cast(inst)); + case Op::PtrNotEqual: + return visitor(static_cast(inst)); + case Op::PtrDiff: + return visitor(static_cast(inst)); + case Op::TypeCooperativeMatrixKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixLoadKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixStoreKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixMulAddKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixLengthKHR: + return visitor(static_cast(inst)); + case Op::AtomicFMinEXT: + return visitor(static_cast(inst)); + case Op::AtomicFMaxEXT: + return visitor(static_cast(inst)); + case Op::AtomicFAddEXT: + return visitor(static_cast(inst)); + } + throw internal_compiler_error(); +} template auto visit(Visitor &&visitor, spv_inst const &inst) { switch (inst.opcode()) { @@ -713,19 +1416,29 @@ template auto visit(Visitor &&visitor, spv_inst const &inst) } throw internal_compiler_error(); } -template class default_visitor { + +template class default_visitor { public: - auto pre_visit(spv_inst const &) {} - auto operator()(OpNop const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpUndef const &in) { + template using const_t = std::conditional_t, T>; + auto pre_visit(const_t &) {} + auto visit_result(const_t &) {} + auto post_visit(const_t &) {} + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); } - auto operator()(OpSourceContinued const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpSource const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); @@ -736,99 +1449,136 @@ template class default_visitor { if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpSourceExtension const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpName const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpMemberName const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpString const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpLine const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpExtension const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpExtInstImport const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpExtInst const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); - for (auto const &op : in.op2()) { + for (auto &op : in.op2()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpMemoryModel const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpEntryPoint const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); - for (auto const &op : in.op3()) { + for (auto &op : in.op3()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpExecutionMode const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpCapability const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); } - auto operator()(OpTypeVoid const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpTypeBool const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpTypeInt const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpTypeFloat const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); if (in.op1()) { static_cast(this)->operator()(*in.op1()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpTypeVector const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpTypeMatrix const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpTypeImage const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); @@ -839,124 +1589,194 @@ template class default_visitor { if (in.op7()) { static_cast(this)->operator()(*in.op7()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpTypeSampler const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpTypeSampledImage const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpTypeArray const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpTypeRuntimeArray const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpTypeStruct const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); - for (auto const &op : in.op0()) { + static_cast(this)->visit_result(in); + for (auto &op : in.op0()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpTypeOpaque const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpTypePointer const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpTypeFunction const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); - for (auto const &op : in.op1()) { + for (auto &op : in.op1()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpTypeEvent const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpTypeDeviceEvent const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpTypeReserveId const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpTypeQueue const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpTypePipe const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpTypeForwardPointer const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpConstantTrue const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); } - auto operator()(OpConstantFalse const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); } - auto operator()(OpConstant const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpConstantComposite const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); - for (auto const &op : in.op0()) { + static_cast(this)->visit_result(in); + for (auto &op : in.op0()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpConstantSampler const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpConstantNull const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); } - auto operator()(OpFunction const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFunctionParameter const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); } - auto operator()(OpFunctionEnd const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpFunctionCall const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); - for (auto const &op : in.op1()) { + for (auto &op : in.op1()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpVariable const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); if (in.op1()) { static_cast(this)->operator()(*in.op1()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImageTexelPointer const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpLoad const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); if (in.op1()) { static_cast(this)->operator()(*in.op1()); @@ -965,8 +1785,10 @@ template class default_visitor { if (in.op2()) { static_cast(this)->operator()(*in.op2()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpStore const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); @@ -977,8 +1799,10 @@ template class default_visitor { if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpCopyMemory const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); @@ -993,8 +1817,10 @@ template class default_visitor { if (in.op4()) { static_cast(this)->operator()(*in.op4()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpCopyMemorySized const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); @@ -1010,250 +1836,333 @@ template class default_visitor { if (in.op5()) { static_cast(this)->operator()(*in.op5()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpAccessChain const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); - for (auto const &op : in.op1()) { + for (auto &op : in.op1()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpInBoundsAccessChain const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); - for (auto const &op : in.op1()) { + for (auto &op : in.op1()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpPtrAccessChain const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); - for (auto const &op : in.op2()) { + for (auto &op : in.op2()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpArrayLength const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpGenericPtrMemSemantics const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpInBoundsPtrAccessChain const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); - for (auto const &op : in.op2()) { + for (auto &op : in.op2()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpDecorate const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); if (in.op2()) { static_cast(this)->operator()(*in.op2()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpMemberDecorate const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); } - auto operator()(OpDecorationGroup const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpGroupDecorate const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); - for (auto const &op : in.op1()) { + for (auto &op : in.op1()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupMemberDecorate const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); - for (auto const &op : in.op1()) { + for (auto &op : in.op1()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpVectorExtractDynamic const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpVectorInsertDynamic const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpVectorShuffle const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); - for (auto const &op : in.op2()) { + for (auto &op : in.op2()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpCompositeConstruct const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); - for (auto const &op : in.op0()) { + static_cast(this)->visit_result(in); + for (auto &op : in.op0()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpCompositeExtract const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); - for (auto const &op : in.op1()) { + for (auto &op : in.op1()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpCompositeInsert const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); - for (auto const &op : in.op2()) { + for (auto &op : in.op2()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpCopyObject const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpTranspose const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpSampledImage const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageSampleImplicitLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); if (in.op2()) { static_cast(this)->operator()(*in.op2()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImageSampleExplicitLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageSampleDrefImplicitLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImageSampleDrefExplicitLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageSampleProjImplicitLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); if (in.op2()) { static_cast(this)->operator()(*in.op2()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImageSampleProjExplicitLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageSampleProjDrefImplicitLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImageSampleProjDrefExplicitLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageFetch const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); if (in.op2()) { static_cast(this)->operator()(*in.op2()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImageGather const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImageDrefGather const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImageRead const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); if (in.op2()) { static_cast(this)->operator()(*in.op2()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImageWrite const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); @@ -1261,1056 +2170,1399 @@ template class default_visitor { if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImage const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageQueryFormat const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageQueryOrder const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageQuerySizeLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageQuerySize const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageQueryLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageQueryLevels const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageQuerySamples const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpConvertFToU const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpConvertFToS const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpConvertSToF const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpConvertUToF const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpUConvert const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpSConvert const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpFConvert const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpQuantizeToF16 const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpConvertPtrToU const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpSatConvertSToU const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpSatConvertUToS const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpConvertUToPtr const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpPtrCastToGeneric const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpGenericCastToPtr const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpGenericCastToPtrExplicit const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpBitcast const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpSNegate const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpFNegate const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpIAdd const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFAdd const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpISub const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFSub const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpIMul const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFMul const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpUDiv const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpSDiv const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFDiv const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpUMod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpSRem const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpSMod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFRem const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFMod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpVectorTimesScalar const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpMatrixTimesScalar const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpVectorTimesMatrix const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpMatrixTimesVector const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpMatrixTimesMatrix const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpOuterProduct const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpDot const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpIAddCarry const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpISubBorrow const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpUMulExtended const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpSMulExtended const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpAny const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpAll const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpIsNan const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpIsInf const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpIsFinite const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpIsNormal const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpSignBitSet const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpLessOrGreater const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpOrdered const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpUnordered const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpLogicalEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpLogicalNotEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpLogicalOr const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpLogicalAnd const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpLogicalNot const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpSelect const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpIEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpINotEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpUGreaterThan const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpSGreaterThan const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpUGreaterThanEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpSGreaterThanEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpULessThan const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpSLessThan const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpULessThanEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpSLessThanEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFOrdEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFUnordEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFOrdNotEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFUnordNotEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFOrdLessThan const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFUnordLessThan const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFOrdGreaterThan const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFUnordGreaterThan const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFOrdLessThanEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFUnordLessThanEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFOrdGreaterThanEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpFUnordGreaterThanEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpShiftRightLogical const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpShiftRightArithmetic const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpShiftLeftLogical const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpBitwiseOr const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpBitwiseXor const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpBitwiseAnd const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpNot const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpBitFieldInsert const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpBitFieldSExtract const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpBitFieldUExtract const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpBitReverse const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpBitCount const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpDPdx const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpDPdy const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpFwidth const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpDPdxFine const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpDPdyFine const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpFwidthFine const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpDPdxCoarse const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpDPdyCoarse const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpFwidthCoarse const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpEmitVertex const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpEndPrimitive const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpEmitStreamVertex const &in) { + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpEndStreamPrimitive const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpControlBarrier const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpMemoryBarrier const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicLoad const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicStore const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicExchange const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicCompareExchange const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); static_cast(this)->operator()(in.op4()); static_cast(this)->operator()(in.op5()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicCompareExchangeWeak const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); static_cast(this)->operator()(in.op4()); static_cast(this)->operator()(in.op5()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicIIncrement const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicIDecrement const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicIAdd const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicISub const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicSMin const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicUMin const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicSMax const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicUMax const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicAnd const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicOr const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicXor const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpPhi const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); - for (auto const &op : in.op0()) { + static_cast(this)->visit_result(in); + for (auto &op : in.op0()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpLoopMerge const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpSelectionMerge const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); } - auto operator()(OpLabel const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpBranch const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpBranchConditional const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); - for (auto const &op : in.op3()) { + for (auto &op : in.op3()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpSwitch const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); - for (auto const &op : in.op2()) { + for (auto &op : in.op2()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); } - auto operator()(OpKill const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpReturn const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpReturnValue const &in) { + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); } - auto operator()(OpUnreachable const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpLifetimeStart const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpLifetimeStop const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupAsyncCopy const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); static_cast(this)->operator()(in.op4()); static_cast(this)->operator()(in.op5()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupWaitEvents const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupAll const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupAny const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupBroadcast const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupIAdd const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupFAdd const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupFMin const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupUMin const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupSMin const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupFMax const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupUMax const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupSMax const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpReadPipe const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpWritePipe const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpReservedReadPipe const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); static_cast(this)->operator()(in.op4()); static_cast(this)->operator()(in.op5()); + static_cast(this)->post_visit(in); } - auto operator()(OpReservedWritePipe const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); static_cast(this)->operator()(in.op4()); static_cast(this)->operator()(in.op5()); + static_cast(this)->post_visit(in); } - auto operator()(OpReserveReadPipePackets const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpReserveWritePipePackets const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpCommitReadPipe const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpCommitWritePipe const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpIsValidReserveId const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpGetNumPipePackets const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGetMaxPipePackets const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupReserveReadPipePackets const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupReserveWritePipePackets const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupCommitReadPipe const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupCommitWritePipe const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); } - auto operator()(OpEnqueueMarker const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpEnqueueKernel const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); @@ -2321,579 +3573,760 @@ template class default_visitor { static_cast(this)->operator()(in.op7()); static_cast(this)->operator()(in.op8()); static_cast(this)->operator()(in.op9()); - for (auto const &op : in.op10()) { + for (auto &op : in.op10()) { static_cast(this)->operator()(op); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGetKernelNDrangeSubGroupCount const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); } - auto operator()(OpGetKernelNDrangeMaxSubGroupSize const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); } - auto operator()(OpGetKernelWorkGroupSize const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpGetKernelPreferredWorkGroupSizeMultiple const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpRetainEvent const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpReleaseEvent const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpCreateUserEvent const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); } - auto operator()(OpIsValidEvent const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpSetUserEventStatus const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpCaptureEventProfilingInfo const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGetDefaultQueue const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); } - auto operator()(OpBuildNDRange const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageSparseSampleImplicitLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); if (in.op2()) { static_cast(this)->operator()(*in.op2()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImageSparseSampleExplicitLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageSparseSampleDrefImplicitLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImageSparseSampleDrefExplicitLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageSparseSampleProjImplicitLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); if (in.op2()) { static_cast(this)->operator()(*in.op2()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImageSparseSampleProjExplicitLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageSparseSampleProjDrefImplicitLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImageSparseSampleProjDrefExplicitLod const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageSparseFetch const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); if (in.op2()) { static_cast(this)->operator()(*in.op2()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImageSparseGather const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImageSparseDrefGather const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpImageSparseTexelsResident const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); } - auto operator()(OpNoLine const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpAtomicFlagTestAndSet const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicFlagClear const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpImageSparseRead const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); if (in.op2()) { static_cast(this)->operator()(*in.op2()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpSizeOf const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); } - auto operator()(OpTypePipeStorage const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpConstantPipeStorage const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpCreatePipeFromPipeStorage const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpGetKernelLocalSizeForSubgroupCount const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); } - auto operator()(OpGetKernelMaxNumSubgroups const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpTypeNamedBarrier const &in) { static_cast(this)->pre_visit(in); } - auto operator()(OpNamedBarrierInitialize const &in) { + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpMemoryNamedBarrier const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpModuleProcessed const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpExecutionModeId const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpDecorateId const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformElect const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformAll const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformAny const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformAllEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformBroadcast const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformBroadcastFirst const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformBallot const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformInverseBallot const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformBallotBitExtract const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformBallotBitCount const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformBallotFindLSB const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformBallotFindMSB const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformShuffle const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformShuffleXor const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformShuffleUp const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformShuffleDown const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformIAdd const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformFAdd const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformIMul const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformFMul const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformSMin const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformUMin const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformFMin const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformSMax const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformUMax const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformFMax const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformBitwiseAnd const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformBitwiseOr const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformBitwiseXor const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformLogicalAnd const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformLogicalOr const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformLogicalXor const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformQuadBroadcast const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpGroupNonUniformQuadSwap const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); } - auto operator()(OpCopyLogical const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpPtrEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpPtrNotEqual const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpPtrDiff const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); } - auto operator()(OpTypeCooperativeMatrixKHR const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); } - auto operator()(OpCooperativeMatrixLoadKHR const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); if (in.op2()) { @@ -2907,8 +4340,10 @@ template class default_visitor { if (in.op4()) { static_cast(this)->operator()(*in.op4()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpCooperativeMatrixStoreKHR const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); @@ -2924,45 +4359,58 @@ template class default_visitor { if (in.op5()) { static_cast(this)->operator()(*in.op5()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpCooperativeMatrixMulAddKHR const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); if (in.op3()) { static_cast(this)->operator()(*in.op3()); } + + static_cast(this)->post_visit(in); } - auto operator()(OpCooperativeMatrixLengthKHR const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicFMinEXT const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicFMaxEXT const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } - auto operator()(OpAtomicFAddEXT const &in) { + auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); static_cast(this)->operator()(in.op0()); static_cast(this)->operator()(in.op1()); static_cast(this)->operator()(in.op2()); static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); } }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4ff6cf56..d18065ac 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -68,7 +68,7 @@ if(SPIRVTools_FOUND) ) foreach(SOURCE IN LISTS SPIRV_VAL_SOURCES) get_filename_component(TEST_NAME ${SOURCE} NAME_WE) - set(CHECK_COMMAND $ -O0 -gspirv ${CMAKE_CURRENT_SOURCE_DIR}/${SOURCE} | spirv-as - -o - | ${SPIRVTools_SPIRV_VAL} -) + set(CHECK_COMMAND $ -O0 -gspirv ${CMAKE_CURRENT_SOURCE_DIR}/${SOURCE} | ${SPIRVTools_SPIRV_VAL} -) list(JOIN CHECK_COMMAND " " CHECK_COMMAND_STR) add_test(NAME spirv-val-${TEST_NAME} COMMAND bash -c "${CHECK_COMMAND_STR}") add_custom_target(spirv-val-${TEST_NAME} COMMAND ${CHECK_COMMAND}) diff --git a/test/spv/arith.ir b/test/spv/arith.ir index 21b4f9cd..62cf63cc 100644 --- a/test/spv/arith.ir +++ b/test/spv/arith.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s ; CHECK: OpCapability Int64 ; CHECK: %[[#BOOL:]] = OpTypeBool diff --git a/test/spv/arith_unary.ir b/test/spv/arith_unary.ir index 30640db4..f4bafe53 100644 --- a/test/spv/arith_unary.ir +++ b/test/spv/arith_unary.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s ; CHECK: OpCapability Int64 ; CHECK: %[[#EXT:]] = OpExtInstImport "OpenCL.std" diff --git a/test/spv/barrier.ir b/test/spv/barrier.ir index dbb6e6dd..28654897 100644 --- a/test/spv/barrier.ir +++ b/test/spv/barrier.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s ; CHECK: %[[#I32:]] = OpTypeInt 32 0 ; CHECK: %[[#SCOPE:]] = OpConstant %[[#I32]] 2 diff --git a/test/spv/builtin.ir b/test/spv/builtin.ir index 400240fc..d1214d52 100644 --- a/test/spv/builtin.ir +++ b/test/spv/builtin.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s ; CHECK: OpEntryPoint Kernel %[[#]] "tbuiltin" %[[#VAR1:]] %[[#VAR2:]] %[[#VAR3:]] %[[#VAR4:]] %[[#VAR5:]] %[[#VAR6:]] diff --git a/test/spv/calling_convention.ir b/test/spv/calling_convention.ir index f70fd5d7..05b4c608 100644 --- a/test/spv/calling_convention.ir +++ b/test/spv/calling_convention.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s ; CHECK: OpDecorate %[[#PTR_F32:]] Alignment 4 ; CHECK: OpDecorate %[[#PTR_I16:]] Alignment 2 diff --git a/test/spv/cast.ir b/test/spv/cast.ir index 264ccc36..64fa56a7 100644 --- a/test/spv/cast.ir +++ b/test/spv/cast.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s ; CHECK: OpCapability Int64 ; CHECK: OpCapability Int8 diff --git a/test/spv/compare.ir b/test/spv/compare.ir index ffafb14a..7c28f699 100644 --- a/test/spv/compare.ir +++ b/test/spv/compare.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s ; CHECK: %[[#BOOL:]] = OpTypeBool ; CHECK: %[[#BOOL2:]] = OpTypeVector %[[#BOOL]] 2 diff --git a/test/spv/for.ir b/test/spv/for.ir index 6b803d28..c7666930 100644 --- a/test/spv/for.ir +++ b/test/spv/for.ir @@ -16,7 +16,7 @@ ; CHECK: %[[#I16_C6:]] = OpConstant %[[#I16]] 6 ; CHECK: %[[#I16_C1:]] = OpConstant %[[#I16]] 1 -; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s func @for1() { %lb = constant 0 -> i16 %ub = constant 10 -> i16 diff --git a/test/spv/if.ir b/test/spv/if.ir index 18852109..222004d6 100644 --- a/test/spv/if.ir +++ b/test/spv/if.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s ; CHECK: %[[#I32:]] = OpTypeInt 32 0 ; CHECK: %[[#BOOL:]] = OpTypeBool diff --git a/test/spv/load.ir b/test/spv/load.ir index 56508478..a17c3da8 100644 --- a/test/spv/load.ir +++ b/test/spv/load.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s ; CHECK: %[[#F32:]] = OpTypeFloat 32 ; CHECK: %[[#PTR_F32:]] = OpTypePointer CrossWorkgroup %[[#F32]] diff --git a/test/spv/size.ir b/test/spv/size.ir index ee463cf5..a2a19a1c 100644 --- a/test/spv/size.ir +++ b/test/spv/size.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s ; CHECK: %[[#I64:]] = OpTypeInt 64 0 ; CHECK: %[[#I64_C8:]] = OpConstant %[[#I64]] 8 diff --git a/test/spv/store.ir b/test/spv/store.ir index 5168140e..10107cea 100644 --- a/test/spv/store.ir +++ b/test/spv/store.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s ; CHECK: OpCapability AtomicFloat32AddEXT ; CHECK: OpCapability AtomicFloat64AddEXT diff --git a/test/spv/unique_function_type.ir b/test/spv/unique_function_type.ir index e6bf62a5..c0649d5c 100644 --- a/test/spv/unique_function_type.ir +++ b/test/spv/unique_function_type.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv < %s | filecheck %s +; RUN: %tinytc-oc -gspirv -S < %s | filecheck %s func @f1() {} func @f2() {} func @f3(%a: i32, %b: f32) {} diff --git a/test/spv/work_group.ir b/test/spv/work_group.ir index d6259db3..c3748978 100644 --- a/test/spv/work_group.ir +++ b/test/spv/work_group.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -O0 < %s | filecheck %s +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s ; CHECK: OpCapability Group ; CHECK: %[[#I16:]] = OpTypeInt 16 0 diff --git a/tools/offline_compiler/main.cpp b/tools/offline_compiler/main.cpp index 0123e5eb..6375df75 100644 --- a/tools/offline_compiler/main.cpp +++ b/tools/offline_compiler/main.cpp @@ -26,6 +26,7 @@ int main(int argc, char **argv) { std::int32_t opt_level = 2; auto flags = cmd::optflag_states{}; auto gen = generator::opencl; + bool emit_asm = false; bool help = false; auto const convert_string_to_generator = [](char const *str, generator &val) { @@ -60,6 +61,7 @@ int main(int argc, char **argv) { }); parser.set_short_opt('g', &gen, "Code generation backend (opencl or spirv)") .converter(convert_string_to_generator); + parser.set_short_opt('S', &emit_asm, "Compile only; do not assemble"); parser.set_short_opt('h', &help, "Show help"); parser.set_long_opt("help", &help, "Show help"); parser.add_positional_arg("file-name", &filename, @@ -105,7 +107,15 @@ int main(int argc, char **argv) { std::cout << compile_to_opencl(std::move(p), info).get_code(); break; case generator::spirv: - compile_to_spirv(std::move(p), info); + if (emit_asm) { + auto mod = compile_to_spirv(std::move(p), info); + auto spvasm = mod.print_to_string(); + std::cout << spvasm.get(); + } else { + auto bin = compile_to_spirv_and_assemble(std::move(p), info); + auto raw_data = bin.get_raw(); + std::cout.write(reinterpret_cast(raw_data.data), raw_data.data_size); + } break; } } catch (status const &st) { diff --git a/tools/spirvgen/spirvgen.py b/tools/spirvgen/spirvgen.py index e512d5fb..50ceaf71 100755 --- a/tools/spirvgen/spirvgen.py +++ b/tools/spirvgen/spirvgen.py @@ -14,17 +14,19 @@ from gen import generate_cpp, generate_header spv_enums = 'enums.hpp' +spv_enums_includes = [''] spv_names = 'names.hpp' spv_names_includes = [spv_enums] spv_names_cpp = 'names.cpp' spv_names_cpp_includes = [spv_names, spv_enums] spv_defs = 'defs.hpp' spv_defs_includes = [ - spv_enums, 'support/ilist_base.hpp', None, '', '', - '', '' + spv_enums, 'support/ilist_base.hpp', None, '', '', + '', '', '' ] spv_ops = 'instructions.hpp' spv_visitor = 'visit.hpp' +spv_visitor_includes = [spv_defs, spv_enums, spv_ops] spv_ops_includes = [ spv_defs, spv_enums, 'error.hpp', 'support/ilist_base.hpp', None, '', '', '', '', '', '', @@ -48,6 +50,10 @@ def get_class_name(instruction): def generate_enums(f, grammar): + print(f'constexpr std::int32_t magic_number = {grammar["magic_number"]};', + file=f) + print(file=f) + print('enum class Op {', file=f) for inst in grammar['instructions']: print(f'{get_opcode_name(inst)} = {inst["opcode"]},', file=f) @@ -147,7 +153,8 @@ def generate_defs(f, grammar): print(""" class spv_inst : public ilist_node { public: - inline spv_inst(Op opcode, bool has_result_id) : opcode_{opcode}, has_result_id_{has_result_id} {} + inline spv_inst(Op opcode, bool has_result_id) + : opcode_{opcode}, id_{has_result_id ? 0 : std::numeric_limits::max()} {} virtual ~spv_inst() = default; spv_inst(spv_inst const &other) = delete; @@ -156,11 +163,15 @@ class spv_inst : public ilist_node { spv_inst &operator=(spv_inst &&other) = delete; inline auto opcode() const -> Op { return opcode_; } - inline auto has_result_id() const -> bool { return has_result_id_; } + // SPIR-V requires 0 < id < Bound, therefore, we can reserve 0 for encoding "produces result; id not yet assigned" + // and uint32_max for encoding "does not produce result" + inline auto has_result_id() const -> bool { return id_ != std::numeric_limits::max(); } + inline auto id() const -> std::uint32_t { return id_; } + inline void id(std::uint32_t id) { id_ = id; } private: Op opcode_; - bool has_result_id_; + std::uint32_t id_; }; using DecorationAttr = std::variant>; @@ -238,30 +249,46 @@ def generate_visitor(f, grammar): print("""template struct overloaded : Ts... { using Ts::operator()...; }; -template overloaded(Ts...) -> overloaded; - - template auto visit(Visitor&& visitor, spv_inst const& inst) { - switch (inst.opcode()) {""", +template overloaded(Ts...) -> overloaded;""", file=f) - for instruction in grammar['instructions']: - print(f"""case Op::{get_opcode_name(instruction)}: - return visitor(static_cast<{get_class_name(instruction)} const&>(inst));""", - file=f) - print("""} - throw internal_compiler_error(); -}""", file=f) - print('template class default_visitor { public:', - file=f) - print('auto pre_visit(spv_inst const&) {}', file=f) + for const in ["", "const"]: + print('template ', file=f) + print( + f'auto visit(Visitor&& visitor, spv_inst {const}& inst) {{ switch(inst.opcode()) {{', + file=f) + for instruction in grammar['instructions']: + print(f"""case Op::{get_opcode_name(instruction)}: + return visitor(static_cast<{get_class_name(instruction)} {const}&>(inst));""", + file=f) + print("""} + throw internal_compiler_error(); +} +""", file=f) + + print( + 'template class default_visitor { public:', + file=f) + print( + 'template using const_t = std::conditional_t, T>;', + file=f) + print('auto pre_visit(const_t&) {}', file=f) + print('auto visit_result(const_t&) {}', file=f) + print('auto post_visit(const_t&) {}', file=f) for instruction in grammar['instructions']: print( - f"""auto operator()({get_class_name(instruction)} const& in) {{""", + f"""auto operator()(const_t<{get_class_name(instruction)}>& in) {{""", file=f) print(f'static_cast(this)->pre_visit(in);', file=f) - for o in get_operands(instruction): + operands = get_operands(instruction) + if len(operands) > 0 and operands[0].name == 'type': + print(format_call('in.type()'), file=f) + operands.pop(0) + if has_result_id(instruction): + print(f'static_cast(this)->visit_result(in);', file=f) + for o in operands: if o.quantifier == '*': - print(f"""for (auto const& op : in.{o.name}()) {{ + print(f"""for (auto& op : in.{o.name}()) {{ {format_call('op')} }} """, @@ -274,6 +301,7 @@ def generate_visitor(f, grammar): file=f) else: print(format_call(f'in.{o.name}()'), file=f) + print(f'static_cast(this)->post_visit(in);', file=f) print('}', file=f) print('};', file=f) @@ -334,7 +362,8 @@ def patch_grammar(grammar): grammar = filter_grammar(grammar, filt) grammar = patch_grammar(grammar) - generate_header(args, spv_enums, grammar, generate_enums) + generate_header(args, spv_enums, grammar, generate_enums, + spv_enums_includes) generate_header(args, spv_names, grammar, generate_names, spv_names_includes) generate_cpp(args, spv_names_cpp, grammar, generate_names_cpp, @@ -343,6 +372,7 @@ def patch_grammar(grammar): spv_defs_includes) generate_header(args, spv_ops, grammar, generate_op_classes, spv_ops_includes) - generate_header(args, spv_visitor, grammar, generate_visitor) + generate_header(args, spv_visitor, grammar, generate_visitor, + spv_visitor_includes) else: print(f'Could not find clang-format: {args.c}') From 61980b553053df2793e93a010983362ccf515d59 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 12 Nov 2024 09:36:48 +0100 Subject: [PATCH 101/297] SPIR-V: Downgrade module version Signed-off-by: Carsten Uphoff --- src/spv/module.hpp | 2 +- src/spv/uniquifier.cpp | 4 ++-- tools/spirvgen/spirvgen.py | 15 +++++++++++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/spv/module.hpp b/src/spv/module.hpp index 7955ef56..7bf88933 100644 --- a/src/spv/module.hpp +++ b/src/spv/module.hpp @@ -47,7 +47,7 @@ struct tinytc_spv_mod final : tinytc::reference_counted { using const_iterator = tinytc::ilist::const_iterator; tinytc_spv_mod(tinytc::compiler_context ctx, tinytc_core_feature_flags_t core_features, - std::int32_t major_version = 1, std::int32_t minor_version = 6); + std::int32_t major_version = 1, std::int32_t minor_version = 2); ~tinytc_spv_mod(); inline auto context() const -> tinytc_compiler_context_t { return ctx_.get(); } diff --git a/src/spv/uniquifier.cpp b/src/spv/uniquifier.cpp index 397f57b8..7bfc59e6 100644 --- a/src/spv/uniquifier.cpp +++ b/src/spv/uniquifier.cpp @@ -232,8 +232,8 @@ auto uniquifier::spv_ty(const_tinytc_data_type_t ty) -> spv_inst * { case scalar_type::f32: case scalar_type::f64: capability(Capability::Float64); - return mod_->add_to(section::type_const_var, size(ty.ty()) * 8, - std::nullopt); + return mod_->add_to(section::type_const_var, + size(ty.ty()) * 8); case scalar_type::c32: { auto float_ty = spv_ty(scalar_data_type::get(mod_->context(), scalar_type::f32)); diff --git a/tools/spirvgen/spirvgen.py b/tools/spirvgen/spirvgen.py index 50ceaf71..4f5c4921 100755 --- a/tools/spirvgen/spirvgen.py +++ b/tools/spirvgen/spirvgen.py @@ -334,6 +334,21 @@ def patch_grammar(grammar): 'kind': 'MemoryAccessAttr', 'quantifier': '?' }) + + version = grammar["major_version"] * 256 + grammar["minor_version"] + # Old grammar files have duplicate enumerants that need to be filtered + if version < 1 * 256 + 6: + for opkind in grammar['operand_kinds']: + category = opkind['category'] + if category != 'BitEnum' and category != 'ValueEnum': + continue + available_values = set() + new_enumerants = list() + for enumerant in opkind['enumerants']: + if enumerant['value'] not in available_values: + new_enumerants.append(enumerant) + available_values.add(enumerant['value']) + opkind['enumerants'] = new_enumerants return grammar From 7cb9f83bad29b4ca410ea5fb86899858bad66a51 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 12 Nov 2024 12:07:12 +0100 Subject: [PATCH 102/297] SPIR-V: alloca Signed-off-by: Carsten Uphoff --- include/tinytc/types.h | 1 + include/tinytc/types.hpp | 1 + src/CMakeLists.txt | 1 + src/analysis/stack.cpp | 30 +++++++++++++++ src/analysis/stack.hpp | 18 +++++++++ src/error.cpp | 2 + src/node/data_type_node.cpp | 15 ++++++++ src/node/data_type_node.hpp | 6 +-- src/pass/stack.cpp | 3 +- src/spv/converter.cpp | 73 ++++++++++++++++++++++++++++++++++--- src/spv/converter.hpp | 6 ++- src/spv/module.hpp | 2 +- src/spv/uniquifier.cpp | 20 ++++++---- src/spv/uniquifier.hpp | 9 +++++ test/CMakeLists.txt | 1 + test/spv/alloca.ir | 45 +++++++++++++++++++++++ 16 files changed, 214 insertions(+), 19 deletions(-) create mode 100644 src/analysis/stack.cpp create mode 100644 src/analysis/stack.hpp create mode 100644 test/spv/alloca.ir diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 18592549..fb1ffe4a 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -86,6 +86,7 @@ typedef enum { tinytc_status_ir_unsupported_coopmatrix_shape = 0x125, ///< Unsupported coopmatrix shape tinytc_status_ir_incompatible_scalar_types = 0x126, ///< Incompatible scalar types tinytc_status_ir_constant_mismatch = 0x127, ///< Constant mismatch + tinytc_status_ir_insufficient_alignment = 0x128, ///< Insufficient alignment // SPIR-V errors tinytc_status_spirv_forbidden_forward_declaration = 0x1000, ///< Forward declaration of id is forbidden diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 49a55508..31f20d85 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -94,6 +94,7 @@ enum class status { ir_unsupported_coopmatrix_shape = tinytc_status_ir_unsupported_coopmatrix_shape, ir_incompatible_scalar_types = tinytc_status_ir_incompatible_scalar_types, ir_constant_mismatch = tinytc_status_ir_constant_mismatch, + ir_insufficient_alignment = tinytc_status_ir_insufficient_alignment, spirv_forbidden_forward_declaration = tinytc_status_spirv_forbidden_forward_declaration, spirv_undefined_value = tinytc_status_spirv_undefined_value, spirv_missing_dope_vector = tinytc_status_spirv_missing_dope_vector, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 20bdf755..ad33dadb 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -21,6 +21,7 @@ set(SOURCES analysis/alias.cpp analysis/cfg.cpp analysis/equal.cpp + analysis/stack.cpp binary.cpp codegen_tools.cpp compiler.cpp diff --git a/src/analysis/stack.cpp b/src/analysis/stack.cpp new file mode 100644 index 00000000..4ffa31fa --- /dev/null +++ b/src/analysis/stack.cpp @@ -0,0 +1,30 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "analysis/stack.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "support/casting.hpp" +#include "support/walk.hpp" + +#include + +namespace tinytc { + +auto stack_high_water_mark::run_on_function(function_node const &fn) -> std::int64_t { + std::int64_t high_water_mark = 0; + + walk(fn, [&high_water_mark](inst_node const &i) { + if (auto *a = dyn_cast(&i); a) { + auto t = dyn_cast(a->result(0).ty()); + if (t == nullptr) { + throw compilation_error(a->loc(), status::ir_expected_memref); + } + high_water_mark = std::max(high_water_mark, a->stack_ptr() + t->size_in_bytes()); + } + }); + + return high_water_mark; +} + +} // namespace tinytc diff --git a/src/analysis/stack.hpp b/src/analysis/stack.hpp new file mode 100644 index 00000000..1bf52c53 --- /dev/null +++ b/src/analysis/stack.hpp @@ -0,0 +1,18 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef STACK_20241112_HPP +#define STACK_20241112_HPP + +#include "node/function_node.hpp" + +namespace tinytc { + +class stack_high_water_mark { + public: + auto run_on_function(function_node const &fn) -> std::int64_t; +}; + +} // namespace tinytc + +#endif // STACK_20241112_HPP diff --git a/src/error.cpp b/src/error.cpp index 9df6d52f..9c4e0571 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -202,6 +202,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Scalar types violate compatibility rules"; case tinytc_status_ir_constant_mismatch: return "Type of constant does not match type of returned value"; + case tinytc_status_ir_insufficient_alignment: + return "Pointer does not satisfy minimum alignment requirements"; // SPIR-V case tinytc_status_spirv_forbidden_forward_declaration: return "Forward declaration of id is forbidden"; diff --git a/src/node/data_type_node.cpp b/src/node/data_type_node.cpp index bdc8d28e..6dff7151 100644 --- a/src/node/data_type_node.cpp +++ b/src/node/data_type_node.cpp @@ -4,6 +4,7 @@ #include "node/data_type_node.hpp" #include "compiler_context_cache.hpp" #include "error.hpp" +#include "scalar_type.hpp" #include "support/casting.hpp" #include "support/fnv1a.hpp" #include "support/fnv1a_array_view.hpp" @@ -121,6 +122,20 @@ scalar_type memref_data_type::element_ty() const { return dyn_cast(element_ty_)->ty(); } +auto memref_data_type::alignment() const -> std::int32_t { + return ::tinytc::alignment(element_ty()); +} +auto memref_data_type::size_in_bytes() const -> std::int64_t { + if (is_dynamic()) { + return dynamic; + } + std::size_t s = size(element_ty()); + if (dim() > 0) { + s *= stride_.back() * shape_.back(); + } + return s; +} + auto memref_data_type::get(tinytc_data_type_t element_ty, array_view shape, array_view stride, address_space addrspace, location const &lc) -> tinytc_data_type_t { diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp index 024aadb9..94d12cd9 100644 --- a/src/node/data_type_node.hpp +++ b/src/node/data_type_node.hpp @@ -124,9 +124,6 @@ class memref_data_type : public data_type_node { inline std::int64_t shape(std::int64_t i) const { return shape_[i]; } inline auto const &stride() const { return stride_; } inline std::int64_t stride(std::int64_t i) const { return stride_[i]; } - inline std::int64_t size_in_bytes() const { - return is_dynamic() ? dynamic : size(element_ty()) * stride_.back() * shape_.back(); - } inline auto addrspace() const -> address_space { return addrspace_; } inline void addrspace(address_space space) { addrspace_ = space; } @@ -139,6 +136,9 @@ class memref_data_type : public data_type_node { inline bool is_dynamic() const { return is_dynamic_shape() || is_dynamic_stride(); } inline bool is_canonical_stride() const { return stride_ == canonical_stride(shape_); } + auto alignment() const -> std::int32_t; + auto size_in_bytes() const -> std::int64_t; + protected: memref_data_type(tinytc_compiler_context_t ctx, tinytc_data_type_t element_ty, std::vector shape, std::vector stride, diff --git a/src/pass/stack.cpp b/src/pass/stack.cpp index a8bfdff8..3c134562 100644 --- a/src/pass/stack.cpp +++ b/src/pass/stack.cpp @@ -31,6 +31,7 @@ void set_stack_ptr_pass::run_on_function(function_node &fn) { if (t == nullptr) { throw compilation_error(a.loc(), status::ir_expected_memref); } + const auto align = t->alignment(); auto size = t->size_in_bytes(); std::int64_t stack_ptr = 0; auto it = allocs.begin(); @@ -38,7 +39,7 @@ void set_stack_ptr_pass::run_on_function(function_node &fn) { if (it->start - stack_ptr >= size) { break; } - stack_ptr = it->stop; + stack_ptr = (1 + (it->stop - 1) / align) * align; } allocs.insert(it, allocation{a.result(), stack_ptr, stack_ptr + size}); a.stack_ptr(stack_ptr); diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index 75fb58b8..2756d92f 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause #include "spv/converter.hpp" +#include "analysis/stack.hpp" #include "error.hpp" #include "node/data_type_node.hpp" #include "node/function_node.hpp" @@ -132,6 +133,14 @@ auto inst_converter::get_dope_vector(tinytc_value const &v) -> dope_vector * { return nullptr; } +auto inst_converter::get_memref_type(tinytc_value const &v) const -> memref_data_type const * { + auto mt = dyn_cast(v.ty()); + if (!mt) { + throw compilation_error(v.loc(), status::ir_expected_memref); + } + return mt; +} + auto inst_converter::get_scalar_type(value_node const &v) const -> scalar_type { auto st = dyn_cast(v.ty()); if (!st) { @@ -149,10 +158,9 @@ auto inst_converter::get_coopmatrix_type(value_node const &v) const -> scalar_ty auto inst_converter::load_builtin(BuiltIn b) -> spv_inst * { auto builtin = unique_.builtin_var(b); - if (auto it = std::find(builtins_used_by_function_.begin(), builtins_used_by_function_.end(), - builtin); - it == builtins_used_by_function_.end()) { - builtins_used_by_function_.push_back(builtin); + if (auto it = std::find(vars_used_by_function_.begin(), vars_used_by_function_.end(), builtin); + it == vars_used_by_function_.end()) { + vars_used_by_function_.push_back(builtin); } return mod_->add(unique_.builtin_pointee_ty(b), builtin, MemoryAccess::Aligned, unique_.builtin_alignment(b)); @@ -344,6 +352,40 @@ void inst_converter::operator()(inst_node const &in) { throw compilation_error(in.loc(), status::not_implemented); } +void inst_converter::operator()(alloca_inst const &in) { + if (in.stack_ptr() < 0) { + throw compilation_error(in.loc(), status::internal_compiler_error, + "Invalid stack_ptr in alloca. Did you run set_stack_ptrs?"); + } + if (!stack_) { + throw compilation_error(in.loc(), status::internal_compiler_error, + "Stack required but not allocated"); + } + + auto mt = get_memref_type(in.result(0)); + if (in.stack_ptr() % mt->alignment() != 0) { + throw compilation_error(in.loc(), status::ir_insufficient_alignment); + } + + auto stack_element_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::i8)); + auto stack_ptr_ty = unique_.spv_pointer_ty(StorageClass::Workgroup, stack_element_ty, + alignment(scalar_type::i8)); + auto stack_ptr = mod_->add( + stack_ptr_ty, stack_, std::vector{unique_.constant(in.stack_ptr())}); + + auto memref_ptr_ty = unique_.spv_ty(mt); + declare(in.result(0), mod_->add(memref_ptr_ty, stack_ptr)); + + // alloca only accepts fixed-size memrefs => dope vector is constant + auto rdv = make_dope_vector(in.result(0)); + for (std::int64_t i = 0; i < mt->dim(); ++i) { + rdv->shape(i, unique_.constant(mt->shape(i))); + } + for (std::int64_t i = 0; i < mt->dim(); ++i) { + rdv->stride(i, unique_.constant(mt->stride(i))); + } +} + void inst_converter::operator()(arith_inst const &in) { auto const make_boolean = [&](arithmetic op, spv_inst *ty, spv_inst *a, spv_inst *b) -> spv_inst * { @@ -968,6 +1010,8 @@ void inst_converter::operator()(if_inst const &in) { } } +void inst_converter::operator()(lifetime_stop_inst const &) {} + void inst_converter::operator()(load_inst const &in) { auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); auto spv_pointer_index_ty = unique_.spv_pointer_ty(StorageClass::CrossWorkgroup, spv_index_ty, @@ -1148,6 +1192,25 @@ auto inst_converter::run_on_region_with_yield(region_node const ®, void inst_converter::run_on_function(function_node const &fn, core_config const &core_cfg) { core_cfg_ = core_cfg; + // Stack + auto const make_stack = [&] { + const auto high_water_mark = stack_high_water_mark{}.run_on_function(fn); + if (high_water_mark > 0) { + auto stack_element_ty = + unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::i8)); + auto stack_array_ty = unique_.spv_array_ty(stack_element_ty, high_water_mark); + auto stack_ptr_ty = + unique_.spv_pointer_ty(StorageClass::Workgroup, stack_array_ty, + alignment(scalar_type::f64, component_count::v2)); + stack_ = mod_->add_to(section::type_const_var, stack_ptr_ty, + StorageClass::Workgroup); + vars_used_by_function_.emplace_back(stack_); + } else { + stack_ = nullptr; + } + }; + make_stack(); + // Function type auto fun_ty = unique_.spv_function_ty([&] { auto params = std::vector{}; @@ -1196,7 +1259,7 @@ void inst_converter::run_on_function(function_node const &fn, core_config const // Entry point mod_->add_to(section::entry_point, ExecutionModel::Kernel, fun, - std::string{fn.name()}, std::move(builtins_used_by_function_)); + std::string{fn.name()}, std::move(vars_used_by_function_)); // Execution mode auto const work_group_size = fn.work_group_size(); diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index 7710baca..e3d60500 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -65,6 +65,7 @@ class inst_converter { // Instruction nodes void operator()(inst_node const &in); + void operator()(alloca_inst const &in); void operator()(arith_inst const &in); void operator()(arith_unary_inst const &in); void operator()(barrier_inst const &in); @@ -75,6 +76,7 @@ class inst_converter { void operator()(group_id_inst const &in); void operator()(group_size_inst const &in); void operator()(if_inst const &in); + void operator()(lifetime_stop_inst const &in); void operator()(load_inst const &in); void operator()(num_subgroups_inst const &in); void operator()(parallel_inst const &in); @@ -108,6 +110,7 @@ class inst_converter { } auto get_last_label() -> spv_inst *; auto get_dope_vector(tinytc_value const &v) -> dope_vector *; + auto get_memref_type(tinytc_value const &v) const -> memref_data_type const *; auto get_scalar_type(tinytc_value const &v) const -> scalar_type; auto get_coopmatrix_type(tinytc_value const &v) const -> scalar_type; auto load_builtin(BuiltIn b) -> spv_inst *; @@ -127,7 +130,8 @@ class inst_converter { std::unordered_map vals_; std::unordered_map> multi_vals_; std::stack> yielded_vals_; - std::vector builtins_used_by_function_; + std::vector vars_used_by_function_; + spv_inst *stack_ = nullptr; core_config core_cfg_ = {}; }; diff --git a/src/spv/module.hpp b/src/spv/module.hpp index 7bf88933..16a452a7 100644 --- a/src/spv/module.hpp +++ b/src/spv/module.hpp @@ -47,7 +47,7 @@ struct tinytc_spv_mod final : tinytc::reference_counted { using const_iterator = tinytc::ilist::const_iterator; tinytc_spv_mod(tinytc::compiler_context ctx, tinytc_core_feature_flags_t core_features, - std::int32_t major_version = 1, std::int32_t minor_version = 2); + std::int32_t major_version = 1, std::int32_t minor_version = 4); ~tinytc_spv_mod(); inline auto context() const -> tinytc_compiler_context_t { return ctx_.get(); } diff --git a/src/spv/uniquifier.cpp b/src/spv/uniquifier.cpp index 7bfc59e6..48a98d8b 100644 --- a/src/spv/uniquifier.cpp +++ b/src/spv/uniquifier.cpp @@ -154,6 +154,13 @@ auto uniquifier::opencl_ext() -> spv_inst * { [&] { return mod_->add_to(section::ext_inst, OpenCLExt); }); } +auto uniquifier::spv_array_ty(spv_inst *element_ty, std::int32_t length) -> spv_inst * { + auto key = std::make_pair(element_ty, length); + return lookup(spv_array_tys_, key, [&](std::pair const &key) { + return mod_->add_to(section::type_const_var, key.first, constant(key.second)); + }); +} + auto uniquifier::spv_function_ty(array_view params) -> spv_inst * { const auto map_key = fnv1a_step(fnv1a0(), params); auto range = spv_function_tys_.equal_range(map_key); @@ -177,8 +184,10 @@ auto uniquifier::spv_pointer_ty(StorageClass cls, spv_inst *pointee_ty, spv_pointer_tys_, key, [&](std::tuple const &key) { auto pointer_ty = mod_->add_to(section::type_const_var, std::get<0>(key), std::get<1>(key)); - mod_->add_to(section::decoration, pointer_ty, Decoration::Alignment, - DecorationAttr{std::get<2>(key)}); + if (std::get<2>(key) > 0) { + mod_->add_to(section::decoration, pointer_ty, Decoration::Alignment, + DecorationAttr{std::get<2>(key)}); + } return pointer_ty; }); } @@ -200,12 +209,7 @@ auto uniquifier::spv_ty(const_tinytc_data_type_t ty) -> spv_inst * { [&](memref_data_type const &mr) -> spv_inst * { const auto storage_cls = address_space_to_storage_class(mr.addrspace()); auto spv_element_ty = spv_ty(mr.element_data_ty()); - const std::int32_t align = [&](scalar_type sty) -> std::int32_t { - if (is_complex_type(sty)) { - return alignment(element_type(sty), component_count::v2); - } - return alignment(sty); - }(mr.element_ty()); + const auto align = mr.alignment(); return spv_pointer_ty(storage_cls, spv_element_ty, align); }, [&](scalar_data_type const &ty) -> spv_inst * { diff --git a/src/spv/uniquifier.hpp b/src/spv/uniquifier.hpp index bc1521a9..2cc72725 100644 --- a/src/spv/uniquifier.hpp +++ b/src/spv/uniquifier.hpp @@ -40,6 +40,7 @@ class uniquifier { auto index3_ty() -> spv_inst *; auto null_constant(spv_inst *spv_ty) -> spv_inst *; auto opencl_ext() -> spv_inst *; + auto spv_array_ty(spv_inst *element_ty, std::int32_t length) -> spv_inst *; auto spv_function_ty(array_view params) -> spv_inst *; auto spv_pointer_ty(StorageClass cls, spv_inst *pointee_ty, std::int32_t alignment) -> spv_inst *; @@ -62,6 +63,12 @@ class uniquifier { return var; } + struct array_key_hash { + inline auto + operator()(std::pair const &key) const -> std::size_t { + return fnv1a_combine(key.first, key.second); + } + }; struct pointer_key_hash { inline auto operator()(std::tuple const &key) const -> std::size_t { @@ -79,6 +86,8 @@ class uniquifier { std::unordered_map cst_map_; std::unordered_set extensions_; std::unordered_map null_cst_; + std::unordered_map, spv_inst *, array_key_hash> + spv_array_tys_; std::unordered_multimap spv_function_tys_; std::unordered_map, spv_inst *, pointer_key_hash> diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d18065ac..4eb12010 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -51,6 +51,7 @@ add_dependencies(lit-check tinytc-oc tinytc-opt) if(SPIRVTools_FOUND) set(SPIRV_VAL_SOURCES + spv/alloca.ir spv/arith.ir spv/arith_unary.ir spv/barrier.ir diff --git a/test/spv/alloca.ir b/test/spv/alloca.ir new file mode 100644 index 00000000..d48f54a5 --- /dev/null +++ b/test/spv/alloca.ir @@ -0,0 +1,45 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s + +; CHECK: OpEntryPoint Kernel %[[#]] "alloca" %[[#STACK_VAR:]] +; CHECK: OpDecorate %[[#STACK_PTR_TY:]] Alignment 16 +; CHECK: OpDecorate %[[#I8_PTR:]] Alignment 1 +; CHECK: OpDecorate %[[#F32_PTR:]] Alignment 4 +; CHECK: OpDecorate %[[#I16_PTR:]] Alignment 2 +; CHECK: %[[#I8:]] = OpTypeInt 8 0 +; CHECK: %[[#STACK_SIZE:]] = OpConstant %[[#]] 264 +; CHECK: %[[#STACK_ARRAY_TY:]] = OpTypeArray %[[#I8]] %[[#STACK_SIZE]] +; CHECK: %[[#STACK_PTR_TY]] = OpTypePointer Workgroup %[[#STACK_ARRAY_TY]] +; CHECK: %[[#STACK_VAR]] = OpVariable %[[#STACK_PTR_TY]] Workgroup +; CHECK: %[[#I64:]] = OpTypeInt 64 0 +; CHECK: %[[#I64_C0:]] = OpConstant %[[#I64]] 0 +; CHECK: %[[#I8_PTR]] = OpTypePointer Workgroup %[[I8]] +; CHECK: %[[#I64_C8:]] = OpConstant %[[#I64]] 8 +; CHECK: %[[#F32:]] = OpTypeFloat 32 +; CHECK: %[[#F32_PTR]] = OpTypePointer Workgroup %[[F32]] +; CHECK: %[[#I64_C4:]] = OpConstant %[[#I64]] 4 +; CHECK: %[[#I64_C6:]] = OpConstant %[[#I64]] 6 +; CHECK: %[[#I16:]] = OpTypeInt 16 0 +; CHECK: %[[#I16_PTR]] = OpTypePointer Workgroup %[[I16]] + +func @alloca() { + %c0 = constant 0 -> index + %0 = alloca -> memref + %1 = alloca -> memref + %2 = alloca -> memref + %3 = load %0[%c0] : memref + %4 = load %1[%c0,%c0] : memref + %5 = load %2[] : memref + %6 = size %1[1] : memref + %7 = arith.not %6 : index +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK: %[[#STACK_PTR_I8:]] = OpInBoundsAccessChain %[[#I8_PTR]] %[[#STACK_VAR]] %[[#I64_C0]] +; CHECK-NEXT: %[[#]] = OpBitcast %[[#I8_PTR]] %[[#STACK_PTR_I8]] +; CHECK-NEXT: %[[#STACK_PTR_F32:]] = OpInBoundsAccessChain %[[#I8_PTR]] %[[#STACK_VAR]] %[[#I64_C8]] +; CHECK-NEXT: %[[#]] = OpBitcast %[[#F32_PTR]] %[[#STACK_PTR_F32]] +; CHECK-NEXT: %[[#STACK_PTR_I16:]] = OpInBoundsAccessChain %[[#I8_PTR]] %[[#STACK_VAR]] %[[#I64_C6]] +; CHECK-NEXT: %[[#]] = OpBitcast %[[#I16_PTR]] %[[#STACK_PTR_I16]] +; CHECK: %[[#]] = OpNot %[[#I64]] %[[I64_C4]] +} From 6fa6a6ef4fee63e8fef049b6a0ef41bf6ead3067 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 12 Nov 2024 14:12:07 +0100 Subject: [PATCH 103/297] SPIR-V: subview Signed-off-by: Carsten Uphoff --- src/spv/converter.cpp | 67 +++++++++++++++++++++++++++++++++++++++++++ src/spv/converter.hpp | 1 + test/CMakeLists.txt | 1 + test/spv/subview.ir | 39 +++++++++++++++++++++++++ 4 files changed, 108 insertions(+) create mode 100644 test/spv/subview.ir diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index 2756d92f..6a67fa65 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -1118,6 +1118,73 @@ void inst_converter::operator()(subgroup_size_inst const &in) { declare(in.result(0), load_builtin(BuiltIn::SubgroupSize)); } +void inst_converter::operator()(subview_inst const &in) { + + auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); + auto spv_pointer_ty = unique_.spv_ty(in.operand().ty()); + auto spv_result_ty = unique_.spv_ty(in.result(0).ty()); + + auto shape_out = std::vector{}; + auto stride_out = std::vector{}; + auto const make_offset_and_shape_stride = [&] { + auto mt = get_memref_type(in.operand()); + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + + int j = 0; + shape_out.reserve(mt->dim()); + stride_out.reserve(mt->dim()); + auto dyn_offsets = in.offsets(); + auto dyn_sizes = in.sizes(); + auto offset_acc = unique_.null_constant(spv_index_ty); + for (std::int64_t i = 0, joffset = 0, jsize = 0; i < mt->dim(); ++i) { + const std::int64_t offset = in.static_offsets()[i]; + + auto const offset_inst = [&]() -> spv_inst * { + if (is_dynamic_value(offset)) { + return val(dyn_offsets[joffset++]); + } + return unique_.constant(offset); + }; + auto tmp = mod_->add(spv_index_ty, offset_inst(), dv->stride(j)); + offset_acc = mod_->add(spv_index_ty, offset_acc, tmp); + + const std::int64_t size = in.static_sizes()[i]; + if (size > 0 || is_dynamic_value(size)) { + auto const size_inst = [&]() -> spv_inst * { + if (is_dynamic_value(size)) { + return val(dyn_sizes[jsize++]); + } + return unique_.constant(size); + }; + shape_out.emplace_back(size_inst()); + stride_out.emplace_back(dv->stride(j)); + } + ++j; + } + return offset_acc; + }; + + auto offset = make_offset_and_shape_stride(); + declare(in.result(0), mod_->add(spv_result_ty, val(in.operand()), + offset, std::vector{})); + + auto rdv = make_dope_vector(in.result(0)); + + if (shape_out.size() != static_cast(rdv->dim()) || + stride_out.size() != static_cast(rdv->dim())) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->shape(i, shape_out[i]); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->stride(i, stride_out[i]); + } +} + void inst_converter::operator()(work_group_inst const &in) { auto const make = [&](scalar_type sty, work_group_operation operation, spv_inst *spv_ty, spv_inst *operand) -> spv_inst * { diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index e3d60500..cc915688 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -85,6 +85,7 @@ class inst_converter { void operator()(subgroup_id_inst const &in); void operator()(subgroup_local_id_inst const &in); void operator()(subgroup_size_inst const &in); + void operator()(subview_inst const &in); void operator()(work_group_inst const &in); void operator()(yield_inst const &in); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4eb12010..0b24af29 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -64,6 +64,7 @@ if(SPIRVTools_FOUND) spv/load.ir spv/size.ir spv/store.ir + spv/subview.ir spv/work_group.ir spv/unique_function_type.ir ) diff --git a/test/spv/subview.ir b/test/spv/subview.ir new file mode 100644 index 00000000..53bf80d7 --- /dev/null +++ b/test/spv/subview.ir @@ -0,0 +1,39 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -O0 < %s | filecheck %s + +; CHECK: %[[#F32:]] = OpTypeFloat 32 +; CHECK: %[[#F32_PTR:]] = OpTypePointer CrossWorkgroup %[[#F32]] +; CHECK: %[[#I64:]] = OpTypeInt 64 0 +; CHECK: %[[#I64_C1:]] = OpConstant %[[#I64]] 1 +; CHECK: %[[#INDEX_NULL:]] = OpConstantNull %[[#I64]] +; CHECK: %[[#I64_C4:]] = OpConstant %[[#I64]] 4 + +func @sv1(%K0: memref, %offset: index, %size: index) { + %0 = subview %K0[4:%size, %offset] : memref + %1 = subview %K0[%offset, 4:%size] : memref + %2 = size %0[0] : memref + %3 = arith.not %2 : index + %4 = size %1[0] : memref> + %5 = arith.not %4 : index +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#P_MR:]] = OpFunctionParameter %[[#F32_PTR]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#P_STRIDE1:]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#P_OFFSET:]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#P_SIZE:]] = OpFunctionParameter %[[#I64]] +; CHECK: %[[#TMP1:]] = OpIMul %[[#I64]] %[[#I64_C4]] %[[I64_C1]] +; CHECK-NEXT: %[[#OFFSET1:]] = OpIAdd %[[#I64]] %[[#INDEX_NULL]] %[[#TMP1]] +; CHECK-NEXT: %[[#TMP2:]] = OpIMul %[[#I64]] %[[#P_OFFSET]] %[[#P_STRIDE1]] +; CHECK-NEXT: %[[#OFFSET2:]] = OpIAdd %[[#I64]] %[[#OFFSET1]] %[[#TMP2]] +; CHECK-NEXT: %[[#]] = OpInBoundsPtrAccessChain %[[#F32_PTR]] %[[#P_MR]] %[[#OFFSET2]] +; CHECK-NEXT: %[[#TMP3:]] = OpIMul %[[#I64]] %[[#P_OFFSET]] %[[I64_C1]] +; CHECK-NEXT: %[[#OFFSET3:]] = OpIAdd %[[#I64]] %[[#INDEX_NULL]] %[[#TMP3]] +; CHECK-NEXT: %[[#TMP4:]] = OpIMul %[[#I64]] %[[#I64_C4]] %[[#P_STRIDE1]] +; CHECK-NEXT: %[[#OFFSET4:]] = OpIAdd %[[#I64]] %[[#OFFSET3]] %[[#TMP4]] +; CHECK-NEXT: %[[#]] = OpInBoundsPtrAccessChain %[[#F32_PTR]] %[[#P_MR]] %[[#OFFSET4]] +; CHECK-NEXT: %[[#]] = OpNot %[[#I64]] %[[#P_SIZE]] +; CHECK-NEXT: %[[#]] = OpNot %[[#I64]] %[[#P_SIZE]] +} From dd3888bc8148477b98aa1694d46874cfe3d3fd2f Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 12 Nov 2024 14:14:42 +0100 Subject: [PATCH 104/297] Fix test Signed-off-by: Carsten Uphoff --- test/spv/subview.ir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/spv/subview.ir b/test/spv/subview.ir index 53bf80d7..6284e0c5 100644 --- a/test/spv/subview.ir +++ b/test/spv/subview.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -O0 < %s | filecheck %s +; RUN: %tinytc-oc -O0 -S -gspirv < %s | filecheck %s ; CHECK: %[[#F32:]] = OpTypeFloat 32 ; CHECK: %[[#F32_PTR:]] = OpTypePointer CrossWorkgroup %[[#F32]] From 07b86c01a9bae7b1212957b2a15746dd79b83048 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 12 Nov 2024 14:44:13 +0100 Subject: [PATCH 105/297] SPIR-V: fuse Signed-off-by: Carsten Uphoff --- src/spv/converter.cpp | 48 +++++++++++++++++++++++++++++++++++++++++-- src/spv/converter.hpp | 1 + test/CMakeLists.txt | 1 + test/spv/fuse.ir | 34 ++++++++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 test/spv/fuse.ir diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index 6a67fa65..95356b5d 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -944,6 +944,52 @@ void inst_converter::operator()(for_inst const &in) { set_results(); } +void inst_converter::operator()(fuse_inst const &in) { + auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); + + auto shape = std::vector{}; + auto stride = std::vector{}; + auto const make_shape_stride = [&] { + auto mt = get_memref_type(in.operand()); + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + shape.reserve(mt->dim()); + stride.reserve(mt->dim()); + std::int64_t i = 0; + for (; i < in.from(); ++i) { + shape.push_back(dv->shape(i)); + stride.push_back(dv->stride(i)); + } + spv_inst *prod = dv->shape(i++); + for (; i <= in.to(); ++i) { + prod = mod_->add(spv_index_ty, prod, dv->shape(i)); + } + shape.push_back(prod); + stride.push_back(dv->stride(in.from())); + for (i = in.to() + 1; i < mt->dim(); ++i) { + shape.push_back(dv->shape(i)); + stride.push_back(dv->stride(i)); + } + }; + make_shape_stride(); + declare(in.result(0), val(in.operand())); + + auto rdv = make_dope_vector(in.result(0)); + + if (shape.size() != static_cast(rdv->dim()) || + stride.size() != static_cast(rdv->dim())) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->shape(i, shape[i]); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->stride(i, stride[i]); + } +} + void inst_converter::operator()(group_id_inst const &in) { auto gid = load_builtin(BuiltIn::GlobalInvocationId); auto index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); @@ -1119,9 +1165,7 @@ void inst_converter::operator()(subgroup_size_inst const &in) { } void inst_converter::operator()(subview_inst const &in) { - auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); - auto spv_pointer_ty = unique_.spv_ty(in.operand().ty()); auto spv_result_ty = unique_.spv_ty(in.result(0).ty()); auto shape_out = std::vector{}; diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index cc915688..848d0267 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -73,6 +73,7 @@ class inst_converter { void operator()(compare_inst const &in); void operator()(constant_inst const &in); void operator()(for_inst const &in); + void operator()(fuse_inst const &in); void operator()(group_id_inst const &in); void operator()(group_size_inst const &in); void operator()(if_inst const &in); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 0b24af29..04ab8230 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -60,6 +60,7 @@ if(SPIRVTools_FOUND) spv/calling_convention.ir spv/compare.ir spv/for.ir + spv/fuse.ir spv/if.ir spv/load.ir spv/size.ir diff --git a/test/spv/fuse.ir b/test/spv/fuse.ir new file mode 100644 index 00000000..a5118136 --- /dev/null +++ b/test/spv/fuse.ir @@ -0,0 +1,34 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -O0 -S -gspirv < %s | filecheck %s + +; CHECK: %[[#F32:]] = OpTypeFloat 32 +; CHECK: %[[#F32_PTR:]] = OpTypePointer CrossWorkgroup %[[#F32]] +; CHECK: %[[#I64:]] = OpTypeInt 64 0 +; CHECK: %[[#I64_C32:]] = OpConstant %[[#I64]] 32 +; CHECK: %[[#I64_C16:]] = OpConstant %[[#I64]] 16 +; CHECK: %[[#I64_C4:]] = OpConstant %[[#I64]] 4 +; CHECK: %[[#I64_C42:]] = OpConstant %[[#I64]] 42 +; CHECK: %[[#I64_C0:]] = OpConstant %[[#I64]] 0 + +func @f1(%0: memref) { + %z = constant 0 -> index + %1 = fuse %0[1,3] : memref + %2 = size %1[0] : memref> + %3 = size %1[1] : memref> + %4 = size %1[2] : memref> + %5 = arith.not %2 : index + %6 = arith.not %3 : index + %7 = arith.not %4 : index +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#P_MR:]] = OpFunctionParameter %[[#F32_PTR]] +; CHECK-NEXT: %[[#P_SHAPE2:]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I64]] +; CHECK: %[[#PROD1:]] = OpIMul %[[#I64]] %[[#I64_C16]] %[[#P_SHAPE2]] +; CHECK-NEXT: %[[#PROD2:]] = OpIMul %[[#I64]] %[[#PROD1]] %[[#I64_C4]] +; CHECK-NEXT: %[[#]] = OpNot %[[#I64]] %[[#I64_C32]] +; CHECK-NEXT: %[[#]] = OpNot %[[#I64]] %[[#PROD2]] +; CHECK-NEXT: %[[#]] = OpNot %[[#I64]] %[[#I64_C42]] +} From 6dbbad0239eee1ec4f842f5b434987a5b9dca32a Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 12 Nov 2024 15:27:14 +0100 Subject: [PATCH 106/297] SPIR-V: expand Signed-off-by: Carsten Uphoff --- src/spv/converter.cpp | 57 ++++++++++++++ src/spv/converter.hpp | 1 + test/CMakeLists.txt | 1 + test/opt/check-ir/expand.ir | 74 +++++++++++++++++++ .../check-ir/subview.ir} | 15 ++-- test/spv/expand.ir | 45 +++++++++++ 6 files changed, 184 insertions(+), 9 deletions(-) create mode 100644 test/opt/check-ir/expand.ir rename test/{codegen/subview_return_type.ir => opt/check-ir/subview.ir} (86%) create mode 100644 test/spv/expand.ir diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index 95356b5d..51963ea3 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -825,6 +825,63 @@ void inst_converter::operator()(constant_inst const &in) { } } +void inst_converter::operator()(expand_inst const &in) { + auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); + + auto shape = std::vector{}; + auto stride = std::vector{}; + auto const make_shape_stride = [&] { + auto mt = get_memref_type(in.operand()); + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + auto static_shape = in.static_expand_shape(); + auto dyn_shape = in.expand_shape(); + + shape.reserve(mt->dim() + static_shape.size() - 1); + stride.reserve(mt->dim() + static_shape.size() - 1); + + for (std::int64_t i = 0; i < in.expanded_mode(); ++i) { + shape.push_back(dv->shape(i)); + stride.push_back(dv->stride(i)); + } + + auto get_shape = [&, j = std::size_t{0}](std::int64_t s) mutable { + if (is_dynamic_value(s)) { + return val(dyn_shape[j++]); + } + return unique_.constant(s); + }; + stride.push_back(dv->stride(in.expanded_mode())); + shape.push_back(get_shape(static_shape[0])); + for (std::size_t j = 1; j < static_shape.size(); ++j) { + stride.push_back(mod_->add(spv_index_ty, stride.back(), shape.back())); + shape.push_back(get_shape(static_shape[j])); + } + + for (std::int64_t i = in.expanded_mode() + 1; i < mt->dim(); ++i) { + shape.push_back(dv->shape(i)); + stride.push_back(dv->stride(i)); + } + }; + make_shape_stride(); + declare(in.result(0), val(in.operand())); + + auto rdv = make_dope_vector(in.result(0)); + + if (shape.size() != static_cast(rdv->dim()) || + stride.size() != static_cast(rdv->dim())) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->shape(i, shape[i]); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->stride(i, stride[i]); + } +} + void inst_converter::operator()(for_inst const &in) { const std::int64_t num_results = num_yielded_vals(in.result_begin(), in.result_end()); diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index 848d0267..93a14318 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -72,6 +72,7 @@ class inst_converter { void operator()(cast_inst const &in); void operator()(compare_inst const &in); void operator()(constant_inst const &in); + void operator()(expand_inst const &in); void operator()(for_inst const &in); void operator()(fuse_inst const &in); void operator()(group_id_inst const &in); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 04ab8230..a0b944d8 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -59,6 +59,7 @@ if(SPIRVTools_FOUND) spv/cast.ir spv/calling_convention.ir spv/compare.ir + spv/expand.ir spv/for.ir spv/fuse.ir spv/if.ir diff --git a/test/opt/check-ir/expand.ir b/test/opt/check-ir/expand.ir new file mode 100644 index 00000000..1bc54c69 --- /dev/null +++ b/test/opt/check-ir/expand.ir @@ -0,0 +1,74 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-opt -pcheck-ir -O0 < %s | filecheck %s + +; No real checks needed, just check that it does not crash, that is, +; the types put in load match those returned by expand +; CHECK: func @t1({{.*}} + +func @t1(%0: memref) { + %z = constant 0 -> index + %1 = expand %0[1->2x8] : memref + %2 = load %1[%z,%z,%z,%z] : memref +} +func @t2(%0: memref) { + %z = constant 0 -> index + %1 = expand %0[1->2x2x2x2] : memref + %2 = load %1[%z,%z,%z,%z,%z,%z] : memref +} +func @t3(%0: memref, %1: index) { + %z = constant 0 -> index + %2 = expand %0[1->%1 x 2] : memref + %3 = load %2[%z,%z,%z] : memref +} +func @t4(%0: memref, %1: index) { + %z = constant 0 -> index + %2 = expand %0[1->2 x %1] : memref + %3 = load %2[%z,%z,%z] : memref +} +func @t5(%0: memref, %1: index) { + %z = constant 0 -> index + %2 = expand %0[1->%1 x 2] : memref + %3 = load %2[%z,%z,%z] : memref +} +func @t6(%0: memref, %1: index) { + %z = constant 0 -> index + %2 = expand %0[1->%1 x 2] : memref + %3 = load %2[%z,%z,%z] : memref +} +func @t7(%0: memref, %1: index, %2: index) { + %z = constant 0 -> index + %3 = expand %0[1->%1 x %2 x 2] : memref + %4 = load %3[%z,%z,%z,%z] : memref +} +func @t8(%0: memref, %1: index, %2: index) { + %z = constant 0 -> index + %3 = expand %0[1->%2 x 2 x %1] : memref + %4 = load %3[%z,%z,%z,%z] : memref +} +func @t9(%0: memref, %1: index, %2: index) { + %z = constant 0 -> index + %3 = expand %0[1->%1 x %2] : memref + %4 = load %3[%z,%z,%z] : memref +} +func @t10(%0: memref, %1: index, %2: index) { + %z = constant 0 -> index + %3 = expand %0[1->%1 x %2] : memref + %4 = load %3[%z,%z,%z] : memref +} +func @t11(%0: memref>) { + %z = constant 0 -> index + %1 = expand %0[0->4 x 8] : memref> + %2 = load %1[%z,%z,%z] : memref> +} +func @t12(%0: memref>, %1: index) { + %z = constant 0 -> index + %2 = expand %0[0->%1 x 4] : memref> + %3 = load %2[%z,%z,%z] : memref> +} +func @t13(%0: memref>, %1: index) { + %z = constant 0 -> index + %2 = expand %0[0->4 x %1] : memref> + %3 = load %2[%z,%z,%z] : memref> +} diff --git a/test/codegen/subview_return_type.ir b/test/opt/check-ir/subview.ir similarity index 86% rename from test/codegen/subview_return_type.ir rename to test/opt/check-ir/subview.ir index a91ac1b4..709d8d02 100644 --- a/test/codegen/subview_return_type.ir +++ b/test/opt/check-ir/subview.ir @@ -1,51 +1,48 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc < %s | filecheck %s +; RUN: %tinytc-opt -pcheck-ir -O0 < %s | filecheck %s + +; No real checks needed, just check that it does not crash, that is, +; the types put in load match those returned by expand +; CHECK: func @t1({{.*}} + func @t1(%0: memref) { -; CHECK-LABEL: void t1( %z = constant 0 -> index %1 = subview %0[4:8,8:4] : memref %2 = load %1[%z,%z] : memref> } func @t2(%0: memref, %1: index) { -; CHECK-LABEL: void t2( %z = constant 0 -> index %2 = subview %0[2:4,%1] : memref %3 = load %2[%z] : memref } func @t3(%0: memref, %1: index) { -; CHECK-LABEL: void t3( %z = constant 0 -> index %2 = subview %0[2:4,%1:0] : memref %3 = load %2[%z] : memref } func @t4(%0: memref, %1: index) { -; CHECK-LABEL: void t4( %z = constant 0 -> index %2 = subview %0[2:4,%1:1] : memref %3 = load %2[%z,%z] : memref> } func @t5(%0: memref, %1: index) { -; CHECK-LABEL: void t5( %z = constant 0 -> index %2 = subview %0[%1:4] : memref %3 = load %2[%z] : memref } func @t6(%0: memref, %1: index) { -; CHECK-LABEL: void t6( %z = constant 0 -> index %2 = subview %0[%1:%1] : memref %3 = load %2[%z] : memref } func @t7(%0: memref, %1: index) { -; CHECK-LABEL: void t7( %z = constant 0 -> index %2 = subview %0[2:4, %1:%1, 6:7] : memref %3 = load %2[%z,%z,%z] : memref> } func @t8(%0: memref>, %1: index) { -; CHECK-LABEL: void t8( %z = constant 0 -> index %2 = subview %0[2:4, %1:%1, 6:7] : memref> %3 = load %2[%z,%z,%z] : memref> diff --git a/test/spv/expand.ir b/test/spv/expand.ir new file mode 100644 index 00000000..e799d2fa --- /dev/null +++ b/test/spv/expand.ir @@ -0,0 +1,45 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -O0 -S -gspirv < %s | filecheck %s + +; CHECK: %[[#I64:]] = OpTypeInt 64 0 +; CHECK: %[[#I64_C32:]] = OpConstant %[[#I64]] 32 +; CHECK: %[[#I64_C7:]] = OpConstant %[[#I64]] 7 +; CHECK: %[[#I64_C1:]] = OpConstant %[[#I64]] 1 +; CHECK: %[[#I64_C0:]] = OpConstant %[[#I64]] 0 +; CHECK: %[[#I64_C4:]] = OpConstant %[[#I64]] 4 +; CHECK: %[[#I64_C5:]] = OpConstant %[[#I64]] 5 + +func @f1(%0: memref, %1: index) { + %c0 = constant 0 -> index + %2 = expand %0[1->4x%1x5] : memref + %3 = size %2[0] : memref + %4 = size %2[1] : memref + %5 = size %2[2] : memref + %6 = size %2[3] : memref + %7 = size %2[4] : memref + %8 = arith.not %3 : index + %9 = arith.not %4 : index + %10 = arith.not %5 : index + %11 = arith.not %6 : index + %12 = arith.not %7 : index + %13 = load %2[%c0,%c0,%c0,%c0,%c0] : memref +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#P_MR:]] = OpFunctionParameter %[[#]] +; CHECK-NEXT: %[[#P_SHAPE1:]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#P_STRIDE2:]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#P_INDEX:]] = OpFunctionParameter %[[#I64]] +; CHECK: %[[#ESTRIDE2:]] = OpIMul %[[#I64]] %[[#I64_C32]] %[[#I64_C4]] +; CHECK: %[[#ESTRIDE3:]] = OpIMul %[[#I64]] %[[#ESTRIDE2]] %[[#P_INDEX]] +; CHECK-NEXT: %[[#]] = OpNot %[[#I64]] %[[#I64_C32]] +; CHECK-NEXT: %[[#]] = OpNot %[[#I64]] %[[#I64_C4]] +; CHECK-NEXT: %[[#]] = OpNot %[[#I64]] %[[#P_INDEX]] +; CHECK-NEXT: %[[#]] = OpNot %[[#I64]] %[[#I64_C5]] +; CHECK-NEXT: %[[#]] = OpNot %[[#I64]] %[[#I64_C7]] +; CHECK: %[[#]] = OpIMul %[[#I64]] %[[#I64_C0]] %[[#I64_C1]] +; CHECK: %[[#]] = OpIMul %[[#I64]] %[[#I64_C0]] %[[#I64_C32]] +; CHECK: %[[#]] = OpIMul %[[#I64]] %[[#I64_C0]] %[[#ESTRIDE2]] +; CHECK: %[[#]] = OpIMul %[[#I64]] %[[#I64_C0]] %[[#ESTRIDE3]] +; CHECK: %[[#]] = OpIMul %[[#I64]] %[[#I64_C0]] %[[#P_STRIDE2]] +} From 48b0e2de7615d7ba7c7bce9f8ee12c418d3da467 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 12 Nov 2024 15:55:24 +0100 Subject: [PATCH 107/297] SPIR-V: cooperative matrix scale Signed-off-by: Carsten Uphoff --- src/spv/converter.cpp | 144 +++++++++++++++------------ src/spv/converter.hpp | 6 ++ src/spv/pass/dump_asm.cpp | 6 +- test/CMakeLists.txt | 1 + test/spv/cooperative_matrix_scale.ir | 18 ++++ test/spv/if.ir | 4 +- test/spv/store.ir | 6 +- 7 files changed, 117 insertions(+), 68 deletions(-) create mode 100644 test/spv/cooperative_matrix_scale.ir diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index 51963ea3..f0359d36 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -183,6 +183,67 @@ auto inst_converter::multi_val(value_node const &v) -> std::vector & throw compilation_error(v.loc(), status::spirv_undefined_value); } +auto inst_converter::make_binary_op(scalar_type sty, arithmetic op, spv_inst *ty, spv_inst *a, + spv_inst *b, location const &loc) -> spv_inst * { + auto const make_int = [&](arithmetic op, spv_inst *ty, spv_inst *a, spv_inst *b) -> spv_inst * { + switch (op) { + case arithmetic::add: + return mod_->add(ty, a, b); + case arithmetic::sub: + return mod_->add(ty, a, b); + case arithmetic::mul: + return mod_->add(ty, a, b); + case arithmetic::div: + return mod_->add(ty, a, b); + case arithmetic::rem: + return mod_->add(ty, a, b); + case arithmetic::shl: + return mod_->add(ty, a, b); + case arithmetic::shr: + return mod_->add(ty, a, b); + case arithmetic::and_: + return mod_->add(ty, a, b); + case arithmetic::or_: + return mod_->add(ty, a, b); + case arithmetic::xor_: + return mod_->add(ty, a, b); + } + throw compilation_error(loc, status::internal_compiler_error); + }; + auto const make_float_complex = [&](arithmetic op, spv_inst *ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (op) { + case arithmetic::add: + return mod_->add(ty, a, b); + case arithmetic::sub: + return mod_->add(ty, a, b); + case arithmetic::mul: + return mod_->add(ty, a, b); + case arithmetic::div: + return mod_->add(ty, a, b); + case arithmetic::rem: + return mod_->add(ty, a, b); + default: + break; + } + throw compilation_error(loc, status::ir_fp_unsupported); + }; + switch (sty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return make_int(op, ty, a, b); + case scalar_type::f32: + case scalar_type::f64: + case scalar_type::c32: + case scalar_type::c64: + return make_float_complex(op, ty, a, b); + } + throw compilation_error(loc, status::internal_compiler_error); +} + auto inst_converter::make_constant(scalar_type sty, spv_inst *spv_ty, constant_inst::value_type const &val) -> spv_inst * { auto const add_constant_complex = [this, &spv_ty](auto cst) -> spv_inst * { @@ -401,66 +462,6 @@ void inst_converter::operator()(arith_inst const &in) { } throw compilation_error(in.loc(), status::ir_boolean_unsupported); }; - auto const make_int = [&](arithmetic op, spv_inst *ty, spv_inst *a, spv_inst *b) -> spv_inst * { - switch (op) { - case arithmetic::add: - return mod_->add(ty, a, b); - case arithmetic::sub: - return mod_->add(ty, a, b); - case arithmetic::mul: - return mod_->add(ty, a, b); - case arithmetic::div: - return mod_->add(ty, a, b); - case arithmetic::rem: - return mod_->add(ty, a, b); - case arithmetic::shl: - return mod_->add(ty, a, b); - case arithmetic::shr: - return mod_->add(ty, a, b); - case arithmetic::and_: - return mod_->add(ty, a, b); - case arithmetic::or_: - return mod_->add(ty, a, b); - case arithmetic::xor_: - return mod_->add(ty, a, b); - } - throw compilation_error(in.loc(), status::internal_compiler_error); - }; - auto const make_float_complex = [&](arithmetic op, spv_inst *ty, spv_inst *a, - spv_inst *b) -> spv_inst * { - switch (op) { - case arithmetic::add: - return mod_->add(ty, a, b); - case arithmetic::sub: - return mod_->add(ty, a, b); - case arithmetic::mul: - return mod_->add(ty, a, b); - case arithmetic::div: - return mod_->add(ty, a, b); - case arithmetic::rem: - return mod_->add(ty, a, b); - default: - break; - } - throw compilation_error(in.loc(), status::ir_fp_unsupported); - }; - auto const make = [&](scalar_type sty, arithmetic op, spv_inst *ty, spv_inst *a, - spv_inst *b) -> spv_inst * { - switch (sty) { - case scalar_type::i8: - case scalar_type::i16: - case scalar_type::i32: - case scalar_type::i64: - case scalar_type::index: - return make_int(op, ty, a, b); - case scalar_type::f32: - case scalar_type::f64: - case scalar_type::c32: - case scalar_type::c64: - return make_float_complex(op, ty, a, b); - } - throw compilation_error(in.loc(), status::internal_compiler_error); - }; auto ty = unique_.spv_ty(in.result(0).ty()); @@ -471,7 +472,7 @@ void inst_converter::operator()(arith_inst const &in) { } else if (auto st = dyn_cast(in.result(0).ty()); st) { auto av = val(in.a()); auto bv = val(in.b()); - declare(in.result(0), make(st->ty(), in.operation(), ty, av, bv)); + declare(in.result(0), make_binary_op(st->ty(), in.operation(), ty, av, bv, in.loc())); } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { auto const length = ct->length(core_cfg_.subgroup_size); auto insts = std::vector{}; @@ -480,7 +481,8 @@ void inst_converter::operator()(arith_inst const &in) { auto &av = multi_val(in.a()); auto &bv = multi_val(in.b()); for (std::int64_t i = 0; i < length; ++i) { - insts.emplace_back(make(ct->component_ty(), in.operation(), ty, av[i], bv[i])); + insts.emplace_back( + make_binary_op(ct->component_ty(), in.operation(), ty, av[i], bv[i], in.loc())); } multi_declare(in.result(0), std::move(insts)); @@ -825,6 +827,24 @@ void inst_converter::operator()(constant_inst const &in) { } } +void inst_converter::operator()(cooperative_matrix_load_inst const &in) {} +void inst_converter::operator()(cooperative_matrix_mul_add_inst const &in) {} +void inst_converter::operator()(cooperative_matrix_scale_inst const &in) { + auto av = val(in.a()); + auto &bv = multi_val(in.b()); + + auto insts = std::vector{}; + insts.reserve(bv.size()); + + for (std::size_t i = 0; i < bv.size(); ++i) { + insts.emplace_back(make_binary_op(get_coopmatrix_type(in.result(0)), arithmetic::mul, + unique_.spv_ty(in.a().ty()), av, bv[i], in.loc())); + } + + multi_declare(in.result(0), std::move(insts)); +} +void inst_converter::operator()(cooperative_matrix_store_inst const &in) {} + void inst_converter::operator()(expand_inst const &in) { auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index 93a14318..68dbdf18 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -72,6 +72,10 @@ class inst_converter { void operator()(cast_inst const &in); void operator()(compare_inst const &in); void operator()(constant_inst const &in); + void operator()(cooperative_matrix_load_inst const &in); + void operator()(cooperative_matrix_mul_add_inst const &in); + void operator()(cooperative_matrix_scale_inst const &in); + void operator()(cooperative_matrix_store_inst const &in); void operator()(expand_inst const &in); void operator()(for_inst const &in); void operator()(fuse_inst const &in); @@ -121,6 +125,8 @@ class inst_converter { auto val(tinytc_value const &v) -> spv_inst *; auto multi_declare(tinytc_value const &v, std::vector insts); auto multi_val(tinytc_value const &v) -> std::vector &; + auto make_binary_op(scalar_type sty, arithmetic op, spv_inst *ty, spv_inst *a, spv_inst *b, + location const &loc) -> spv_inst *; auto make_constant(scalar_type sty, spv_inst *spv_ty, constant_inst::value_type const &val) -> spv_inst *; auto make_dope_vector(tinytc_value const &v) -> dope_vector *; diff --git a/src/spv/pass/dump_asm.cpp b/src/spv/pass/dump_asm.cpp index 10b80cd5..70d0e0be 100644 --- a/src/spv/pass/dump_asm.cpp +++ b/src/spv/pass/dump_asm.cpp @@ -74,7 +74,11 @@ void dump_asm_pass::operator()(LiteralContextDependentNumber const &l) { using unsigned_t = std::make_unsigned_t>; *os_ << " " << static_cast(l); }, - [&](auto const &l) { *os_ << " " << l; }}, + [&](std::floating_point auto const &l) { + auto flags = os_->flags(); + *os_ << " " << std::hexfloat << l; + os_->flags(flags); + }}, l); } void dump_asm_pass::operator()(LiteralInteger const &l) { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a0b944d8..9891f3ca 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -59,6 +59,7 @@ if(SPIRVTools_FOUND) spv/cast.ir spv/calling_convention.ir spv/compare.ir + spv/cooperative_matrix_scale.ir spv/expand.ir spv/for.ir spv/fuse.ir diff --git a/test/spv/cooperative_matrix_scale.ir b/test/spv/cooperative_matrix_scale.ir new file mode 100644 index 00000000..f837cc7e --- /dev/null +++ b/test/spv/cooperative_matrix_scale.ir @@ -0,0 +1,18 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s + +; CHECK: %[[#F32:]] = OpTypeFloat 32 +; CHECK: %[[#F32_CPI:]] = OpConstant %[[#F32]] 0x1.921fb6p+1 +; CHECK: %[[#F32_C13_369:]] = OpConstant %[[#F32]] 0x1.abcefap+3 + +func @scale() subgroup_size(16) { + %0 = constant 3.14159265358979323846 -> f32 + %1 = constant 13.36901521971920820459 -> coopmatrix + %2 = cooperative_matrix_scale %0, %1 : f32, coopmatrix +; CHECK: %[[#]] = OpFMul %[[#F32]] %[[#F32_CPI]] %[[#F32_C13_369]] +; CHECK-NEXT: %[[#]] = OpFMul %[[#F32]] %[[#F32_CPI]] %[[#F32_C13_369]] +; CHECK-NEXT: %[[#]] = OpFMul %[[#F32]] %[[#F32_CPI]] %[[#F32_C13_369]] +; CHECK-NEXT: %[[#]] = OpFMul %[[#F32]] %[[#F32_CPI]] %[[#F32_C13_369]] +} diff --git a/test/spv/if.ir b/test/spv/if.ir index 222004d6..43a3be15 100644 --- a/test/spv/if.ir +++ b/test/spv/if.ir @@ -8,8 +8,8 @@ ; CHECK: %[[#BOOL_TRUE:]] = OpConstantTrue ; CHECK: %[[#I32_0:]] = OpConstant %[[#I32]] 0 ; CHECK: %[[#F32:]] = OpTypeFloat 32 -; CHECK: %[[#CST1:]] = OpConstant %[[#F32]] 1 -; CHECK: %[[#CST0:]] = OpConstant %[[#F32]] 0 +; CHECK: %[[#CST1:]] = OpConstant %[[#F32]] 0x1p+0 +; CHECK: %[[#CST0:]] = OpConstant %[[#F32]] 0x0p+0 func @if0(%0: i32) { %c42 = constant 42 -> i32 diff --git a/test/spv/store.ir b/test/spv/store.ir index 10107cea..79d8f374 100644 --- a/test/spv/store.ir +++ b/test/spv/store.ir @@ -19,12 +19,12 @@ ; CHECK: %[[#I32_C0:]] = OpConstant %[[#I32]] 0 ; CHECK: %[[#F32:]] = OpTypeFloat 32 ; CHECK: %[[#PTR_F32:]] = OpTypePointer CrossWorkgroup %[[#F32]] -; CHECK: %[[#F32_C42:]] = OpConstant %[[#F32]] 42 +; CHECK: %[[#F32_C42:]] = OpConstant %[[#F32]] 0x1.5p+5 ; CHECK: %[[#F64:]] = OpTypeFloat 64 ; CHECK: %[[#C64:]] = OpTypeVector %[[#F64]] 2 ; CHECK: %[[#PTR_C64:]] = OpTypePointer CrossWorkgroup %[[#C64]] -; CHECK: %[[#F64_C42:]] = OpConstant %[[#F64]] 42 -; CHECK: %[[#F64_C1:]] = OpConstant %[[#F64]] 1 +; CHECK: %[[#F64_C42:]] = OpConstant %[[#F64]] 0x1.5p+5 +; CHECK: %[[#F64_C1:]] = OpConstant %[[#F64]] 0x1p+0 ; CHECK: %[[#C64_C42_1:]] = OpConstantComposite %[[#C64]] %[[#F64_C42]] %[[#F64_C1]] ; CHECK: %[[#PTR_F64:]] = OpTypePointer CrossWorkgroup %[[#F64]] ; CHECK: %[[#I32_C1:]] = OpConstant %[[#I32]] 1 From 2b5fabf580021704d8bf33eca2a4074f5070f430 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 13 Nov 2024 11:13:48 +0100 Subject: [PATCH 108/297] SPIR-V: Add rules for adding capabilities and extensions Signed-off-by: Carsten Uphoff --- src/CMakeLists.txt | 2 + src/spv/capex_util.cpp | 677 +++++++++++++++++++++++++++++++++++++ src/spv/capex_util.hpp | 26 ++ src/spv/converter.cpp | 58 +--- src/spv/converter.hpp | 7 +- src/spv/defs.hpp | 6 +- src/spv/enums.hpp | 8 +- src/spv/instructions.hpp | 44 ++- src/spv/names.cpp | 4 + src/spv/names.hpp | 6 +- src/spv/pass/capex.cpp | 111 ++++++ src/spv/pass/capex.hpp | 39 +++ src/spv/uniquifier.cpp | 5 - src/spv/visit.hpp | 27 +- tools/spirvgen/filter.json | 1 + tools/spirvgen/spirvgen.py | 55 +++ 16 files changed, 1002 insertions(+), 74 deletions(-) create mode 100644 src/spv/capex_util.cpp create mode 100644 src/spv/capex_util.hpp create mode 100644 src/spv/pass/capex.cpp create mode 100644 src/spv/pass/capex.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ad33dadb..e49b6479 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -64,6 +64,7 @@ set(SOURCES region.cpp required_extensions.cpp scalar_type.cpp + spv/capex_util.cpp spv/converter.cpp spv/inst_assembler.cpp spv/module.cpp @@ -72,6 +73,7 @@ set(SOURCES spv/pass/assemble.cpp spv/pass/assign_ids.cpp spv/pass/dump_asm.cpp + spv/pass/capex.cpp spv/uniquifier.cpp source.cpp tiling.cpp diff --git a/src/spv/capex_util.cpp b/src/spv/capex_util.cpp new file mode 100644 index 00000000..776ad4cf --- /dev/null +++ b/src/spv/capex_util.cpp @@ -0,0 +1,677 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#include "capex_util.hpp" + +namespace tinytc::spv { + +auto capabilities(ExecutionModel e) -> array_view { + switch (e) { + case ExecutionModel::Vertex: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionModel::TessellationControl: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionModel::TessellationEvaluation: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionModel::Geometry: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionModel::Fragment: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionModel::GLCompute: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionModel::Kernel: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionModel::TaskNV: { + constexpr static Capability values[] = {Capability::MeshShadingNV}; + return {values, 1}; + } + case ExecutionModel::MeshNV: { + constexpr static Capability values[] = {Capability::MeshShadingNV}; + return {values, 1}; + } + case ExecutionModel::RayGenerationKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::IntersectionKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::AnyHitKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::ClosestHitKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::MissKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::CallableKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::TaskEXT: { + constexpr static Capability values[] = {Capability::MeshShadingEXT}; + return {values, 1}; + } + case ExecutionModel::MeshEXT: { + constexpr static Capability values[] = {Capability::MeshShadingEXT}; + return {values, 1}; + } + default: + return {}; + } +} +auto capabilities(AddressingModel e) -> array_view { + switch (e) { + case AddressingModel::Physical32: { + constexpr static Capability values[] = {Capability::Addresses}; + return {values, 1}; + } + case AddressingModel::Physical64: { + constexpr static Capability values[] = {Capability::Addresses}; + return {values, 1}; + } + case AddressingModel::PhysicalStorageBuffer64: { + constexpr static Capability values[] = {Capability::PhysicalStorageBufferAddresses}; + return {values, 1}; + } + default: + return {}; + } +} +auto capabilities(MemoryModel e) -> array_view { + switch (e) { + case MemoryModel::Simple: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case MemoryModel::GLSL450: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case MemoryModel::OpenCL: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case MemoryModel::Vulkan: { + constexpr static Capability values[] = {Capability::VulkanMemoryModel}; + return {values, 1}; + } + default: + return {}; + } +} +auto capabilities(ExecutionMode e) -> array_view { + switch (e) { + case ExecutionMode::Invocations: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::SpacingEqual: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::SpacingFractionalEven: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::SpacingFractionalOdd: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::VertexOrderCw: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::VertexOrderCcw: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::PixelCenterInteger: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::OriginUpperLeft: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::OriginLowerLeft: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::EarlyFragmentTests: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::PointMode: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::Xfb: { + constexpr static Capability values[] = {Capability::TransformFeedback}; + return {values, 1}; + } + case ExecutionMode::DepthReplacing: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::DepthGreater: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::DepthLess: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::DepthUnchanged: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::LocalSizeHint: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::InputPoints: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::InputLines: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::InputLinesAdjacency: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::Triangles: { + constexpr static Capability values[] = {Capability::Geometry, Capability::Tessellation}; + return {values, 2}; + } + case ExecutionMode::InputTrianglesAdjacency: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::Quads: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::Isolines: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::OutputVertices: { + constexpr static Capability values[] = {Capability::Geometry, Capability::Tessellation, + Capability::MeshShadingNV, + Capability::MeshShadingEXT}; + return {values, 4}; + } + case ExecutionMode::OutputPoints: { + constexpr static Capability values[] = {Capability::Geometry, Capability::MeshShadingNV, + Capability::MeshShadingEXT}; + return {values, 3}; + } + case ExecutionMode::OutputLineStrip: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::OutputTriangleStrip: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::VecTypeHint: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::ContractionOff: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::Initializer: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::Finalizer: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::SubgroupSize: { + constexpr static Capability values[] = {Capability::SubgroupDispatch}; + return {values, 1}; + } + case ExecutionMode::SubgroupsPerWorkgroup: { + constexpr static Capability values[] = {Capability::SubgroupDispatch}; + return {values, 1}; + } + case ExecutionMode::SubgroupsPerWorkgroupId: { + constexpr static Capability values[] = {Capability::SubgroupDispatch}; + return {values, 1}; + } + case ExecutionMode::LocalSizeHintId: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::NonCoherentColorAttachmentReadEXT: { + constexpr static Capability values[] = {Capability::TileImageColorReadAccessEXT}; + return {values, 1}; + } + case ExecutionMode::NonCoherentDepthAttachmentReadEXT: { + constexpr static Capability values[] = {Capability::TileImageDepthReadAccessEXT}; + return {values, 1}; + } + case ExecutionMode::NonCoherentStencilAttachmentReadEXT: { + constexpr static Capability values[] = {Capability::TileImageStencilReadAccessEXT}; + return {values, 1}; + } + case ExecutionMode::SubgroupUniformControlFlowKHR: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::PostDepthCoverage: { + constexpr static Capability values[] = {Capability::SampleMaskPostDepthCoverage}; + return {values, 1}; + } + case ExecutionMode::DenormPreserve: { + constexpr static Capability values[] = {Capability::DenormPreserve}; + return {values, 1}; + } + case ExecutionMode::DenormFlushToZero: { + constexpr static Capability values[] = {Capability::DenormFlushToZero}; + return {values, 1}; + } + case ExecutionMode::SignedZeroInfNanPreserve: { + constexpr static Capability values[] = {Capability::SignedZeroInfNanPreserve}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTE: { + constexpr static Capability values[] = {Capability::RoundingModeRTE}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTZ: { + constexpr static Capability values[] = {Capability::RoundingModeRTZ}; + return {values, 1}; + } + case ExecutionMode::EarlyAndLateFragmentTestsAMD: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::StencilRefReplacingEXT: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::CoalescingAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::IsApiEntryAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::MaxNodeRecursionAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::StaticNumWorkgroupsAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::ShaderIndexAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::MaxNumWorkgroupsAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::StencilRefUnchangedFrontAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::StencilRefGreaterFrontAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::StencilRefLessFrontAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::StencilRefUnchangedBackAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::StencilRefGreaterBackAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::StencilRefLessBackAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::QuadDerivativesKHR: { + constexpr static Capability values[] = {Capability::QuadControlKHR}; + return {values, 1}; + } + case ExecutionMode::RequireFullQuadsKHR: { + constexpr static Capability values[] = {Capability::QuadControlKHR}; + return {values, 1}; + } + case ExecutionMode::SharesInputWithAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::OutputLinesEXT: { + constexpr static Capability values[] = {Capability::MeshShadingNV, + Capability::MeshShadingEXT}; + return {values, 2}; + } + case ExecutionMode::OutputPrimitivesEXT: { + constexpr static Capability values[] = {Capability::MeshShadingNV, + Capability::MeshShadingEXT}; + return {values, 2}; + } + case ExecutionMode::DerivativeGroupQuadsKHR: { + constexpr static Capability values[] = {Capability::ComputeDerivativeGroupQuadsKHR}; + return {values, 2}; + } + case ExecutionMode::DerivativeGroupLinearKHR: { + constexpr static Capability values[] = {Capability::ComputeDerivativeGroupLinearKHR}; + return {values, 2}; + } + case ExecutionMode::OutputTrianglesEXT: { + constexpr static Capability values[] = {Capability::MeshShadingNV, + Capability::MeshShadingEXT}; + return {values, 2}; + } + case ExecutionMode::PixelInterlockOrderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderPixelInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::PixelInterlockUnorderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderPixelInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::SampleInterlockOrderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderSampleInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::SampleInterlockUnorderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderSampleInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::ShadingRateInterlockOrderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderShadingRateInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::ShadingRateInterlockUnorderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderShadingRateInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::SharedLocalMemorySizeINTEL: { + constexpr static Capability values[] = {Capability::VectorComputeINTEL}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTPINTEL: { + constexpr static Capability values[] = {Capability::RoundToInfinityINTEL}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTNINTEL: { + constexpr static Capability values[] = {Capability::RoundToInfinityINTEL}; + return {values, 1}; + } + case ExecutionMode::FloatingPointModeALTINTEL: { + constexpr static Capability values[] = {Capability::RoundToInfinityINTEL}; + return {values, 1}; + } + case ExecutionMode::FloatingPointModeIEEEINTEL: { + constexpr static Capability values[] = {Capability::RoundToInfinityINTEL}; + return {values, 1}; + } + case ExecutionMode::MaxWorkgroupSizeINTEL: { + constexpr static Capability values[] = {Capability::KernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::MaxWorkDimINTEL: { + constexpr static Capability values[] = {Capability::KernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::NoGlobalOffsetINTEL: { + constexpr static Capability values[] = {Capability::KernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::NumSIMDWorkitemsINTEL: { + constexpr static Capability values[] = {Capability::FPGAKernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::SchedulerTargetFmaxMhzINTEL: { + constexpr static Capability values[] = {Capability::FPGAKernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::MaximallyReconvergesKHR: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::FPFastMathDefault: { + constexpr static Capability values[] = {Capability::FloatControls2}; + return {values, 1}; + } + case ExecutionMode::StreamingInterfaceINTEL: { + constexpr static Capability values[] = {Capability::FPGAKernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::RegisterMapInterfaceINTEL: { + constexpr static Capability values[] = {Capability::FPGAKernelAttributesv2INTEL}; + return {values, 1}; + } + case ExecutionMode::NamedBarrierCountINTEL: { + constexpr static Capability values[] = {Capability::VectorComputeINTEL}; + return {values, 1}; + } + case ExecutionMode::MaximumRegistersINTEL: { + constexpr static Capability values[] = {Capability::RegisterLimitsINTEL}; + return {values, 1}; + } + case ExecutionMode::MaximumRegistersIdINTEL: { + constexpr static Capability values[] = {Capability::RegisterLimitsINTEL}; + return {values, 1}; + } + case ExecutionMode::NamedMaximumRegistersINTEL: { + constexpr static Capability values[] = {Capability::RegisterLimitsINTEL}; + return {values, 1}; + } + default: + return {}; + } +} +auto extensions(ExecutionModel e) -> array_view { + switch (e) { + default: + return {}; + } +} +auto extensions(AddressingModel e) -> array_view { + switch (e) { + case AddressingModel::PhysicalStorageBuffer64: { + constexpr static char const *values[] = {"SPV_EXT_physical_storage_buffer", + "SPV_KHR_physical_storage_buffer"}; + return {values, 2}; + } + default: + return {}; + } +} +auto extensions(MemoryModel e) -> array_view { + switch (e) { + case MemoryModel::Vulkan: { + constexpr static char const *values[] = {"SPV_KHR_vulkan_memory_model"}; + return {values, 1}; + } + default: + return {}; + } +} +auto extensions(ExecutionMode e) -> array_view { + switch (e) { + case ExecutionMode::SubgroupUniformControlFlowKHR: { + constexpr static char const *values[] = {"SPV_KHR_subgroup_uniform_control_flow"}; + return {values, 1}; + } + case ExecutionMode::PostDepthCoverage: { + constexpr static char const *values[] = {"SPV_KHR_post_depth_coverage"}; + return {values, 1}; + } + case ExecutionMode::DenormPreserve: { + constexpr static char const *values[] = {"SPV_KHR_float_controls"}; + return {values, 1}; + } + case ExecutionMode::DenormFlushToZero: { + constexpr static char const *values[] = {"SPV_KHR_float_controls"}; + return {values, 1}; + } + case ExecutionMode::SignedZeroInfNanPreserve: { + constexpr static char const *values[] = {"SPV_KHR_float_controls"}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTE: { + constexpr static char const *values[] = {"SPV_KHR_float_controls"}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTZ: { + constexpr static char const *values[] = {"SPV_KHR_float_controls"}; + return {values, 1}; + } + case ExecutionMode::EarlyAndLateFragmentTestsAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests"}; + return {values, 1}; + } + case ExecutionMode::StencilRefReplacingEXT: { + constexpr static char const *values[] = {"SPV_EXT_shader_stencil_export"}; + return {values, 1}; + } + case ExecutionMode::StencilRefUnchangedFrontAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::StencilRefGreaterFrontAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::StencilRefLessFrontAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::StencilRefUnchangedBackAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::StencilRefGreaterBackAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::StencilRefLessBackAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::OutputLinesEXT: { + constexpr static char const *values[] = {"SPV_NV_mesh_shader", "SPV_EXT_mesh_shader"}; + return {values, 2}; + } + case ExecutionMode::OutputPrimitivesEXT: { + constexpr static char const *values[] = {"SPV_NV_mesh_shader", "SPV_EXT_mesh_shader"}; + return {values, 2}; + } + case ExecutionMode::DerivativeGroupQuadsKHR: { + constexpr static char const *values[] = {"SPV_NV_compute_shader_derivatives", + "SPV_KHR_compute_shader_derivatives"}; + return {values, 2}; + } + case ExecutionMode::DerivativeGroupLinearKHR: { + constexpr static char const *values[] = {"SPV_NV_compute_shader_derivatives", + "SPV_KHR_compute_shader_derivatives"}; + return {values, 2}; + } + case ExecutionMode::OutputTrianglesEXT: { + constexpr static char const *values[] = {"SPV_NV_mesh_shader", "SPV_EXT_mesh_shader"}; + return {values, 2}; + } + case ExecutionMode::PixelInterlockOrderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::PixelInterlockUnorderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::SampleInterlockOrderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::SampleInterlockUnorderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::ShadingRateInterlockOrderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::ShadingRateInterlockUnorderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::MaxWorkgroupSizeINTEL: { + constexpr static char const *values[] = {"SPV_INTEL_kernel_attributes"}; + return {values, 1}; + } + case ExecutionMode::MaxWorkDimINTEL: { + constexpr static char const *values[] = {"SPV_INTEL_kernel_attributes"}; + return {values, 1}; + } + case ExecutionMode::NoGlobalOffsetINTEL: { + constexpr static char const *values[] = {"SPV_INTEL_kernel_attributes"}; + return {values, 1}; + } + case ExecutionMode::NumSIMDWorkitemsINTEL: { + constexpr static char const *values[] = {"SPV_INTEL_kernel_attributes"}; + return {values, 1}; + } + case ExecutionMode::MaximallyReconvergesKHR: { + constexpr static char const *values[] = {"SPV_KHR_maximal_reconvergence"}; + return {values, 1}; + } + default: + return {}; + } +} + +} // namespace tinytc::spv diff --git a/src/spv/capex_util.hpp b/src/spv/capex_util.hpp new file mode 100644 index 00000000..9396f285 --- /dev/null +++ b/src/spv/capex_util.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_CAPEX_UTIL_20241113_HPP +#define GENERATED_CAPEX_UTIL_20241113_HPP + +#include "enums.hpp" +#include "tinytc/tinytc.hpp" + +namespace tinytc::spv { + +auto capabilities(ExecutionModel op) -> array_view; +auto capabilities(AddressingModel op) -> array_view; +auto capabilities(MemoryModel op) -> array_view; +auto capabilities(ExecutionMode op) -> array_view; +auto extensions(ExecutionModel op) -> array_view; +auto extensions(AddressingModel op) -> array_view; +auto extensions(MemoryModel op) -> array_view; +auto extensions(ExecutionMode op) -> array_view; + +} // namespace tinytc::spv + +#endif // GENERATED_CAPEX_UTIL_20241113_HPP diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index f0359d36..bc8ca063 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -14,6 +14,7 @@ #include "spv/enums.hpp" #include "spv/instructions.hpp" #include "spv/opencl.std.hpp" +#include "spv/pass/capex.hpp" #include "spv/uniquifier.hpp" #include "spv/visit.hpp" #include "support/ilist.hpp" @@ -45,10 +46,6 @@ auto convert_prog_to_spirv(tinytc_prog const &p, auto conv = inst_converter{*m}; - conv.unique().capability(Capability::Addresses); - conv.unique().capability(Capability::Kernel); - conv.unique().capability(Capability::SubgroupDispatch); - m->add_to(section::memory_model, AddressingModel::Physical64, MemoryModel::OpenCL); @@ -61,27 +58,10 @@ auto convert_prog_to_spirv(tinytc_prog const &p, } // Add missing capabilites and extensions + auto cx = capex{conv.unique()}; for (std::int32_t s = 0; s < num_module_sections; ++s) { for (auto const &i : m->insts(enum_cast
(s))) { - visit(overloaded{[&](I const &in) { - if (isa(static_cast(in))) { - // We manage OpAtomicFAddExt manually as the required - // capabilitites depend on the data type - return; - } - for (auto const &cap : I::required_capabilities) { - conv.unique().capability(cap); - } - }, - [&](auto const &) {}}, - i); - visit(overloaded{[&](I const &) { - for (auto const &ext_name : I::required_extensions) { - conv.unique().extension(ext_name); - } - }, - [&](auto const &) {}}, - i); + visit(cx, i); } } @@ -148,12 +128,13 @@ auto inst_converter::get_scalar_type(value_node const &v) const -> scalar_type { } return st->ty(); } -auto inst_converter::get_coopmatrix_type(value_node const &v) const -> scalar_type { +auto inst_converter::get_coopmatrix_type(value_node const &v) const + -> coopmatrix_data_type const * { auto ct = dyn_cast(v.ty()); if (!ct) { throw compilation_error(v.loc(), status::ir_expected_coopmatrix); } - return ct->component_ty(); + return ct; } auto inst_converter::load_builtin(BuiltIn b) -> spv_inst * { @@ -321,24 +302,6 @@ auto inst_converter::make_dope_vector(tinytc_value const &v) -> dope_vector * { void inst_converter::make_store(store_flag flag, scalar_type sty, address_space as, spv_inst *pointer, spv_inst *value) { - auto const add_fadd_caps = [&] { - switch (sty) { - case scalar_type::i8: - case scalar_type::i16: - case scalar_type::i32: - case scalar_type::i64: - case scalar_type::index: - break; - case scalar_type::f32: - case scalar_type::c32: - unique_.capability(Capability::AtomicFloat32AddEXT); - break; - case scalar_type::f64: - case scalar_type::c64: - unique_.capability(Capability::AtomicFloat64AddEXT); - break; - } - }; auto const split_re_im = [&]() -> std::array, 2u> { auto component_sty = element_type(sty); auto float_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), component_sty)); @@ -389,12 +352,10 @@ void inst_converter::make_store(store_flag flag, scalar_type sty, address_space break; case scalar_type::f32: case scalar_type::f64: - add_fadd_caps(); mod_->add(result_ty, pointer, scope, semantics, value); break; case scalar_type::c32: case scalar_type::c64: { - add_fadd_caps(); auto re_im = split_re_im(); auto component_sty = element_type(sty); auto float_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), component_sty)); @@ -711,7 +672,7 @@ void inst_converter::operator()(cast_inst const &in) { insts.reserve(length); auto &av = multi_val(in.a()); - auto a_ty = get_coopmatrix_type(in.a()); + auto a_ty = get_coopmatrix_type(in.a())->component_ty(); for (std::int64_t i = 0; i < length; ++i) { insts.emplace_back(make(ct->component_ty(), a_ty, spv_to_ty, av[i])); } @@ -837,8 +798,9 @@ void inst_converter::operator()(cooperative_matrix_scale_inst const &in) { insts.reserve(bv.size()); for (std::size_t i = 0; i < bv.size(); ++i) { - insts.emplace_back(make_binary_op(get_coopmatrix_type(in.result(0)), arithmetic::mul, - unique_.spv_ty(in.a().ty()), av, bv[i], in.loc())); + insts.emplace_back(make_binary_op(get_coopmatrix_type(in.result(0))->component_ty(), + arithmetic::mul, unique_.spv_ty(in.a().ty()), av, bv[i], + in.loc())); } multi_declare(in.result(0), std::move(insts)); diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index 68dbdf18..e9ab863c 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -119,7 +119,7 @@ class inst_converter { auto get_dope_vector(tinytc_value const &v) -> dope_vector *; auto get_memref_type(tinytc_value const &v) const -> memref_data_type const *; auto get_scalar_type(tinytc_value const &v) const -> scalar_type; - auto get_coopmatrix_type(tinytc_value const &v) const -> scalar_type; + auto get_coopmatrix_type(tinytc_value const &v) const -> coopmatrix_data_type const *; auto load_builtin(BuiltIn b) -> spv_inst *; auto declare(tinytc_value const &v, spv_inst *in); auto val(tinytc_value const &v) -> spv_inst *; @@ -144,11 +144,6 @@ class inst_converter { core_config core_cfg_ = {}; }; -template -concept spv_inst_with_required_capabilities = requires() { T::required_capabilities; }; -template -concept spv_inst_with_required_extensions = requires() { T::required_extensions; }; - } // namespace tinytc::spv #endif // CONVERTER_20241111_HPP diff --git a/src/spv/defs.hpp b/src/spv/defs.hpp index 24e7a271..bc6f8f5d 100644 --- a/src/spv/defs.hpp +++ b/src/spv/defs.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_DEFS_20241111_HPP -#define GENERATED_DEFS_20241111_HPP +#ifndef GENERATED_DEFS_20241113_HPP +#define GENERATED_DEFS_20241113_HPP #include "enums.hpp" #include "support/ilist_base.hpp" @@ -62,4 +62,4 @@ using PairIdRefLiteralInteger = std::pair; } // namespace tinytc::spv -#endif // GENERATED_DEFS_20241111_HPP +#endif // GENERATED_DEFS_20241113_HPP diff --git a/src/spv/enums.hpp b/src/spv/enums.hpp index 1557607b..90ca6a81 100644 --- a/src/spv/enums.hpp +++ b/src/spv/enums.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_ENUMS_20241111_HPP -#define GENERATED_ENUMS_20241111_HPP +#ifndef GENERATED_ENUMS_20241113_HPP +#define GENERATED_ENUMS_20241113_HPP #include @@ -358,6 +358,8 @@ enum class Op { CooperativeMatrixStoreKHR = 4458, CooperativeMatrixMulAddKHR = 4459, CooperativeMatrixLengthKHR = 4460, + SubgroupBlockReadINTEL = 5575, + SubgroupBlockWriteINTEL = 5576, AtomicFMinEXT = 5614, AtomicFMaxEXT = 5615, AtomicFAddEXT = 6035, @@ -1429,4 +1431,4 @@ enum class FPEncoding {}; } // namespace tinytc::spv -#endif // GENERATED_ENUMS_20241111_HPP +#endif // GENERATED_ENUMS_20241113_HPP diff --git a/src/spv/instructions.hpp b/src/spv/instructions.hpp index c9f84ec9..59276d0e 100644 --- a/src/spv/instructions.hpp +++ b/src/spv/instructions.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_INSTRUCTIONS_20241111_HPP -#define GENERATED_INSTRUCTIONS_20241111_HPP +#ifndef GENERATED_INSTRUCTIONS_20241113_HPP +#define GENERATED_INSTRUCTIONS_20241113_HPP #include "defs.hpp" #include "enums.hpp" @@ -6647,6 +6647,44 @@ class OpCooperativeMatrixLengthKHR : public spv_inst { IdResultType type_; IdRef op0_; }; +class OpSubgroupBlockReadINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::SubgroupBlockReadINTEL; + } + constexpr static std::array required_capabilities = { + Capability::SubgroupBufferBlockIOINTEL}; + OpSubgroupBlockReadINTEL(IdResultType type, IdRef op0) + : spv_inst{Op::SubgroupBlockReadINTEL, true}, type_(std::move(type)), op0_(std::move(op0)) { + } + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSubgroupBlockWriteINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::SubgroupBlockWriteINTEL; + } + constexpr static std::array required_capabilities = { + Capability::SubgroupBufferBlockIOINTEL}; + OpSubgroupBlockWriteINTEL(IdRef op0, IdRef op1) + : spv_inst{Op::SubgroupBlockWriteINTEL, false}, op0_(std::move(op0)), op1_(std::move(op1)) { + } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdRef op0_; + IdRef op1_; +}; class OpAtomicFMinEXT : public spv_inst { public: inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicFMinEXT; } @@ -6733,4 +6771,4 @@ class OpAtomicFAddEXT : public spv_inst { } // namespace tinytc::spv -#endif // GENERATED_INSTRUCTIONS_20241111_HPP +#endif // GENERATED_INSTRUCTIONS_20241113_HPP diff --git a/src/spv/names.cpp b/src/spv/names.cpp index 77e06035..6c0a159f 100644 --- a/src/spv/names.cpp +++ b/src/spv/names.cpp @@ -699,6 +699,10 @@ auto to_string(Op op) -> char const * { return "CooperativeMatrixMulAddKHR"; case Op::CooperativeMatrixLengthKHR: return "CooperativeMatrixLengthKHR"; + case Op::SubgroupBlockReadINTEL: + return "SubgroupBlockReadINTEL"; + case Op::SubgroupBlockWriteINTEL: + return "SubgroupBlockWriteINTEL"; case Op::AtomicFMinEXT: return "AtomicFMinEXT"; case Op::AtomicFMaxEXT: diff --git a/src/spv/names.hpp b/src/spv/names.hpp index 3fc09b91..c84df1e8 100644 --- a/src/spv/names.hpp +++ b/src/spv/names.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_NAMES_20241111_HPP -#define GENERATED_NAMES_20241111_HPP +#ifndef GENERATED_NAMES_20241113_HPP +#define GENERATED_NAMES_20241113_HPP #include "enums.hpp" @@ -68,4 +68,4 @@ auto to_string(FPEncoding e) -> char const *; } // namespace tinytc::spv -#endif // GENERATED_NAMES_20241111_HPP +#endif // GENERATED_NAMES_20241113_HPP diff --git a/src/spv/pass/capex.cpp b/src/spv/pass/capex.cpp new file mode 100644 index 00000000..71d92bde --- /dev/null +++ b/src/spv/pass/capex.cpp @@ -0,0 +1,111 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/pass/capex.hpp" +#include "error.hpp" +#include "spv/capex_util.hpp" +#include "spv/enums.hpp" +#include "spv/instructions.hpp" +#include "spv/module.hpp" +#include "spv/visit.hpp" +#include "support/casting.hpp" +#include "support/util.hpp" + +#include + +namespace tinytc::spv { + +capex::capex(uniquifier &unique) : unique_{&unique} {} + +void capex::operator()(spv_inst const &) {} +void capex::operator()(OpAtomicFAddEXT const &in) { + auto ty = dyn_cast(in.type()); + if (!ty) { + throw status::internal_compiler_error; + } + switch (ty->op0()) { + case 16: + unique_->capability(Capability::AtomicFloat16AddEXT); + unique_->extension("SPV_EXT_shader_atomic_float16_add"); + break; + case 32: + unique_->capability(Capability::AtomicFloat32AddEXT); + unique_->extension("SPV_EXT_shader_atomic_float_add"); + break; + case 64: + unique_->capability(Capability::AtomicFloat64AddEXT); + unique_->extension("SPV_EXT_shader_atomic_float_add"); + break; + default: + break; + } +} +void capex::operator()(OpEntryPoint const &in) { + for (auto const &cap : capabilities(in.op0())) { + unique_->capability(cap); + } +} +void capex::operator()(OpExecutionMode const &in) { + for (auto const &cap : capabilities(in.op1())) { + unique_->capability(cap); + } +} +void capex::operator()(OpGroupFAdd const &) { unique_->capability(Capability::Groups); } +void capex::operator()(OpGroupIAdd const &) { unique_->capability(Capability::Groups); } +void capex::operator()(OpInBoundsPtrAccessChain const &) { + unique_->capability(Capability::Addresses); +} +void capex::operator()(OpMemoryModel const &in) { + for (auto const &cap : capabilities(in.op0())) { + unique_->capability(cap); + } + for (auto const &cap : capabilities(in.op1())) { + unique_->capability(cap); + } +} +void capex::operator()(OpSubgroupBlockReadINTEL const &) { + unique_->capability(Capability::SubgroupBufferBlockIOINTEL); + unique_->extension("SPV_INTEL_subgroups"); +} +void capex::operator()(OpSubgroupBlockWriteINTEL const &) { + unique_->capability(Capability::SubgroupBufferBlockIOINTEL); + unique_->extension("SPV_INTEL_subgroups"); +} +void capex::operator()(OpTypeFloat const &in) { + switch (in.op0()) { + case 16: + unique_->capability(Capability::Float16); + break; + case 64: + unique_->capability(Capability::Float64); + break; + default: + break; + } +} +void capex::operator()(OpTypeInt const &in) { + switch (in.op0()) { + case 8: + unique_->capability(Capability::Int8); + break; + case 16: + unique_->capability(Capability::Int16); + break; + case 64: + unique_->capability(Capability::Int64); + break; + default: + break; + } +} + +void capex::run_on_module(tinytc_spv_mod const &mod) { + for (std::int32_t s = 0; s < num_module_sections; ++s) { + for (auto const &i : mod.insts(enum_cast
(s))) { + visit(*this, i); + } + } +} + +} // namespace tinytc::spv + diff --git a/src/spv/pass/capex.hpp b/src/spv/pass/capex.hpp new file mode 100644 index 00000000..5a3e5ad1 --- /dev/null +++ b/src/spv/pass/capex.hpp @@ -0,0 +1,39 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CAPEX_20241113_HPP +#define CAPEX_20241113_HPP + +#include "spv/defs.hpp" +#include "spv/instructions.hpp" +#include "spv/uniquifier.hpp" +#include "tinytc/types.h" + +namespace tinytc::spv { + +class capex { + public: + capex(uniquifier &unique); + + void operator()(spv_inst const &in); + void operator()(OpAtomicFAddEXT const &in); + void operator()(OpEntryPoint const &in); + void operator()(OpExecutionMode const &in); + void operator()(OpGroupFAdd const &in); + void operator()(OpGroupIAdd const &in); + void operator()(OpInBoundsPtrAccessChain const &in); + void operator()(OpMemoryModel const &in); + void operator()(OpSubgroupBlockReadINTEL const &in); + void operator()(OpSubgroupBlockWriteINTEL const &in); + void operator()(OpTypeFloat const &in); + void operator()(OpTypeInt const &in); + + void run_on_module(tinytc_spv_mod const &mod); + + private: + uniquifier *unique_; +}; + +} // namespace tinytc::spv + +#endif // CAPEX_20241113_HPP diff --git a/src/spv/uniquifier.cpp b/src/spv/uniquifier.cpp index 48a98d8b..235cb9a7 100644 --- a/src/spv/uniquifier.cpp +++ b/src/spv/uniquifier.cpp @@ -215,27 +215,22 @@ auto uniquifier::spv_ty(const_tinytc_data_type_t ty) -> spv_inst * { [&](scalar_data_type const &ty) -> spv_inst * { switch (ty.ty()) { case scalar_type::i8: - capability(Capability::Int8); return mod_->add_to(section::type_const_var, 8, 0); case scalar_type::i16: - capability(Capability::Int16); return mod_->add_to(section::type_const_var, 16, 0); case scalar_type::i32: return mod_->add_to(section::type_const_var, 32, 0); case scalar_type::i64: - capability(Capability::Int64); return mod_->add_to(section::type_const_var, 64, 0); case scalar_type::index: { const auto sz = size(ty.ty()); if (sz == 8) { - capability(Capability::Int64); return spv_ty(scalar_data_type::get(mod_->context(), scalar_type::i64)); } return spv_ty(scalar_data_type::get(mod_->context(), scalar_type::i32)); } case scalar_type::f32: case scalar_type::f64: - capability(Capability::Float64); return mod_->add_to(section::type_const_var, size(ty.ty()) * 8); case scalar_type::c32: { diff --git a/src/spv/visit.hpp b/src/spv/visit.hpp index dc09c379..4a1268b4 100644 --- a/src/spv/visit.hpp +++ b/src/spv/visit.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_VISIT_20241111_HPP -#define GENERATED_VISIT_20241111_HPP +#ifndef GENERATED_VISIT_20241113_HPP +#define GENERATED_VISIT_20241113_HPP #include "defs.hpp" #include "enums.hpp" @@ -707,6 +707,10 @@ template auto visit(Visitor &&visitor, spv_inst &inst) { return visitor(static_cast(inst)); case Op::CooperativeMatrixLengthKHR: return visitor(static_cast(inst)); + case Op::SubgroupBlockReadINTEL: + return visitor(static_cast(inst)); + case Op::SubgroupBlockWriteINTEL: + return visitor(static_cast(inst)); case Op::AtomicFMinEXT: return visitor(static_cast(inst)); case Op::AtomicFMaxEXT: @@ -1407,6 +1411,10 @@ template auto visit(Visitor &&visitor, spv_inst const &inst) return visitor(static_cast(inst)); case Op::CooperativeMatrixLengthKHR: return visitor(static_cast(inst)); + case Op::SubgroupBlockReadINTEL: + return visitor(static_cast(inst)); + case Op::SubgroupBlockWriteINTEL: + return visitor(static_cast(inst)); case Op::AtomicFMinEXT: return visitor(static_cast(inst)); case Op::AtomicFMaxEXT: @@ -4382,6 +4390,19 @@ template class default_visitor { static_cast(this)->operator()(in.op0()); static_cast(this)->post_visit(in); } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } auto operator()(const_t &in) { static_cast(this)->pre_visit(in); static_cast(this)->operator()(in.type()); @@ -4416,4 +4437,4 @@ template class default_visitor { } // namespace tinytc::spv -#endif // GENERATED_VISIT_20241111_HPP +#endif // GENERATED_VISIT_20241113_HPP diff --git a/tools/spirvgen/filter.json b/tools/spirvgen/filter.json index 38b632ae..a77517db 100644 --- a/tools/spirvgen/filter.json +++ b/tools/spirvgen/filter.json @@ -7,6 +7,7 @@ [0, 47], [53, 999], [4456, 4460], + [5575, 5576], [5614, 5615], [6035, 6035] ] diff --git a/tools/spirvgen/spirvgen.py b/tools/spirvgen/spirvgen.py index 4f5c4921..34605fed 100755 --- a/tools/spirvgen/spirvgen.py +++ b/tools/spirvgen/spirvgen.py @@ -15,6 +15,13 @@ spv_enums = 'enums.hpp' spv_enums_includes = [''] +spv_capex = 'capex_util.hpp' +spv_capex_includes = [spv_enums, "tinytc/tinytc.hpp"] +spv_capex_cpp = 'capex_util.cpp' +spv_capex_cpp_includes = [spv_capex] +spv_capex_required_enums = [ + 'AddressingModel', 'ExecutionMode', 'ExecutionModel', 'MemoryModel' +] spv_names = 'names.hpp' spv_names_includes = [spv_enums] spv_names_cpp = 'names.cpp' @@ -49,6 +56,50 @@ def get_class_name(instruction): return instruction['opname'] +def generate_capex(f, grammar): + for name, ty in zip(['capabilities', 'extensions'], + ['Capability', 'char const*']): + for opkind in grammar['operand_kinds']: + if opkind['kind'] in spv_capex_required_enums: + print(f'auto {name}({opkind["kind"]} op) -> array_view<{ty}>;', + file=f) + + +def generate_capex_cpp(f, grammar): + # Need to go through capabilities to filter aliases later on + capabilities = set() + for opkind in grammar['operand_kinds']: + if opkind['kind'] == 'Capability': + for enumerant in opkind['enumerants']: + capabilities.add(enumerant['enumerant']) + + def print_function(name, ty, trans, filt): + for opkind in grammar['operand_kinds']: + if opkind['kind'] in spv_capex_required_enums: + print( + f'auto {name}({opkind["kind"]} e) -> array_view<{ty}> {{ switch(e) {{', + file=f) + for enumerant in opkind['enumerants']: + if name in enumerant: + ename = enumerant["enumerant"] + values = enumerant[name] + values_str = ','.join( + [trans(v) for v in values if filt(v)]) + print( + f'case {opkind["kind"]}::{enumerant_subs.get(ename, ename)}: {{', + file=f) + print( + f'constexpr static {ty} values[] = {{{values_str}}};', + file=f) + print(f'return {{values, {len(values)}}}; }}', file=f) + print('default: return {}; }}', file=f) + + print_function('capabilities', 'Capability', lambda x: f'Capability::{x}', + lambda x: x in capabilities) + print_function('extensions', 'char const*', lambda x: f'"{x}"', + lambda x: True) + + def generate_enums(f, grammar): print(f'constexpr std::int32_t magic_number = {grammar["magic_number"]};', file=f) @@ -377,6 +428,10 @@ def patch_grammar(grammar): grammar = filter_grammar(grammar, filt) grammar = patch_grammar(grammar) + generate_header(args, spv_capex, grammar, generate_capex, + spv_capex_includes) + generate_cpp(args, spv_capex_cpp, grammar, generate_capex_cpp, + spv_capex_cpp_includes) generate_header(args, spv_enums, grammar, generate_enums, spv_enums_includes) generate_header(args, spv_names, grammar, generate_names, From cd48069a49869832a7147a79ff5684f1e69eba96 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 13 Nov 2024 16:21:06 +0100 Subject: [PATCH 109/297] SPIR-V: Cooperative matrix load / store Signed-off-by: Carsten Uphoff --- src/spv/converter.cpp | 282 ++++++++++++++++++++++++++- src/spv/converter.hpp | 4 + test/CMakeLists.txt | 2 + test/spv/cooperative_matrix_load.ir | 98 ++++++++++ test/spv/cooperative_matrix_store.ir | 37 ++++ 5 files changed, 421 insertions(+), 2 deletions(-) create mode 100644 test/spv/cooperative_matrix_load.ir create mode 100644 test/spv/cooperative_matrix_store.ir diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index bc8ca063..19f60ae2 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -225,6 +225,41 @@ auto inst_converter::make_binary_op(scalar_type sty, arithmetic op, spv_inst *ty throw compilation_error(loc, status::internal_compiler_error); } +auto inst_converter::make_conditional_execution( + spv_inst *returned_element_ty, spv_inst *condition, + std::function()> conditional_code, + location const &loc) -> std::vector { + auto then_label = std::make_unique(); + auto merge_label = std::make_unique(); + + auto init_last_label = get_last_label(); + if (!init_last_label) { + throw compilation_error(loc, status::internal_compiler_error); + } + mod_->add(merge_label.get(), SelectionControl::None); + mod_->add(condition, then_label.get(), merge_label.get(), + std::vector{}); + mod_->insts(section::function).push_back(then_label.release()); + std::vector loaded_values = conditional_code(); + mod_->add(merge_label.get()); + auto then_last_label = get_last_label(); + if (!then_last_label) { + throw compilation_error(loc, status::internal_compiler_error); + } + + mod_->insts(section::function).push_back(merge_label.release()); + + auto yielded_values = std::vector{}; + yielded_values.reserve(loaded_values.size()); + auto alternative = PairIdRefIdRef{unique_.null_constant(returned_element_ty), init_last_label}; + for (auto &value : loaded_values) { + yielded_values.emplace_back(mod_->add( + returned_element_ty, + std::vector{PairIdRefIdRef{value, then_last_label}, alternative})); + } + return yielded_values; +} + auto inst_converter::make_constant(scalar_type sty, spv_inst *spv_ty, constant_inst::value_type const &val) -> spv_inst * { auto const add_constant_complex = [this, &spv_ty](auto cst) -> spv_inst * { @@ -788,7 +823,141 @@ void inst_converter::operator()(constant_inst const &in) { } } -void inst_converter::operator()(cooperative_matrix_load_inst const &in) {} +void inst_converter::operator()(cooperative_matrix_load_inst const &in) { + auto spv_boolean_ty = unique_.spv_ty(boolean_data_type::get(mod_->context())); + auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); + auto spv_operand_ty = unique_.spv_ty(in.operand().ty()); + auto spv_ty = unique_.spv_ty(in.result(0).ty()); + auto ot = get_memref_type(in.operand()); + auto odv = get_dope_vector(in.operand()); + if (!odv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + auto rt = get_coopmatrix_type(in.result(0)); + + const int rmode = rt->distributed_mode(); + const int omode = in.t() == transpose::T ? 1 - rmode : rmode; + const bool check_m = in.checked() == checked_flag::both || + (rmode == 0 && in.checked() == checked_flag::rows) || + (rmode == 1 && in.checked() == checked_flag::cols); + const bool check_k = in.checked() == checked_flag::both || + (rmode == 1 && in.checked() == checked_flag::rows) || + (rmode == 0 && in.checked() == checked_flag::cols); + const bool enable_sub_group_reads = core_cfg_.block_read_write_supported && + !is_complex_type(ot->element_ty()) && + in.t() == transpose::N && ot->stride(omode) == 1; + + spv_inst *pv[] = {val(in.pos0()), val(in.pos1())}; + auto get_rem = [&, rem = std::array{}](std::int64_t index) mutable { + if (rem[index] == nullptr) { + rem[index] = mod_->add(spv_index_ty, odv->shape(index), pv[index]); + } + return rem[index]; + }; + + auto pv0_stride0 = mod_->add(spv_index_ty, pv[0], odv->stride(0)); + auto pv1_stride1 = mod_->add(spv_index_ty, pv[1], odv->stride(1)); + auto offset = mod_->add(spv_index_ty, pv0_stride0, pv1_stride1); + auto pointer = mod_->add(spv_operand_ty, val(in.operand()), offset, + std::vector{}); + + const auto block_load = [&](spv_inst *offset) -> spv_inst * { + auto sub_pointer = mod_->add(spv_operand_ty, pointer, offset, + std::vector{}); + auto const cast_load_cast = [&](scalar_type int_sty) { + auto int_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), int_sty)); + const auto storage_cls = address_space_to_storage_class(ot->addrspace()); + auto int_ptr_ty = unique_.spv_pointer_ty(storage_cls, int_ty, ot->alignment()); + auto int_ptr = mod_->add(int_ptr_ty, sub_pointer); + auto value = mod_->add(int_ty, int_ptr); + return mod_->add(spv_ty, value); + }; + switch (ot->element_ty()) { + case scalar_type::f32: + return cast_load_cast(scalar_type::i32); + case scalar_type::f64: + return cast_load_cast(scalar_type::i64); + default: + return mod_->add(spv_ty, sub_pointer); + } + }; + const auto gather_load = [&](spv_inst *offset) -> spv_inst * { + auto sub_pointer = mod_->add(spv_operand_ty, pointer, offset, + std::vector{}); + return mod_->add(spv_ty, sub_pointer); + }; + + spv_inst *m = load_builtin(BuiltIn::SubgroupLocalInvocationId); + m = mod_->add(spv_index_ty, m); + const std::int64_t num_blocks = rt->num_blocks(core_cfg_.subgroup_size); + const std::int64_t K = rt->shape(1 - rmode); + auto loaded_values = std::vector{}; + loaded_values.reserve(K * num_blocks); + for (std::int64_t block = 0; block < num_blocks; ++block) { + auto const remainder = rt->shape(rmode) - core_cfg_.subgroup_size * block; + const bool needs_mask = remainder < core_cfg_.subgroup_size; + + const auto load_block_impl = [&](auto load_impl, + spv_inst *offset) -> std::vector { + auto result = std::vector{}; + result.reserve(K); + for (std::int64_t k = 0; k < K; ++k) { + if (check_k) { + auto check1 = mod_->add(spv_boolean_ty, unique_.constant(-k), + pv[1 - omode]); + auto check2 = mod_->add(spv_boolean_ty, unique_.constant(k), + get_rem(1 - omode)); + auto cond = mod_->add(spv_boolean_ty, check1, check2); + auto vals = make_conditional_execution( + spv_ty, cond, [&] { return std::vector{load_impl(offset)}; }, + in.loc()); + result.insert(result.end(), vals.begin(), vals.end()); + } else { + result.emplace_back(load_impl(offset)); + } + if (k + 1 < K) { + offset = mod_->add(spv_index_ty, offset, odv->stride(1 - omode)); + } + } + return result; + }; + + auto const load_block = [&]() -> std::vector { + if (enable_sub_group_reads && !needs_mask && !check_m) { + return load_block_impl(block_load, + unique_.constant(block * core_cfg_.subgroup_size)); + } else { + spv_inst *offset = m; + if (block > 0) { + offset = mod_->add(spv_index_ty, offset, + unique_.constant(block * core_cfg_.subgroup_size)); + } + offset = mod_->add(spv_index_ty, offset, odv->stride(omode)); + return load_block_impl(gather_load, offset); + } + }; + spv_inst *cond = nullptr; + if (check_m) { + spv_inst *m_offset = m; + if (block > 0) { + m_offset = mod_->add(spv_index_ty, m_offset, + unique_.constant(block * core_cfg_.subgroup_size)); + } + auto neg = mod_->add(spv_index_ty, pv[omode]); + auto check1 = mod_->add(spv_boolean_ty, m_offset, neg); + auto check2 = mod_->add(spv_boolean_ty, m_offset, get_rem(omode)); + cond = mod_->add(spv_boolean_ty, check1, check2); + } + if (needs_mask) { + spv_inst *mask = mod_->add(spv_boolean_ty, m, unique_.constant(remainder)); + cond = cond ? mod_->add(spv_boolean_ty, cond, mask) : mask; + } + auto block_vals = + cond ? make_conditional_execution(spv_ty, cond, load_block, in.loc()) : load_block(); + loaded_values.insert(loaded_values.end(), block_vals.begin(), block_vals.end()); + } + multi_declare(in.result(0), std::move(loaded_values)); +} void inst_converter::operator()(cooperative_matrix_mul_add_inst const &in) {} void inst_converter::operator()(cooperative_matrix_scale_inst const &in) { auto av = val(in.a()); @@ -805,7 +974,116 @@ void inst_converter::operator()(cooperative_matrix_scale_inst const &in) { multi_declare(in.result(0), std::move(insts)); } -void inst_converter::operator()(cooperative_matrix_store_inst const &in) {} +void inst_converter::operator()(cooperative_matrix_store_inst const &in) { + auto spv_boolean_ty = unique_.spv_ty(boolean_data_type::get(mod_->context())); + auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); + auto spv_operand_ty = unique_.spv_ty(in.operand().ty()); + auto spv_ty = unique_.spv_ty(in.val().ty()); + auto ot = get_memref_type(in.operand()); + auto odv = get_dope_vector(in.operand()); + if (!odv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + auto vt = get_coopmatrix_type(in.val()); + + const int vmode = vt->distributed_mode(); + const int omode = vmode; + const bool check_m = in.checked() == checked_flag::both || + (vmode == 0 && in.checked() == checked_flag::rows) || + (vmode == 1 && in.checked() == checked_flag::cols); + const bool check_k = in.checked() == checked_flag::both || + (vmode == 1 && in.checked() == checked_flag::rows) || + (vmode == 0 && in.checked() == checked_flag::cols); + + spv_inst *pv[] = {val(in.pos0()), val(in.pos1())}; + auto get_rem = [&, rem = std::array{}](std::int64_t index) mutable { + if (rem[index] == nullptr) { + rem[index] = mod_->add(spv_index_ty, odv->shape(index), pv[index]); + } + return rem[index]; + }; + + auto pv0_stride0 = mod_->add(spv_index_ty, pv[0], odv->stride(0)); + auto pv1_stride1 = mod_->add(spv_index_ty, pv[1], odv->stride(1)); + auto offset = mod_->add(spv_index_ty, pv0_stride0, pv1_stride1); + auto pointer = mod_->add(spv_operand_ty, val(in.operand()), offset, + std::vector{}); + + const auto scatter_store = [&](spv_inst *offset, spv_inst *value) { + auto sub_pointer = mod_->add(spv_operand_ty, pointer, offset, + std::vector{}); + make_store(in.flag(), ot->element_ty(), ot->addrspace(), sub_pointer, value); + }; + + spv_inst *m = load_builtin(BuiltIn::SubgroupLocalInvocationId); + m = mod_->add(spv_index_ty, m); + const std::int64_t num_blocks = vt->num_blocks(core_cfg_.subgroup_size); + const std::int64_t K = vt->shape(1 - vmode); + auto &values = multi_val(in.val()); + for (std::int64_t block = 0; block < num_blocks; ++block) { + auto const remainder = vt->shape(vmode) - core_cfg_.subgroup_size * block; + const bool needs_mask = remainder < core_cfg_.subgroup_size; + + const auto store_block_impl = [&](auto store_impl, + spv_inst *offset) -> std::vector { + for (std::int64_t k = 0; k < K; ++k) { + auto &value = values[k + block * vt->shape(1 - vmode)]; + if (check_k) { + auto check1 = mod_->add(spv_boolean_ty, unique_.constant(-k), + pv[1 - omode]); + auto check2 = mod_->add(spv_boolean_ty, unique_.constant(k), + get_rem(1 - omode)); + auto cond = mod_->add(spv_boolean_ty, check1, check2); + auto vals = make_conditional_execution( + spv_ty, cond, + [&] { + store_impl(offset, value); + return std::vector{}; + }, + in.loc()); + } else { + store_impl(offset, value); + } + if (k + 1 < K) { + offset = mod_->add(spv_index_ty, offset, odv->stride(1 - omode)); + } + } + return {}; + }; + + auto const store_block = [&]() -> std::vector { + spv_inst *offset = m; + if (block > 0) { + offset = mod_->add(spv_index_ty, offset, + unique_.constant(block * core_cfg_.subgroup_size)); + } + offset = mod_->add(spv_index_ty, offset, odv->stride(omode)); + store_block_impl(scatter_store, offset); + return {}; + }; + spv_inst *cond = nullptr; + if (check_m) { + spv_inst *m_offset = m; + if (block > 0) { + m_offset = mod_->add(spv_index_ty, m_offset, + unique_.constant(block * core_cfg_.subgroup_size)); + } + auto neg = mod_->add(spv_index_ty, pv[omode]); + auto check1 = mod_->add(spv_boolean_ty, m_offset, neg); + auto check2 = mod_->add(spv_boolean_ty, m_offset, get_rem(omode)); + cond = mod_->add(spv_boolean_ty, check1, check2); + } + if (needs_mask) { + spv_inst *mask = mod_->add(spv_boolean_ty, m, unique_.constant(remainder)); + cond = cond ? mod_->add(spv_boolean_ty, cond, mask) : mask; + } + if (cond) { + make_conditional_execution(spv_ty, cond, store_block, in.loc()); + } else { + store_block(); + } + } +} void inst_converter::operator()(expand_inst const &in) { auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index e9ab863c..30677201 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -19,6 +19,7 @@ #include "tinytc/types.hpp" #include +#include #include #include #include @@ -127,6 +128,9 @@ class inst_converter { auto multi_val(tinytc_value const &v) -> std::vector &; auto make_binary_op(scalar_type sty, arithmetic op, spv_inst *ty, spv_inst *a, spv_inst *b, location const &loc) -> spv_inst *; + auto make_conditional_execution(spv_inst *returned_element_ty, spv_inst *condition, + std::function()> conditional_code, + location const &loc) -> std::vector; auto make_constant(scalar_type sty, spv_inst *spv_ty, constant_inst::value_type const &val) -> spv_inst *; auto make_dope_vector(tinytc_value const &v) -> dope_vector *; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 9891f3ca..6015b3a0 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -59,7 +59,9 @@ if(SPIRVTools_FOUND) spv/cast.ir spv/calling_convention.ir spv/compare.ir + spv/cooperative_matrix_load.ir spv/cooperative_matrix_scale.ir + spv/cooperative_matrix_store.ir spv/expand.ir spv/for.ir spv/fuse.ir diff --git a/test/spv/cooperative_matrix_load.ir b/test/spv/cooperative_matrix_load.ir new file mode 100644 index 00000000..61d64f55 --- /dev/null +++ b/test/spv/cooperative_matrix_load.ir @@ -0,0 +1,98 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s + +; CHECK: %[[#I32:]] = OpTypeInt 32 0 +; CHECK: %[[#I32_PTR:]] = OpTypePointer CrossWorkgroup %[[#I32]] +; CHECK: %[[#I64:]] = OpTypeInt 64 0 +; CHECK: %[[#I64_C64:]] = OpConstant %[[#I64]] 64 +; CHECK: %[[#I64_C48:]] = OpConstant %[[#I64]] 48 +; CHECK: %[[#I64_C1:]] = OpConstant %[[#I64]] 1 +; CHECK: %[[#BOOL:]] = OpTypeBool +; CHECK: %[[#I64_C0:]] = OpConstant %[[#I64]] 0 +; CHECK: %[[#I32_NULL:]] = OpConstantNull %[[#I32]] +; CHECK: %[[#I64_C16:]] = OpConstant %[[#I64]] 16 + + +func @coopmatrix_a_load_n(%A: memref, %x: index, %y: index) subgroup_size(16) { + %0 = cooperative_matrix_load.n %A[%x,%y] : memref -> coopmatrix +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#T1_MR:]] = OpFunctionParameter %[[#I32_PTR]] +; CHECK-NEXT: %[[#T1_X:]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#T1_Y:]] = OpFunctionParameter %[[#I64]] +; CHECK: %[[#T1_X_STRIDE:]] = OpIMul %[[#I64]] %[[#T1_X]] %[[#I64_C1]] +; CHECK-NEXT: %[[#T1_Y_STRIDE:]] = OpIMul %[[#I64]] %[[#T1_Y]] %[[#I64_C64]] +; CHECK-NEXT: %[[#T1_OFFSET0:]] = OpIAdd %[[#I64]] %[[#T1_X_STRIDE]] %[[#T1_Y_STRIDE]] +; CHECK-NEXT: %[[#T1_POINTER:]] = OpInBoundsPtrAccessChain %[[#I32_PTR]] %[[#T1_MR]] %[[#T1_OFFSET0]] +; CHECK: %[[#T1_SUBPTR1:]] = OpInBoundsPtrAccessChain %[[#I32_PTR]] %[[#T1_POINTER]] %[[#I64_C0]] +; CHECK-NEXT: %[[#]] = OpSubgroupBlockReadINTEL %[[#I32]] %[[#T1_SUBPTR1]] +; CHECK-NEXT: %[[#T1_OFFSET1:]] = OpIAdd %[[#I64]] %[[#I64_C0]] %[[#I64_C64]] +; CHECK-NEXT: %[[#T1_SUBPTR2:]] = OpInBoundsPtrAccessChain %[[#I32_PTR]] %[[#T1_POINTER]] %[[#T1_OFFSET1]] +; CHECK-NEXT: %[[#]] = OpSubgroupBlockReadINTEL %[[#I32]] %[[#T1_SUBPTR2]] +; CHECK-NEXT: %[[#T1_OFFSET2:]] = OpIAdd %[[#I64]] %[[#T1_OFFSET1]] %[[#I64_C64]] +; CHECK-NEXT: %[[#T1_SUBPTR3:]] = OpInBoundsPtrAccessChain %[[#I32_PTR]] %[[#T1_POINTER]] %[[#T1_OFFSET2]] +; CHECK-NEXT: %[[#]] = OpSubgroupBlockReadINTEL %[[#I32]] %[[#T1_SUBPTR3]] +} + +func @coopmatrix_a_load_n_rows_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { + %0 = cooperative_matrix_load.n.rows_checked %A[%x,%y] : memref -> coopmatrix +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I32_PTR]] +; CHECK-NEXT: %[[#T2_X:]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#T2_BEGIN:]] = OpLabel +; CHECK: %[[#T2_M_I32:]] = OpLoad %[[#I32]] %[[#]] Aligned 4 +; CHECK-NEXT: %[[#T2_M:]] = OpSConvert %[[#I64]] %[[#T2_M_I32]] +; CHECK-NEXT: %[[#T2_NEG_X:]] = OpSNegate %[[#I64]] %[[#T2_X]] +; CHECK-NEXT: %[[#T2_MCHECK1:]] = OpSGreaterThanEqual %[[#BOOL]] %[[#T2_M]] %[[#T2_NEG_X]] +; CHECK-NEXT: %[[#T2_MREM:]] = OpISub %[[#I64]] %[[#I64_C64]] %[[#T2_X]] +; CHECK-NEXT: %[[#T2_MCHECK2:]] = OpSLessThan %[[#BOOL]] %[[#T2_M]] %[[#T2_MREM]] +; CHECK-NEXT: %[[#T2_MCOND1:]] = OpLogicalAnd %[[#BOOL]] %[[#T2_MCHECK1]] %[[#T2_MCHECK2]] +; CHECK: OpBranchConditional %[[#T2_MCOND1]] %[[#T2_THEN1:]] %[[#T2_MERGE1:]] +; CHECK-NEXT: %[[#T2_THEN1]] = OpLabel +; CHECK: %[[#T2_V1:]] = OpLoad %[[#I32]] %[[#]] +; CHECK: %[[#T2_V2:]] = OpLoad %[[#I32]] %[[#]] +; CHECK-NEXT: OpBranch %[[#T2_MERGE1]] +; CHECK-NEXT: %[[#T2_MERGE1]] = OpLabel +; CHECK-NEXT: %[[#]] = OpPhi %[[#I32]] %[[#T2_V1]] %[[#T2_THEN1]] %[[#I32_NULL]] %[[#T2_BEGIN]] +; CHECK-NEXT: %[[#]] = OpPhi %[[#I32]] %[[#T2_V2]] %[[#T2_THEN1]] %[[#I32_NULL]] %[[#T2_BEGIN]] +; CHECK-NEXT: %[[#T2_M_B2:]] = OpIAdd %[[#I64]] %[[#T2_M]] %[[#I64_C16]] +; CHECK-NEXT: %[[#T2_NEG_M_B2:]] = OpSNegate %[[#I64]] %[[#T2_X]] +; CHECK-NEXT: %[[#T2_MCHECK3:]] = OpSGreaterThanEqual %[[#BOOL]] %[[#T2_M_B2]] %[[#T2_NEG_M_B2]] +; CHECK-NEXT: %[[#T2_MCHECK4:]] = OpSLessThan %[[#BOOL]] %[[#T2_M_B2]] %[[#T2_MREM]] +; CHECK-NEXT: %[[#]] = OpLogicalAnd %[[#BOOL]] %[[#T2_MCHECK3]] %[[#T2_MCHECK4]] +; CHECK: %[[#T2_MTHEN2:]] = OpLabel +; CHECK: %[[#T2_V3:]] = OpLoad %[[#I32]] %[[#]] +; CHECK: %[[#T2_V4:]] = OpLoad %[[#I32]] %[[#]] +; CHECK: %[[#]] = OpPhi %[[#I32]] %[[#T2_V3]] %[[#T2_MTHEN2]] %[[#I32_NULL]] %[[#T2_MERGE1]] +; CHECK-NEXT: %[[#]] = OpPhi %[[#I32]] %[[#T2_V4]] %[[#T2_MTHEN2]] %[[#I32_NULL]] %[[#T2_MERGE1]] +} + +func @coopmatrix_a_load_n_cols_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { + %0 = cooperative_matrix_load.n.cols_checked %A[%x,%y] : memref -> coopmatrix +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I32_PTR]] +; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#T3_Y:]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#T3_BEGIN:]] = OpLabel +; CHECK: %[[#T3_KCHECK1:]] = OpSLessThanEqual %[[#BOOL]] %[[#I64_C0]] %[[#T3_Y]] +; CHECK-NEXT: %[[#T3_KREM:]] = OpISub %[[#I64]] %[[#I64_C48]] %[[#T3_Y]] +; CHECK-NEXT: %[[#T3_KCHECK2:]] = OpSLessThan %[[#BOOL]] %[[#I64_C0]] %[[#T3_KREM]] +; CHECK-NEXT: %[[#T3_KCOND:]] = OpLogicalAnd %[[#BOOL]] %[[#T3_KCHECK1]] %[[#T3_KCHECK2]] +; CHECK: OpBranchConditional %[[#T3_KCOND]] %[[#T3_THEN1:]] %[[#T3_MERGE1:]] +; CHECK-NEXT: %[[#T3_THEN1]] = OpLabel +; CHECK-NEXT: %[[#T3_SUBPTR1:]] = OpInBoundsPtrAccessChain %[[#I32_PTR]] %[[#]] %[[#I64_C0]] +; CHECK-NEXT: %[[#T3_V1:]] = OpSubgroupBlockReadINTEL %[[#I32]] %[[#T3_SUBPTR1]] +; CHECK-NEXT: OpBranch %[[#T3_MERGE1]] +; CHECK-NEXT: %[[#T3_MERGE1]] = OpLabel +; CHECK-NEXT: %[[#]] = OpPhi %[[#I32]] %[[#T3_V1]] %[[#T3_THEN1]] %[[#I32_NULL]] %[[#T3_BEGIN]] +} + +func @coopmatrix_a_load_t(%A: memref, %x: index, %y: index) subgroup_size(16) { + %0 = cooperative_matrix_load.t %A[%x,%y] : memref -> coopmatrix +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK: %[[#T4_M:]] = OpSConvert %[[#I64]] %[[#]] +; CHECK: %[[#T4_OFFSET:]] = OpIMul %[[#I64]] %[[#T4_M]] %[[#I64_C64]] +; CHECK: %[[#]] = OpIAdd %[[#I64]] %[[#T4_OFFSET]] %[[#I64_C1]] +} diff --git a/test/spv/cooperative_matrix_store.ir b/test/spv/cooperative_matrix_store.ir new file mode 100644 index 00000000..0958028c --- /dev/null +++ b/test/spv/cooperative_matrix_store.ir @@ -0,0 +1,37 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s + +; CHECK: %[[#I32:]] = OpTypeInt 32 0 +; CHECK: %[[#I32_PTR:]] = OpTypePointer CrossWorkgroup %[[#I32]] +; CHECK: %[[#I64:]] = OpTypeInt 64 0 +; CHECK: %[[#I64_C64:]] = OpConstant %[[#I64]] 64 +; CHECK: %[[#I64_C48:]] = OpConstant %[[#I64]] 48 +; CHECK: %[[#I64_C1:]] = OpConstant %[[#I64]] 1 +; CHECK: %[[#BOOL:]] = OpTypeBool + + +func @coopmatrix_a_store_n(%A: memref, %x: index, %y: index) subgroup_size(16) { + %0 = constant 1 -> coopmatrix + cooperative_matrix_store %0, %A[%x,%y] : coopmatrix, memref +; CHECK: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#T1_MR:]] = OpFunctionParameter %[[#I32_PTR]] +; CHECK-NEXT: %[[#T1_X:]] = OpFunctionParameter %[[#I64]] +; CHECK-NEXT: %[[#T1_Y:]] = OpFunctionParameter %[[#I64]] +; CHECK: %[[#T1_X_STRIDE:]] = OpIMul %[[#I64]] %[[#T1_X]] %[[#I64_C1]] +; CHECK-NEXT: %[[#T1_Y_STRIDE:]] = OpIMul %[[#I64]] %[[#T1_Y]] %[[#I64_C64]] +; CHECK-NEXT: %[[#T1_OFFSET0:]] = OpIAdd %[[#I64]] %[[#T1_X_STRIDE]] %[[#T1_Y_STRIDE]] +; CHECK-NEXT: %[[#T1_POINTER:]] = OpInBoundsPtrAccessChain %[[#I32_PTR]] %[[#T1_MR]] %[[#T1_OFFSET0]] +; CHECK: %[[#T1_M:]] = OpSConvert %[[#I64]] %[[#]] +; CHECK: %[[#T1_OFFSET0:]] = OpIMul %[[#I64]] %[[#T1_M]] %[[#I64_C1]] +; CHECK: %[[#T1_SUBPTR1:]] = OpInBoundsPtrAccessChain %[[#I32_PTR]] %[[#T1_POINTER]] %[[#T1_OFFSET0]] +; CHECK-NEXT: OpStore %[[#T1_SUBPTR1]] %[[#]] +; CHECK-NEXT: %[[#T1_OFFSET1:]] = OpIAdd %[[#I64]] %[[#T1_OFFSET0]] %[[#I64_C64]] +; CHECK-NEXT: %[[#T1_SUBPTR2:]] = OpInBoundsPtrAccessChain %[[#I32_PTR]] %[[#T1_POINTER]] %[[#T1_OFFSET1]] +; CHECK-NEXT: OpStore %[[#T1_SUBPTR2]] %[[#]] +; CHECK-NEXT: %[[#T1_OFFSET2:]] = OpIAdd %[[#I64]] %[[#T1_OFFSET1]] %[[#I64_C64]] +; CHECK-NEXT: %[[#T1_SUBPTR3:]] = OpInBoundsPtrAccessChain %[[#I32_PTR]] %[[#T1_POINTER]] %[[#T1_OFFSET2]] +; CHECK-NEXT: OpStore %[[#T1_SUBPTR3]] %[[#]] +} + From 066aa74f0c71c95b4a0fe0eb71111ee1bb219647 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 14 Nov 2024 12:20:21 +0100 Subject: [PATCH 110/297] SPIR-V: Fix complex mul and implement mul_add Signed-off-by: Carsten Uphoff --- src/spv/converter.cpp | 399 ++++++++++++++++++++++++++++++----------- src/spv/converter.hpp | 6 + src/spv/pass/capex.cpp | 1 + src/spv/pass/capex.hpp | 1 + src/spv/uniquifier.cpp | 4 + src/spv/uniquifier.hpp | 1 + test/spv/arith.ir | 33 +++- 7 files changed, 336 insertions(+), 109 deletions(-) diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index 19f60ae2..6ee1ba23 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -191,8 +191,8 @@ auto inst_converter::make_binary_op(scalar_type sty, arithmetic op, spv_inst *ty } throw compilation_error(loc, status::internal_compiler_error); }; - auto const make_float_complex = [&](arithmetic op, spv_inst *ty, spv_inst *a, - spv_inst *b) -> spv_inst * { + auto const make_float = [&](arithmetic op, spv_inst *ty, spv_inst *a, + spv_inst *b) -> spv_inst * { switch (op) { case arithmetic::add: return mod_->add(ty, a, b); @@ -209,6 +209,35 @@ auto inst_converter::make_binary_op(scalar_type sty, arithmetic op, spv_inst *ty } throw compilation_error(loc, status::ir_fp_unsupported); }; + auto const make_complex = [&](arithmetic op, spv_inst *ty, spv_inst *float_ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (op) { + case arithmetic::add: + return mod_->add(ty, a, b); + case arithmetic::sub: + return mod_->add(ty, a, b); + case arithmetic::mul: { + return make_complex_mul(ty, a, b); + } + case arithmetic::div: { + auto a_times_conj_b = make_complex_mul(ty, a, b, true); + + auto b_squared = mod_->add(ty, b, b); + auto b_squared_0 = + mod_->add(float_ty, b_squared, std::vector{0}); + auto b_squared_1 = + mod_->add(float_ty, b_squared, std::vector{1}); + spv_inst *b_abs = mod_->add(float_ty, b_squared_0, b_squared_1); + auto dummy = mod_->add(ty); + b_abs = mod_->add(ty, b_abs, dummy, std::vector{0}); + b_abs = mod_->add(ty, b_abs, dummy, std::vector{0, 0}); + return mod_->add(ty, a_times_conj_b, b_abs); + } + default: + break; + } + throw compilation_error(loc, status::ir_complex_unsupported); + }; switch (sty) { case scalar_type::i8: case scalar_type::i16: @@ -218,9 +247,101 @@ auto inst_converter::make_binary_op(scalar_type sty, arithmetic op, spv_inst *ty return make_int(op, ty, a, b); case scalar_type::f32: case scalar_type::f64: + return make_float(op, ty, a, b); case scalar_type::c32: case scalar_type::c64: - return make_float_complex(op, ty, a, b); + return make_complex(op, ty, unique_.spv_ty(element_type(sty)), a, b); + } + throw compilation_error(loc, status::internal_compiler_error); +} + +auto inst_converter::make_complex_mul(spv_inst *ty, spv_inst *a, spv_inst *b, + bool conj_b) -> spv_inst * { + auto neg_a = mod_->add(ty, a); + auto a_times_i = + conj_b ? mod_->add(ty, a, neg_a, std::vector{1, 2}) + : mod_->add(ty, neg_a, a, std::vector{1, 2}); + auto b_1 = mod_->add(ty, b, b, std::vector{1, 1}); + auto b_1_a_times_i = mod_->add(ty, b_1, a_times_i); + auto b_0 = mod_->add(ty, b, b, std::vector{0, 0}); + return mod_->add(ty, unique_.opencl_ext(), + static_cast(OpenCLEntrypoint::fma), + std::vector{a, b_0, b_1_a_times_i}); +} + +auto inst_converter::make_cast(scalar_type to_ty, scalar_type a_ty, spv_inst *spv_to_ty, + spv_inst *a, location const &loc) -> spv_inst * { + auto const cast_from_int = [&](scalar_type to_ty, spv_inst *spv_to_ty, + spv_inst *a) -> spv_inst * { + switch (to_ty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return mod_->add(spv_to_ty, a); + case scalar_type::f32: + case scalar_type::f64: + return mod_->add(spv_to_ty, a); + case scalar_type::c32: + case scalar_type::c64: { + auto spv_float_ty = unique_.spv_ty(element_type(to_ty)); + auto re = mod_->add(spv_float_ty, a); + return mod_->add(spv_to_ty, re, unique_.null_constant(spv_to_ty), + std::vector{0}); + } + } + throw compilation_error(loc, status::ir_forbidden_cast); + }; + auto const cast_from_float = [&](scalar_type to_ty, spv_inst *spv_to_ty, + spv_inst *a) -> spv_inst * { + switch (to_ty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return mod_->add(spv_to_ty, a); + case scalar_type::f32: + case scalar_type::f64: + return mod_->add(spv_to_ty, a); + case scalar_type::c32: + case scalar_type::c64: { + auto spv_float_ty = unique_.spv_ty(element_type(to_ty)); + auto re = mod_->add(spv_float_ty, a); + return mod_->add(spv_to_ty, re, unique_.null_constant(spv_to_ty), + std::vector{0}); + } + } + throw compilation_error(loc, status::ir_forbidden_cast); + }; + auto const cast_from_complex = [&](scalar_type to_ty, spv_inst *spv_to_ty, + spv_inst *a) -> spv_inst * { + switch (to_ty) { + case scalar_type::c32: + case scalar_type::c64: + return mod_->add(spv_to_ty, a); + default: + throw compilation_error(loc, status::ir_forbidden_cast); + } + }; + if (a_ty == to_ty) { + return mod_->add(spv_to_ty, a); + } + switch (a_ty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return cast_from_int(to_ty, spv_to_ty, a); + case scalar_type::f32: + case scalar_type::f64: + return cast_from_float(to_ty, spv_to_ty, a); + case scalar_type::c32: + case scalar_type::c64: { + return cast_from_complex(to_ty, spv_to_ty, a); + } } throw compilation_error(loc, status::internal_compiler_error); } @@ -314,7 +435,7 @@ auto inst_converter::make_dope_vector(tinytc_value const &v) -> dope_vector * { throw compilation_error(v.loc(), status::internal_compiler_error); } - auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); + auto spv_index_ty = unique_.spv_ty(scalar_type::index); return ::tinytc::visit( overloaded{[&](memref_data_type const &mr) -> dope_vector * { return &(dope_vec_[&v] = dope_vector{spv_index_ty, mr.shape(), mr.stride()}); @@ -335,11 +456,82 @@ auto inst_converter::make_dope_vector(tinytc_value const &v) -> dope_vector * { *v.ty()); } +auto inst_converter::make_mixed_precision_fma(scalar_type a_ty, scalar_type b_ty, scalar_type c_ty, + spv_inst *a, spv_inst *b, spv_inst *c, + location const &loc) -> spv_inst * { + auto const make_mul_same_type = [this, &loc](scalar_type mul_ty, spv_inst *spv_mul_ty, + spv_inst *a, spv_inst *b) -> spv_inst * { + switch (mul_ty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return mod_->add(spv_mul_ty, a, b); + case scalar_type::f32: + case scalar_type::f64: + return mod_->add(spv_mul_ty, a, b); + case scalar_type::c32: + case scalar_type::c64: + return make_complex_mul(spv_mul_ty, a, b); + } + throw compilation_error(loc, status::internal_compiler_error); + }; + auto const make_mul_mixed_type = [this, &loc, &make_mul_same_type]( + scalar_type mul_ty, scalar_type a_ty, scalar_type b_ty, + spv_inst *a, spv_inst *b) { + const bool a_non_complex_b_complex = !is_complex_type(a_ty) && is_complex_type(b_ty); + const bool a_complex_b_non_complex = is_complex_type(a_ty) && !is_complex_type(b_ty); + if (a_non_complex_b_complex || a_complex_b_non_complex) { + if (a_complex_b_non_complex) { + std::swap(a, b); + std::swap(a_ty, b_ty); + } + mul_ty = element_type(mul_ty); + auto spv_mul_ty = unique_.spv_ty(mul_ty); + if (a_ty != mul_ty) { + a = make_cast(mul_ty, a_ty, spv_mul_ty, a, loc); + } + auto spv_b_ty = unique_.spv_ty(b_ty); + auto dummy = mod_->add(spv_b_ty); + a = mod_->add(spv_b_ty, a, dummy, std::vector{0}); + a = mod_->add(spv_b_ty, a, dummy, std::vector{0, 0}); + return make_mul_same_type(mul_ty, spv_mul_ty, a, b); + } else { + auto spv_mul_ty = unique_.spv_ty(mul_ty); + if (a_ty != mul_ty) { + a = make_cast(mul_ty, a_ty, spv_mul_ty, a, loc); + } + if (b_ty != mul_ty) { + b = make_cast(mul_ty, b_ty, spv_mul_ty, b, loc); + } + return make_mul_same_type(mul_ty, spv_mul_ty, a, b); + } + }; + + if (a_ty == b_ty && b_ty == c_ty && !is_complex_type(a_ty)) { + auto spv_c_ty = unique_.spv_ty(c_ty); + return mod_->add(spv_c_ty, unique_.opencl_ext(), + static_cast(OpenCLEntrypoint::fma), + std::vector{a, b, c}); + } + + auto mul_ty = compatible_type(a_ty, b_ty); + auto product = make_mul_mixed_type(mul_ty, a_ty, b_ty, a, b); + + auto add_ty = compatible_type(mul_ty, c_ty); + auto spv_add_ty = unique_.spv_ty(add_ty); + if (mul_ty != add_ty) { + product = make_cast(add_ty, mul_ty, spv_add_ty, product, loc); + } + return make_binary_op(add_ty, arithmetic::add, spv_add_ty, c, product, loc); +} + void inst_converter::make_store(store_flag flag, scalar_type sty, address_space as, spv_inst *pointer, spv_inst *value) { auto const split_re_im = [&]() -> std::array, 2u> { auto component_sty = element_type(sty); - auto float_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), component_sty)); + auto float_ty = unique_.spv_ty(component_sty); const auto storage_cls = address_space_to_storage_class(as); auto pointer_ty = unique_.spv_pointer_ty(storage_cls, float_ty, alignment(component_sty)); auto c0 = unique_.constant(std::int32_t{0}); @@ -374,7 +566,7 @@ void inst_converter::make_store(store_flag flag, scalar_type sty, address_space break; } case store_flag::atomic_add: { - auto result_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), sty)); + auto result_ty = unique_.spv_ty(sty); auto scope = unique_.constant(static_cast(Scope::Workgroup)); auto semantics = unique_.constant(static_cast(MemorySemantics::Relaxed)); switch (sty) { @@ -393,7 +585,7 @@ void inst_converter::make_store(store_flag flag, scalar_type sty, address_space case scalar_type::c64: { auto re_im = split_re_im(); auto component_sty = element_type(sty); - auto float_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), component_sty)); + auto float_ty = unique_.spv_ty(component_sty); mod_->add(float_ty, re_im[0][0], scope, semantics, re_im[0][1]); mod_->add(float_ty, re_im[1][0], scope, semantics, re_im[1][1]); break; @@ -424,7 +616,7 @@ void inst_converter::operator()(alloca_inst const &in) { throw compilation_error(in.loc(), status::ir_insufficient_alignment); } - auto stack_element_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::i8)); + auto stack_element_ty = unique_.spv_ty(scalar_type::i8); auto stack_ptr_ty = unique_.spv_pointer_ty(StorageClass::Workgroup, stack_element_ty, alignment(scalar_type::i8)); auto stack_ptr = mod_->add( @@ -529,7 +721,7 @@ void inst_converter::operator()(arith_unary_inst const &in) { spv_inst *a) -> spv_inst * { switch (op) { case arithmetic_unary::abs: { - auto spv_a_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), sty)); + auto spv_a_ty = unique_.spv_ty(sty); auto a2 = mod_->add(spv_a_ty, a, a); auto a2_0 = mod_->add(ty, a2, std::vector{0}); auto a2_1 = mod_->add(ty, a2, std::vector{1}); @@ -541,8 +733,7 @@ void inst_converter::operator()(arith_unary_inst const &in) { case arithmetic_unary::neg: return mod_->add(ty, a); case arithmetic_unary::conj: { - auto spv_float_ty = - unique_.spv_ty(scalar_data_type::get(mod_->context(), element_type(sty))); + auto spv_float_ty = unique_.spv_ty(element_type(sty)); auto a_im = mod_->add(spv_float_ty, a, std::vector{1}); auto neg_a_im = mod_->add(spv_float_ty, a_im); @@ -616,91 +807,12 @@ void inst_converter::operator()(barrier_inst const &in) { } void inst_converter::operator()(cast_inst const &in) { - auto const cast_from_int = [&](scalar_type to_ty, spv_inst *spv_to_ty, - spv_inst *a) -> spv_inst * { - switch (to_ty) { - case scalar_type::i8: - case scalar_type::i16: - case scalar_type::i32: - case scalar_type::i64: - case scalar_type::index: - return mod_->add(spv_to_ty, a); - case scalar_type::f32: - case scalar_type::f64: - return mod_->add(spv_to_ty, a); - case scalar_type::c32: - case scalar_type::c64: { - auto spv_float_ty = - unique_.spv_ty(scalar_data_type::get(mod_->context(), element_type(to_ty))); - auto re = mod_->add(spv_float_ty, a); - return mod_->add(spv_to_ty, re, unique_.null_constant(spv_to_ty), - std::vector{0}); - } - } - throw compilation_error(in.loc(), status::ir_forbidden_cast); - }; - auto const cast_from_float = [&](scalar_type to_ty, spv_inst *spv_to_ty, - spv_inst *a) -> spv_inst * { - switch (to_ty) { - case scalar_type::i8: - case scalar_type::i16: - case scalar_type::i32: - case scalar_type::i64: - case scalar_type::index: - return mod_->add(spv_to_ty, a); - case scalar_type::f32: - case scalar_type::f64: - return mod_->add(spv_to_ty, a); - case scalar_type::c32: - case scalar_type::c64: { - auto spv_float_ty = - unique_.spv_ty(scalar_data_type::get(mod_->context(), element_type(to_ty))); - auto re = mod_->add(spv_float_ty, a); - return mod_->add(spv_to_ty, re, unique_.null_constant(spv_to_ty), - std::vector{0}); - } - } - throw compilation_error(in.loc(), status::ir_forbidden_cast); - }; - auto const cast_from_complex = [&](scalar_type to_ty, spv_inst *spv_to_ty, - spv_inst *a) -> spv_inst * { - switch (to_ty) { - case scalar_type::c32: - case scalar_type::c64: - return mod_->add(spv_to_ty, a); - default: - throw compilation_error(in.loc(), status::ir_forbidden_cast); - } - }; - auto const make = [&](scalar_type to_ty, scalar_type a_ty, spv_inst *spv_to_ty, - spv_inst *a) -> spv_inst * { - if (a_ty == to_ty) { - return mod_->add(spv_to_ty, a); - } - switch (a_ty) { - case scalar_type::i8: - case scalar_type::i16: - case scalar_type::i32: - case scalar_type::i64: - case scalar_type::index: - return cast_from_int(to_ty, spv_to_ty, a); - case scalar_type::f32: - case scalar_type::f64: - return cast_from_float(to_ty, spv_to_ty, a); - case scalar_type::c32: - case scalar_type::c64: { - return cast_from_complex(to_ty, spv_to_ty, a); - } - } - throw compilation_error(in.loc(), status::internal_compiler_error); - }; - auto spv_to_ty = unique_.spv_ty(in.result(0).ty()); if (auto st = dyn_cast(in.result(0).ty()); st) { auto av = val(in.a()); auto a_ty = get_scalar_type(in.a()); - declare(in.result(0), make(st->ty(), a_ty, spv_to_ty, av)); + declare(in.result(0), make_cast(st->ty(), a_ty, spv_to_ty, av, in.loc())); } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { auto const length = ct->length(core_cfg_.subgroup_size); auto insts = std::vector{}; @@ -709,7 +821,7 @@ void inst_converter::operator()(cast_inst const &in) { auto &av = multi_val(in.a()); auto a_ty = get_coopmatrix_type(in.a())->component_ty(); for (std::int64_t i = 0; i < length; ++i) { - insts.emplace_back(make(ct->component_ty(), a_ty, spv_to_ty, av[i])); + insts.emplace_back(make_cast(ct->component_ty(), a_ty, spv_to_ty, av[i], in.loc())); } multi_declare(in.result(0), std::move(insts)); @@ -825,7 +937,7 @@ void inst_converter::operator()(constant_inst const &in) { void inst_converter::operator()(cooperative_matrix_load_inst const &in) { auto spv_boolean_ty = unique_.spv_ty(boolean_data_type::get(mod_->context())); - auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); + auto spv_index_ty = unique_.spv_ty(scalar_type::index); auto spv_operand_ty = unique_.spv_ty(in.operand().ty()); auto spv_ty = unique_.spv_ty(in.result(0).ty()); auto ot = get_memref_type(in.operand()); @@ -865,7 +977,7 @@ void inst_converter::operator()(cooperative_matrix_load_inst const &in) { auto sub_pointer = mod_->add(spv_operand_ty, pointer, offset, std::vector{}); auto const cast_load_cast = [&](scalar_type int_sty) { - auto int_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), int_sty)); + auto int_ty = unique_.spv_ty(int_sty); const auto storage_cls = address_space_to_storage_class(ot->addrspace()); auto int_ptr_ty = unique_.spv_pointer_ty(storage_cls, int_ty, ot->alignment()); auto int_ptr = mod_->add(int_ptr_ty, sub_pointer); @@ -958,7 +1070,85 @@ void inst_converter::operator()(cooperative_matrix_load_inst const &in) { } multi_declare(in.result(0), std::move(loaded_values)); } -void inst_converter::operator()(cooperative_matrix_mul_add_inst const &in) {} +void inst_converter::operator()(cooperative_matrix_mul_add_inst const &in) { + auto at = get_coopmatrix_type(in.a()); + auto bt = get_coopmatrix_type(in.b()); + auto ct = get_coopmatrix_type(in.c()); + auto rt = get_coopmatrix_type(in.result(0)); + auto &av = multi_val(in.a()); + auto &bv = multi_val(in.b()); + auto &cv = multi_val(in.c()); + + const auto a_ty = at->component_ty(); + const auto b_ty = bt->component_ty(); + const auto b_component_ty = bt->component_ty(); + const auto c_ty = ct->component_ty(); + const auto r_ty = rt->component_ty(); + const auto spv_b_ty = unique_.spv_ty(bt->ty()); + const auto spv_b_component_ty = unique_.spv_ty(b_component_ty); + const auto spv_c_ty = unique_.spv_ty(ct->ty()); + const auto spv_r_ty = unique_.spv_ty(rt->ty()); + const bool a_and_b_complex = is_complex_type(a_ty) && is_complex_type(b_ty); + + const std::int64_t N = rt->cols(), K = at->cols(); + + const std::int64_t num_blocks = rt->num_blocks(core_cfg_.subgroup_size); + constexpr std::int64_t nbb = 4; + auto broadcast_scope = unique_.constant(static_cast(Scope::Subgroup)); + + auto result = std::vector(cv.begin(), cv.end()); + auto result_im = a_and_b_complex + ? std::vector(cv.size(), unique_.null_constant(spv_c_ty)) + : std::vector{}; + for (std::int64_t m_block = 0; m_block < num_blocks; ++m_block) { + for (std::int64_t nb = 0; nb < N; nb += nbb) { + for (std::int64_t k = 0; k < K; ++k) { + for (std::int64_t n = 0; n < nbb; ++n) { + if (nb + n < N) { + auto const n_block = (nb + n) / core_cfg_.subgroup_size; + auto n_offset = unique_.constant((nb + n) % core_cfg_.subgroup_size); + + auto a = av[k + m_block * K]; + auto b = bv[k + n_block * K]; + auto b_bc = + mod_->add(spv_b_ty, broadcast_scope, b, n_offset); + auto &c = result[nb + n + m_block * N]; + + if (a_and_b_complex) { + auto &c_im = result_im[nb + n + m_block * N]; + auto b_bc_re = mod_->add( + spv_b_component_ty, b_bc, std::vector{0}); + auto b_bc_im = mod_->add( + spv_b_component_ty, b_bc, std::vector{1}); + c = make_mixed_precision_fma(a_ty, b_component_ty, c_ty, a, b_bc_re, c, + in.loc()); + c_im = make_mixed_precision_fma(a_ty, b_component_ty, c_ty, a, b_bc_im, + c, in.loc()); + } else { + c = make_mixed_precision_fma(a_ty, b_ty, c_ty, a, b_bc, c, in.loc()); + } + } + } + } + } + } + if (a_and_b_complex) { + for (std::size_t i = 0; i < result.size(); ++i) { + auto &c = result[i]; + auto c_im = result_im[i]; + auto neg_c_im = mod_->add(spv_c_ty, c_im); + auto c_im_times_i = mod_->add(spv_c_ty, neg_c_im, c_im, + std::vector{1, 2}); + c = mod_->add(spv_c_ty, c, c_im_times_i); + } + } + for (auto &r : result) { + if (c_ty != r_ty) { + r = make_cast(r_ty, c_ty, spv_r_ty, r, in.loc()); + } + } + multi_declare(in.result(0), std::move(result)); +} void inst_converter::operator()(cooperative_matrix_scale_inst const &in) { auto av = val(in.a()); auto &bv = multi_val(in.b()); @@ -976,7 +1166,7 @@ void inst_converter::operator()(cooperative_matrix_scale_inst const &in) { } void inst_converter::operator()(cooperative_matrix_store_inst const &in) { auto spv_boolean_ty = unique_.spv_ty(boolean_data_type::get(mod_->context())); - auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); + auto spv_index_ty = unique_.spv_ty(scalar_type::index); auto spv_operand_ty = unique_.spv_ty(in.operand().ty()); auto spv_ty = unique_.spv_ty(in.val().ty()); auto ot = get_memref_type(in.operand()); @@ -1086,7 +1276,7 @@ void inst_converter::operator()(cooperative_matrix_store_inst const &in) { } void inst_converter::operator()(expand_inst const &in) { - auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); + auto spv_index_ty = unique_.spv_ty(scalar_type::index); auto shape = std::vector{}; auto stride = std::vector{}; @@ -1262,7 +1452,7 @@ void inst_converter::operator()(for_inst const &in) { } void inst_converter::operator()(fuse_inst const &in) { - auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); + auto spv_index_ty = unique_.spv_ty(scalar_type::index); auto shape = std::vector{}; auto stride = std::vector{}; @@ -1309,13 +1499,13 @@ void inst_converter::operator()(fuse_inst const &in) { void inst_converter::operator()(group_id_inst const &in) { auto gid = load_builtin(BuiltIn::GlobalInvocationId); - auto index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); + auto index_ty = unique_.spv_ty(scalar_type::index); declare(in.result(0), mod_->add(index_ty, gid, std::vector{2})); } void inst_converter::operator()(group_size_inst const &in) { auto gs = load_builtin(BuiltIn::GlobalSize); - auto index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); + auto index_ty = unique_.spv_ty(scalar_type::index); declare(in.result(0), mod_->add(index_ty, gs, std::vector{2})); } @@ -1376,7 +1566,7 @@ void inst_converter::operator()(if_inst const &in) { void inst_converter::operator()(lifetime_stop_inst const &) {} void inst_converter::operator()(load_inst const &in) { - auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); + auto spv_index_ty = unique_.spv_ty(scalar_type::index); auto spv_pointer_index_ty = unique_.spv_pointer_ty(StorageClass::CrossWorkgroup, spv_index_ty, alignment(scalar_type::i64)); auto spv_pointer_ty = unique_.spv_ty(in.operand().ty()); @@ -1442,7 +1632,7 @@ void inst_converter::operator()(size_inst const &in) { } void inst_converter::operator()(store_inst const &in) { - auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); + auto spv_index_ty = unique_.spv_ty(scalar_type::index); auto spv_pointer_ty = unique_.spv_ty(in.operand().ty()); auto dv = get_dope_vector(in.operand()); if (!dv) { @@ -1482,7 +1672,7 @@ void inst_converter::operator()(subgroup_size_inst const &in) { } void inst_converter::operator()(subview_inst const &in) { - auto spv_index_ty = unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::index)); + auto spv_index_ty = unique_.spv_ty(scalar_type::index); auto spv_result_ty = unique_.spv_ty(in.result(0).ty()); auto shape_out = std::vector{}; @@ -1624,8 +1814,7 @@ void inst_converter::run_on_function(function_node const &fn, core_config const auto const make_stack = [&] { const auto high_water_mark = stack_high_water_mark{}.run_on_function(fn); if (high_water_mark > 0) { - auto stack_element_ty = - unique_.spv_ty(scalar_data_type::get(mod_->context(), scalar_type::i8)); + auto stack_element_ty = unique_.spv_ty(scalar_type::i8); auto stack_array_ty = unique_.spv_array_ty(stack_element_ty, high_water_mark); auto stack_ptr_ty = unique_.spv_pointer_ty(StorageClass::Workgroup, stack_array_ty, diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index 30677201..4aeefca3 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -128,6 +128,12 @@ class inst_converter { auto multi_val(tinytc_value const &v) -> std::vector &; auto make_binary_op(scalar_type sty, arithmetic op, spv_inst *ty, spv_inst *a, spv_inst *b, location const &loc) -> spv_inst *; + auto make_complex_mul(spv_inst *ty, spv_inst *a, spv_inst *b, + bool conj_b = false) -> spv_inst *; + auto make_mixed_precision_fma(scalar_type a_ty, scalar_type b_ty, scalar_type c_ty, spv_inst *a, + spv_inst *b, spv_inst *c, location const &loc) -> spv_inst *; + auto make_cast(scalar_type to_ty, scalar_type a_ty, spv_inst *spv_to_ty, spv_inst *a, + location const &loc) -> spv_inst *; auto make_conditional_execution(spv_inst *returned_element_ty, spv_inst *condition, std::function()> conditional_code, location const &loc) -> std::vector; diff --git a/src/spv/pass/capex.cpp b/src/spv/pass/capex.cpp index 71d92bde..c355330b 100644 --- a/src/spv/pass/capex.cpp +++ b/src/spv/pass/capex.cpp @@ -50,6 +50,7 @@ void capex::operator()(OpExecutionMode const &in) { unique_->capability(cap); } } +void capex::operator()(OpGroupBroadcast const &) { unique_->capability(Capability::Groups); } void capex::operator()(OpGroupFAdd const &) { unique_->capability(Capability::Groups); } void capex::operator()(OpGroupIAdd const &) { unique_->capability(Capability::Groups); } void capex::operator()(OpInBoundsPtrAccessChain const &) { diff --git a/src/spv/pass/capex.hpp b/src/spv/pass/capex.hpp index 5a3e5ad1..f6cf1a99 100644 --- a/src/spv/pass/capex.hpp +++ b/src/spv/pass/capex.hpp @@ -19,6 +19,7 @@ class capex { void operator()(OpAtomicFAddEXT const &in); void operator()(OpEntryPoint const &in); void operator()(OpExecutionMode const &in); + void operator()(OpGroupBroadcast const &in); void operator()(OpGroupFAdd const &in); void operator()(OpGroupIAdd const &in); void operator()(OpInBoundsPtrAccessChain const &in); diff --git a/src/spv/uniquifier.cpp b/src/spv/uniquifier.cpp index 235cb9a7..cc818b25 100644 --- a/src/spv/uniquifier.cpp +++ b/src/spv/uniquifier.cpp @@ -255,5 +255,9 @@ auto uniquifier::spv_ty(const_tinytc_data_type_t ty) -> spv_inst * { }); } +auto uniquifier::spv_ty(scalar_type sty) -> spv_inst * { + return spv_ty(scalar_data_type::get(mod_->context(), sty)); +} + } // namespace tinytc::spv diff --git a/src/spv/uniquifier.hpp b/src/spv/uniquifier.hpp index 2cc72725..d805798f 100644 --- a/src/spv/uniquifier.hpp +++ b/src/spv/uniquifier.hpp @@ -45,6 +45,7 @@ class uniquifier { auto spv_pointer_ty(StorageClass cls, spv_inst *pointee_ty, std::int32_t alignment) -> spv_inst *; auto spv_ty(const_tinytc_data_type_t ty) -> spv_inst *; + auto spv_ty(scalar_type sty) -> spv_inst *; private: template diff --git a/test/spv/arith.ir b/test/spv/arith.ir index 62cf63cc..7a7113d7 100644 --- a/test/spv/arith.ir +++ b/test/spv/arith.ir @@ -13,6 +13,7 @@ func @tbool(%a: bool, %b: bool) { %0 = arith.and %a, %b : bool %1 = arith.or %a, %b : bool %2 = arith.xor %a, %b : bool +; CHECK-LABEL: %[[#]] = OpFunction {{.*}} ; CHECK: %[[#]] = OpLogicalAnd %[[#BOOL]] %[[#]] %[[#]] ; CHECK-NEXT: %[[#]] = OpLogicalOr %[[#BOOL]] %[[#]] %[[#]] ; CHECK-NEXT: %[[#]] = OpLogicalNotEqual %[[#BOOL]] %[[#]] %[[#]] @@ -29,6 +30,7 @@ func @tint(%a: i64, %b: i64) { %7 = arith.and %a, %b : i64 %8 = arith.or %a, %b : i64 %9 = arith.xor %a, %b : i64 +; CHECK-LABEL: %[[#]] = OpFunction {{.*}} ; CHECK: %[[#]] = OpIAdd %[[#I64]] %[[#]] %[[#]] ; CHECK-NEXT: %[[#]] = OpISub %[[#I64]] %[[#]] %[[#]] ; CHECK-NEXT: %[[#]] = OpIMul %[[#I64]] %[[#]] %[[#]] @@ -47,6 +49,7 @@ func @tfloat(%a: f32, %b: f32) { %2 = arith.mul %a, %b : f32 %3 = arith.div %a, %b : f32 %4 = arith.rem %a, %b : f32 +; CHECK-LABEL: %[[#]] = OpFunction {{.*}} ; CHECK: %[[#]] = OpFAdd %[[#F32]] %[[#]] %[[#]] ; CHECK-NEXT: %[[#]] = OpFSub %[[#F32]] %[[#]] %[[#]] ; CHECK-NEXT: %[[#]] = OpFMul %[[#F32]] %[[#]] %[[#]] @@ -59,10 +62,31 @@ func @tcomplex(%a: c32, %b: c32) { %1 = arith.sub %a, %b : c32 %2 = arith.mul %a, %b : c32 %3 = arith.div %a, %b : c32 -; CHECK: %[[#]] = OpFAdd %[[#C32]] %[[#]] %[[#]] -; CHECK-NEXT: %[[#]] = OpFSub %[[#C32]] %[[#]] %[[#]] -; CHECK-NEXT: %[[#]] = OpFMul %[[#C32]] %[[#]] %[[#]] -; CHECK-NEXT: %[[#]] = OpFDiv %[[#C32]] %[[#]] %[[#]] +; CHECK-LABEL: %[[#]] = OpFunction {{.*}} +; CHECK-NEXT: %[[#TC_A:]] = OpFunctionParameter %[[#C32]] +; CHECK-NEXT: %[[#TC_B:]] = OpFunctionParameter %[[#C32]] +; CHECK: %[[#]] = OpFAdd %[[#C32]] %[[#TC_A]] %[[#TC_B]] +; CHECK-NEXT: %[[#]] = OpFSub %[[#C32]] %[[#TC_A]] %[[#TC_B]] +; CHECK-NEXT: %[[#TC_NEG_A:]] = OpFNegate %[[#C32]] %[[#TC_A]] +; CHECK-NEXT: %[[#TC_A_TIMES_I:]] = OpVectorShuffle %[[#C32]] %[[#TC_NEG_A]] %[[#TC_A]] 1 2 +; CHECK-NEXT: %[[#TC_B_1:]] = OpVectorShuffle %[[#C32]] %[[#TC_B]] %[[#TC_B]] 1 1 +; CHECK-NEXT: %[[#TC_B_1_A_TIMES_I:]] = OpFMul %[[#C32]] %[[#TC_B_1]] %[[#TC_A_TIMES_I]] +; CHECK-NEXT: %[[#TC_B_0:]] = OpVectorShuffle %[[#C32]] %[[#TC_B]] %[[#TC_B]] 0 0 +; CHECK-NEXT: %[[#]] = OpExtInst %[[#C32]] %[[#]] fma %[[#TC_A]] %[[#TC_B_0]] %[[#TC_B_1_A_TIMES_I]] +; CHECK-NEXT: %[[#TC_NEG_A_2:]] = OpFNegate %[[#C32]] %[[#TC_A]] +; CHECK-NEXT: %[[#TC_A_TIMES_CONJ_I:]] = OpVectorShuffle %[[#C32]] %[[#TC_A]] %[[#TC_NEG_A_2]] 1 2 +; CHECK-NEXT: %[[#TC_B_1_2:]] = OpVectorShuffle %[[#C32]] %[[#TC_B]] %[[#TC_B]] 1 1 +; CHECK-NEXT: %[[#TC_B_1_2_A_TIMES_CONJ_I:]] = OpFMul %[[#C32]] %[[#TC_B_1_2]] %[[#TC_A_TIMES_CONJ_I]] +; CHECK-NEXT: %[[#TC_B_0_2:]] = OpVectorShuffle %[[#C32]] %[[#TC_B]] %[[#TC_B]] 0 0 +; CHECK-NEXT: %[[#TC_A_TIMES_CONJ_B:]] = OpExtInst %[[#C32]] %1 fma %[[#TC_A]] %[[#TC_B_0_2]] %[[#TC_B_1_2_A_TIMES_CONJ_I]] +; CHECK-NEXT: %[[#TC_B_SQUARED:]] = OpFMul %[[#C32]] %[[#TC_B]] %[[#TC_B]] +; CHECK-NEXT: %[[#TC_B_SQUARED_0:]] = OpCompositeExtract %[[#F32]] %[[#TC_B_SQUARED]] 0 +; CHECK-NEXT: %[[#TC_B_SQUARED_1:]] = OpCompositeExtract %[[#F32]] %[[#TC_B_SQUARED]] 1 +; CHECK-NEXT: %[[#TC_B_ABS:]] = OpFAdd %[[#F32]] %[[#TC_B_SQUARED_0]] %[[#TC_B_SQUARED_1]] +; CHECK-NEXT: %[[#TC_DUMMY:]] = OpUndef %[[#C32]] +; CHECK-NEXT: %[[#TC_B_ABS_1:]] = OpCompositeInsert %[[#C32]] %[[#TC_B_ABS]] %[[#TC_DUMMY]] 0 +; CHECK-NEXT: %[[#TC_B_ABS_2:]] = OpVectorShuffle %[[#C32]] %[[#TC_B_ABS_1]] %[[#TC_DUMMY]] 0 0 +; CHECK-NEXT: %[[#]] = OpFDiv %[[#C32]] %[[#TC_A_TIMES_CONJ_B]] %[[#TC_B_ABS_2]] } func @tfloatcoopmatrix() subgroup_size(16) { @@ -72,6 +96,7 @@ func @tfloatcoopmatrix() subgroup_size(16) { %3 = arith.sub %0, %1 : coopmatrix %4 = arith.mul %0, %1 : coopmatrix %5 = arith.div %0, %1 : coopmatrix +; CHECK-LABEL: %[[#]] = OpFunction {{.*}} ; CHECK: %[[#]] = OpFAdd %[[#F32]] %[[#]] %[[#]] ; CHECK-NEXT: %[[#]] = OpFAdd %[[#F32]] %[[#]] %[[#]] ; CHECK-NEXT: %[[#]] = OpFAdd %[[#F32]] %[[#]] %[[#]] From c74ab53270ca328e7d1c35ea319afda80b1385bf Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 14 Nov 2024 18:30:06 +0100 Subject: [PATCH 111/297] SPIR-V: coopmatrix mul_add test Signed-off-by: Carsten Uphoff --- src/spv/converter.cpp | 111 ++++++++++--------------- test/CMakeLists.txt | 1 + test/spv/cooperative_matrix_mul_add.ir | 98 ++++++++++++++++++++++ 3 files changed, 145 insertions(+), 65 deletions(-) create mode 100644 test/spv/cooperative_matrix_mul_add.ir diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index 6ee1ba23..e41ea04a 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -459,68 +459,49 @@ auto inst_converter::make_dope_vector(tinytc_value const &v) -> dope_vector * { auto inst_converter::make_mixed_precision_fma(scalar_type a_ty, scalar_type b_ty, scalar_type c_ty, spv_inst *a, spv_inst *b, spv_inst *c, location const &loc) -> spv_inst * { - auto const make_mul_same_type = [this, &loc](scalar_type mul_ty, spv_inst *spv_mul_ty, - spv_inst *a, spv_inst *b) -> spv_inst * { - switch (mul_ty) { - case scalar_type::i8: - case scalar_type::i16: - case scalar_type::i32: - case scalar_type::i64: - case scalar_type::index: - return mod_->add(spv_mul_ty, a, b); - case scalar_type::f32: - case scalar_type::f64: - return mod_->add(spv_mul_ty, a, b); - case scalar_type::c32: - case scalar_type::c64: - return make_complex_mul(spv_mul_ty, a, b); - } - throw compilation_error(loc, status::internal_compiler_error); - }; - auto const make_mul_mixed_type = [this, &loc, &make_mul_same_type]( - scalar_type mul_ty, scalar_type a_ty, scalar_type b_ty, - spv_inst *a, spv_inst *b) { - const bool a_non_complex_b_complex = !is_complex_type(a_ty) && is_complex_type(b_ty); - const bool a_complex_b_non_complex = is_complex_type(a_ty) && !is_complex_type(b_ty); - if (a_non_complex_b_complex || a_complex_b_non_complex) { - if (a_complex_b_non_complex) { - std::swap(a, b); - std::swap(a_ty, b_ty); - } - mul_ty = element_type(mul_ty); - auto spv_mul_ty = unique_.spv_ty(mul_ty); - if (a_ty != mul_ty) { - a = make_cast(mul_ty, a_ty, spv_mul_ty, a, loc); - } - auto spv_b_ty = unique_.spv_ty(b_ty); - auto dummy = mod_->add(spv_b_ty); - a = mod_->add(spv_b_ty, a, dummy, std::vector{0}); - a = mod_->add(spv_b_ty, a, dummy, std::vector{0, 0}); - return make_mul_same_type(mul_ty, spv_mul_ty, a, b); - } else { - auto spv_mul_ty = unique_.spv_ty(mul_ty); - if (a_ty != mul_ty) { - a = make_cast(mul_ty, a_ty, spv_mul_ty, a, loc); - } - if (b_ty != mul_ty) { - b = make_cast(mul_ty, b_ty, spv_mul_ty, b, loc); - } - return make_mul_same_type(mul_ty, spv_mul_ty, a, b); - } - }; + const auto mul_ty = compatible_type(a_ty, b_ty); + const auto add_ty = compatible_type(mul_ty, c_ty); + auto spv_mul_ty = unique_.spv_ty(mul_ty); + auto spv_add_ty = unique_.spv_ty(add_ty); - if (a_ty == b_ty && b_ty == c_ty && !is_complex_type(a_ty)) { - auto spv_c_ty = unique_.spv_ty(c_ty); - return mod_->add(spv_c_ty, unique_.opencl_ext(), + // Normalize such that a is compatible with b and we cast to the type of b + if (a_ty != mul_ty && b_ty != mul_ty) { + throw compilation_error(loc, status::internal_compiler_error, + "compatible type must be either type of a or type of b"); + } + if (b_ty != mul_ty) { + std::swap(a, b); + std::swap(a_ty, b_ty); + } + + const bool a_non_complex_b_complex = !is_complex_type(a_ty) && is_complex_type(b_ty); + const bool a_complex_b_complex = is_complex_type(a_ty) && is_complex_type(b_ty); + const auto a_cast_ty = a_non_complex_b_complex ? element_type(mul_ty) : mul_ty; + auto spv_a_cast_ty = unique_.spv_ty(a_cast_ty); + + if (a_ty != a_cast_ty) { + a = make_cast(a_cast_ty, a_ty, spv_a_cast_ty, a, loc); + } + if (a_cast_ty != mul_ty) { + auto dummy = mod_->add(spv_mul_ty); + a = mod_->add(spv_mul_ty, a, dummy, std::vector{0}); + a = mod_->add(spv_mul_ty, a, dummy, std::vector{0, 0}); + } + if (!a_complex_b_complex && mul_ty == add_ty) { + return mod_->add(spv_mul_ty, unique_.opencl_ext(), static_cast(OpenCLEntrypoint::fma), std::vector{a, b, c}); } - auto mul_ty = compatible_type(a_ty, b_ty); - auto product = make_mul_mixed_type(mul_ty, a_ty, b_ty, a, b); + auto product = [&]() -> spv_inst * { + if (a_complex_b_complex) { + return make_complex_mul(spv_mul_ty, a, b); + } else if (a_non_complex_b_complex || is_floating_type(mul_ty)) { + return mod_->add(spv_mul_ty, a, b); + } + return mod_->add(spv_mul_ty, a, b); + }(); - auto add_ty = compatible_type(mul_ty, c_ty); - auto spv_add_ty = unique_.spv_ty(add_ty); if (mul_ty != add_ty) { product = make_cast(add_ty, mul_ty, spv_add_ty, product, loc); } @@ -1081,7 +1062,7 @@ void inst_converter::operator()(cooperative_matrix_mul_add_inst const &in) { const auto a_ty = at->component_ty(); const auto b_ty = bt->component_ty(); - const auto b_component_ty = bt->component_ty(); + const auto b_component_ty = element_type(b_ty); const auto c_ty = ct->component_ty(); const auto r_ty = rt->component_ty(); const auto spv_b_ty = unique_.spv_ty(bt->ty()); @@ -1090,20 +1071,20 @@ void inst_converter::operator()(cooperative_matrix_mul_add_inst const &in) { const auto spv_r_ty = unique_.spv_ty(rt->ty()); const bool a_and_b_complex = is_complex_type(a_ty) && is_complex_type(b_ty); - const std::int64_t N = rt->cols(), K = at->cols(); + const std::int32_t N = rt->cols(), K = at->cols(); - const std::int64_t num_blocks = rt->num_blocks(core_cfg_.subgroup_size); - constexpr std::int64_t nbb = 4; + const std::int32_t num_blocks = rt->num_blocks(core_cfg_.subgroup_size); + constexpr std::int32_t nbb = 4; auto broadcast_scope = unique_.constant(static_cast(Scope::Subgroup)); auto result = std::vector(cv.begin(), cv.end()); auto result_im = a_and_b_complex ? std::vector(cv.size(), unique_.null_constant(spv_c_ty)) : std::vector{}; - for (std::int64_t m_block = 0; m_block < num_blocks; ++m_block) { - for (std::int64_t nb = 0; nb < N; nb += nbb) { - for (std::int64_t k = 0; k < K; ++k) { - for (std::int64_t n = 0; n < nbb; ++n) { + for (std::int32_t m_block = 0; m_block < num_blocks; ++m_block) { + for (std::int32_t nb = 0; nb < N; nb += nbb) { + for (std::int32_t k = 0; k < K; ++k) { + for (std::int32_t n = 0; n < nbb; ++n) { if (nb + n < N) { auto const n_block = (nb + n) / core_cfg_.subgroup_size; auto n_offset = unique_.constant((nb + n) % core_cfg_.subgroup_size); @@ -1123,7 +1104,7 @@ void inst_converter::operator()(cooperative_matrix_mul_add_inst const &in) { c = make_mixed_precision_fma(a_ty, b_component_ty, c_ty, a, b_bc_re, c, in.loc()); c_im = make_mixed_precision_fma(a_ty, b_component_ty, c_ty, a, b_bc_im, - c, in.loc()); + c_im, in.loc()); } else { c = make_mixed_precision_fma(a_ty, b_ty, c_ty, a, b_bc, c, in.loc()); } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 6015b3a0..a0774c5e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -60,6 +60,7 @@ if(SPIRVTools_FOUND) spv/calling_convention.ir spv/compare.ir spv/cooperative_matrix_load.ir + spv/cooperative_matrix_mul_add.ir spv/cooperative_matrix_scale.ir spv/cooperative_matrix_store.ir spv/expand.ir diff --git a/test/spv/cooperative_matrix_mul_add.ir b/test/spv/cooperative_matrix_mul_add.ir new file mode 100644 index 00000000..0eb745bb --- /dev/null +++ b/test/spv/cooperative_matrix_mul_add.ir @@ -0,0 +1,98 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -O0 -gspirv -S < %s | filecheck %s + +; CHECK: %[[#F32:]] = OpTypeFloat 32 +; CHECK: %[[#F32_C1:]] = OpConstant %[[#F32]] 0x1p+0 +; CHECK: %[[#F32_C2:]] = OpConstant %[[#F32]] 0x1p+1 +; CHECK: %[[#F32_C3:]] = OpConstant %[[#F32]] 0x1.8p+1 +; CHECK: %[[#I32:]] = OpTypeInt 32 0 +; CHECK: %[[#I32_C3:]] = OpConstant %[[#I32]] 3 +; CHECK: %[[#I32_C0:]] = OpConstant %[[#I32]] 0 +; CHECK: %[[#I32_C1:]] = OpConstant %[[#I32]] 1 +; CHECK: %[[#C32:]] = OpTypeVector %[[#F32]] 2 +; CHECK: %[[#F32_C0:]] = OpConstant %[[#F32]] 0x0p+0 +; CHECK: %[[#C32_C1_0:]] = OpConstantComposite %[[#C32]] %[[#F32_C1]] %[[#F32_C0]] +; CHECK: %[[#C32_C3_0:]] = OpConstantComposite %[[#C32]] %[[#F32_C3]] %[[#F32_C0]] +; CHECK: %[[#I32_C2:]] = OpConstant %[[#I32]] 2 +; CHECK: %[[#C32_C1_0_2:]] = OpConstantComposite %[[#C32]] %[[#F32_C1]] %[[#F32_C0]] +; CHECK: %[[#C32_C2_0:]] = OpConstantComposite %[[#C32]] %[[#F32_C2]] %[[#F32_C0]] +; CHECK: %[[#C32_C3_0_2:]] = OpConstantComposite %[[#C32]] %[[#F32_C3]] %[[#F32_C0]] +; CHECK: %[[#C32_NULL:]] = OpConstantNull %[[#C32]] +; CHECK: %[[#I16:]] = OpTypeInt 16 0 +; CHECK: %[[#I16_C1:]] = OpConstant %[[#I16]] 1 +; CHECK: %[[#I64:]] = OpTypeInt 64 0 +; CHECK: %[[#I64_C3:]] = OpConstant %[[#I64]] 3 + +func @coopmatrix_mul_add_ff() subgroup_size(16) { + %a = constant 1.0 -> coopmatrix + %b = constant 2.0 -> coopmatrix + %c = constant 3.0 -> coopmatrix + %c_next = cooperative_matrix_mul_add %a, %b, %c + : coopmatrix, coopmatrix, + coopmatrix -> coopmatrix +; CHECK-LABEL: %[[#]] = OpFunction {{.*}} +; CHECK: %[[#FF_B0:]] = OpGroupBroadcast %[[#F32]] %[[#I32_C3]] %[[#F32_C2]] %[[#I32_C0]] +; CHECK-NEXT: %[[#FF_C0:]] = OpExtInst %[[#F32]] %[[#]] fma %[[#F32_C1]] %[[#FF_B0]] %[[#F32_C3]] +; CHECK-NEXT: %[[#FF_B1:]] = OpGroupBroadcast %[[#F32]] %[[#I32_C3]] %[[#F32_C2]] %[[#I32_C1]] +; CHECK-NEXT: %[[#FF_C1:]] = OpExtInst %[[#F32]] %[[#]] fma %[[#F32_C1]] %[[#FF_B1]] %[[#F32_C3]] +; CHECK-NEXT: %[[#FF_B2:]] = OpGroupBroadcast %[[#F32]] %[[#I32_C3]] %[[#F32_C2]] %[[#I32_C0]] +; CHECK-NEXT: %[[#]] = OpExtInst %[[#F32]] %[[#]] fma %[[#F32_C1]] %[[#FF_B2]] %[[#FF_C0]] +; CHECK-NEXT: %[[#FF_B3:]] = OpGroupBroadcast %[[#F32]] %[[#I32_C3]] %[[#F32_C2]] %[[#I32_C1]] +; CHECK-NEXT: %[[#]] = OpExtInst %[[#F32]] %[[#]] fma %[[#F32_C1]] %[[#FF_B3]] %[[#FF_C1]] +} + +func @coopmatrix_mul_add_cf() subgroup_size(16) { + %a = constant [1.0, 0.0] -> coopmatrix + %b = constant 2.0 -> coopmatrix + %c = constant [3.0, 0.0] -> coopmatrix + %c_next = cooperative_matrix_mul_add %a, %b, %c + : coopmatrix, coopmatrix, + coopmatrix -> coopmatrix +; CHECK-LABEL: %[[#]] = OpFunction {{.*}} +; CHECK: %[[#CF_B0:]] = OpGroupBroadcast %[[#F32]] %[[#I32_C3]] %[[#F32_C2]] %[[#I32_C0]] +; CHECK-NEXT: %[[#CF_DUMMY:]] = OpUndef %[[#C32]] +; CHECK-NEXT: %[[#CF_B0_SPLAT1:]] = OpCompositeInsert %[[#C32]] %[[#CF_B0]] %[[#CF_DUMMY]] 0 +; CHECK-NEXT: %[[#CF_B0_SPLAT2:]] = OpVectorShuffle %[[#C32]] %[[#CF_B0_SPLAT1]] %[[#CF_DUMMY]] 0 0 +; CHECK-NEXT: %[[#]] = OpExtInst %[[#C32]] %[[#]] fma %[[#CF_B0_SPLAT2]] %[[#C32_C1_0]] %[[#C32_C3_0]] +} + +func @coopmatrix_mul_add_cc() subgroup_size(16) { + %a = constant [1.0, 0.0] -> coopmatrix + %b = constant [2.0, 0.0] -> coopmatrix + %c = constant [3.0, 0.0] -> coopmatrix + %c_next = cooperative_matrix_mul_add %a, %b, %c + : coopmatrix, coopmatrix, + coopmatrix -> coopmatrix +; CHECK-LABEL: %[[#]] = OpFunction {{.*}} +; CHECK: %[[#CC_B0:]] = OpGroupBroadcast %[[#C32]] %[[#I32_C3]] %[[#C32_C2_0]] %[[#I32_C0]] +; CHECK-NEXT: %[[#CC_B0_RE:]] = OpCompositeExtract %[[#F32]] %[[#CC_B0]] 0 +; CHECK-NEXT: %[[#CC_B0_IM:]] = OpCompositeExtract %[[#F32]] %[[#CC_B0]] 1 +; CHECK-NEXT: %[[#CC_DUMMY:]] = OpUndef %[[#C32]] +; CHECK-NEXT: %[[#CC_B0_RE_SPLAT1:]] = OpCompositeInsert %[[#C32]] %[[#CC_B0_RE]] %[[#CC_DUMMY]] 0 +; CHECK-NEXT: %[[#CC_B0_RE_SPLAT2:]] = OpVectorShuffle %[[#C32]] %[[#CC_B0_RE_SPLAT1]] %[[#CC_DUMMY]] 0 0 +; CHECK-NEXT: %[[#]] = OpExtInst %[[#C32]] %[[#]] fma %[[#CC_B0_RE_SPLAT2]] %[[#C32_C1_0_2]] %[[#C32_C3_0_2]] +; CHECK-NEXT: %[[#CC_DUMMY2:]] = OpUndef %[[#C32]] +; CHECK-NEXT: %[[#CC_B0_IM_SPLAT1:]] = OpCompositeInsert %[[#C32]] %[[#CC_B0_IM]] %[[#CC_DUMMY2]] 0 +; CHECK-NEXT: %[[#CC_B0_IM_SPLAT2:]] = OpVectorShuffle %[[#C32]] %[[#CC_B0_IM_SPLAT1]] %[[#CC_DUMMY2]] 0 0 +; CHECK-NEXT: %[[#]] = OpExtInst %[[#C32]] %[[#]] fma %[[#CC_B0_IM_SPLAT2]] %[[#C32_C1_0_2]] %[[#C32_NULL]] +; CHECK: %[[#CC_NEG_IM:]] = OpFNegate %[[#C32]] %[[#CC_IM:]] +; CHECK-NEXT: %[[#CC_IM_CONJ:]] = OpVectorShuffle %[[#C32]] %[[#CC_NEG_IM]] %[[#CC_IM]] 1 2 +; CHECK-NEXT: %[[#]] = OpFAdd %[[#C32]] %[[#]] %[[#CC_IM_CONJ]] +} + +func @coopmatrix_mul_add_ii_mixed() subgroup_size(16) { + %a = constant 1 -> coopmatrix + %b = constant 2 -> coopmatrix + %c = constant 3 -> coopmatrix + %c_next = cooperative_matrix_mul_add %a, %b, %c + : coopmatrix, coopmatrix, + coopmatrix -> coopmatrix +; CHECK-LABEL: %[[#]] = OpFunction {{.*}} +; CHECK: %[[#II_BC:]] = OpGroupBroadcast %[[#I32]] %[[#I32_C3]] %[[#I32_C2]] %[[#I32_C0]] +; CHECK-NEXT: %[[#II_A_I32:]] = OpSConvert %[[#I32]] %[[#I16_C1]] +; CHECK-NEXT: %[[#II_MUL:]] = OpIMul %[[#I32]] %[[#II_A_I32]] %[[#II_BC]] +; CHECK-NEXT: %[[#II_MUL_I64:]] = OpSConvert %[[#I64]] %[[#II_MUL]] +; CHECK-NEXT: %[[#]] = OpIAdd %[[#I64]] %[[#I64_C3]] %[[#II_MUL_I64]] +} From b4bd16c3f72f55824c7c83a250c70976a0facbbd Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 14 Nov 2024 19:48:29 +0100 Subject: [PATCH 112/297] Bugfix in lower linalg Signed-off-by: Carsten Uphoff --- include/tinytc/tinytc.hpp | 13 +++++------ include/tinytc/types.h | 3 +++ include/tinytc/types.hpp | 2 ++ src/compiler.cpp | 2 ++ src/error.cpp | 7 +++++- src/node/inst_node.cpp | 3 +++ src/pass/check_ir.cpp | 39 +++++++++++++++++++++++++++++-- src/pass/check_ir.hpp | 9 +++++++ src/pass/dump_ir.cpp | 8 ++++++- src/pass/lower_linalg.cpp | 22 ++++++++--------- test/opt/dead-code-elimination.ir | 2 +- 11 files changed, 87 insertions(+), 23 deletions(-) diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 988db08e..4709b472 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1914,27 +1914,26 @@ class region_builder { /** * @brief Build if with functor then(region_builder&) -> void * + * Note: If the if instruction returns values then we must have a "yield" instruction in both + * the "then" and the "else" branch. So to return values use the "ifelse" function. + * * @tparam F Functor type * @param condition Condition value * @param then Then region functor - * @param return_type_list Types of returned values * @param loc Source code location * * @return Returned values */ - template - auto if_condition(value condition, F &&then, array_view return_type_list = {}, - location const &loc = {}) -> std::vector { - auto ii = ::tinytc::make_if(std::move(condition), return_type_list, loc); + template void if_condition(value condition, F &&then, location const &loc = {}) { + auto ii = ::tinytc::make_if(std::move(condition), {}, loc); auto reg = region{}; ii.get_regions(reg); if (!reg) { throw status::internal_compiler_error; } - auto results = add_multivalued(std::move(ii)); + reg_.add_instruction(std::move(ii)); auto bb = region_builder{reg}; then(bb); - return results; } /** * @brief Build if/else with functors then(region_builder&) -> void and diff --git a/include/tinytc/types.h b/include/tinytc/types.h index fb1ffe4a..7b16ccc2 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -87,6 +87,9 @@ typedef enum { tinytc_status_ir_incompatible_scalar_types = 0x126, ///< Incompatible scalar types tinytc_status_ir_constant_mismatch = 0x127, ///< Constant mismatch tinytc_status_ir_insufficient_alignment = 0x128, ///< Insufficient alignment + tinytc_status_ir_must_have_yield = 0x129, ///< Must have yield instruction + tinytc_status_ir_yield_in_else_branch_missing = + 0x130, ///< Must have yield instruction in else branch // SPIR-V errors tinytc_status_spirv_forbidden_forward_declaration = 0x1000, ///< Forward declaration of id is forbidden diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 31f20d85..7237b191 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -95,6 +95,8 @@ enum class status { ir_incompatible_scalar_types = tinytc_status_ir_incompatible_scalar_types, ir_constant_mismatch = tinytc_status_ir_constant_mismatch, ir_insufficient_alignment = tinytc_status_ir_insufficient_alignment, + ir_must_have_yield = tinytc_status_ir_must_have_yield, + ir_yield_in_else_branch_missing = tinytc_status_ir_yield_in_else_branch_missing, spirv_forbidden_forward_declaration = tinytc_status_spirv_forbidden_forward_declaration, spirv_undefined_value = tinytc_status_spirv_undefined_value, spirv_missing_dope_vector = tinytc_status_spirv_missing_dope_vector, diff --git a/src/compiler.cpp b/src/compiler.cpp index 49db9efc..4a93ff1d 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -77,6 +77,8 @@ void apply_default_optimization_pipeline(tinytc_prog_t prg, const_tinytc_core_in run_function_pass(cpp, *prg); run_function_pass(dead_code_elimination_pass{}, *prg); } + + run_function_pass(check_ir_pass{}, *prg); } } // namespace tinytc diff --git a/src/error.cpp b/src/error.cpp index 9c4e0571..5fe7bef0 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -157,7 +157,8 @@ char const *tinytc_error_string(tinytc_status_t status) { case tinytc_status_ir_unexpected_yield: return "Yield encountered in non-yielding region"; case tinytc_status_ir_yield_mismatch: - return "Number of yielded values does not match number of values yielded by region"; + return "Number of yielded values does not match number of values yielded by region or the " + "types are different"; case tinytc_status_ir_subview_mismatch: return "Number of dynamic offsets and sizes must match number of dynamic operands"; case tinytc_status_ir_invalid_slice: @@ -204,6 +205,10 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Type of constant does not match type of returned value"; case tinytc_status_ir_insufficient_alignment: return "Pointer does not satisfy minimum alignment requirements"; + case tinytc_status_ir_must_have_yield: + return "Last instruction of region that returns values must be \"yield\""; + case tinytc_status_ir_yield_in_else_branch_missing: + return "Else-branch must have yield instruction if then-branch has yield instruction"; // SPIR-V case tinytc_status_spirv_forbidden_forward_declaration: return "Forward declaration of id is forbidden"; diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 639d12c0..be1756ab 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -161,6 +161,7 @@ loop_inst::loop_inst(IK tid, tinytc_value_t from0, tinytc_value_t to0, tinytc_va body().set_num_params(1 + init_values.size()); body().set_param(0, loop_var_type, lc); + body().loc(lc); for (std::size_t i = 0; i < init_values.size(); ++i) { body().set_param(1 + i, init_values[i]->ty(), lc); result(i) = value_node{init_values[i]->ty(), this, lc}; @@ -838,6 +839,8 @@ if_inst::if_inst(tinytc_value_t condition, array_view return : standard_inst{IK::if_, 1, static_cast(return_types.size())} { op(0, condition); loc(lc); + then().loc(lc); + otherwise().loc(lc); if (!isa(*condition->ty())) { throw compilation_error(loc(), status::ir_expected_boolean); } diff --git a/src/pass/check_ir.cpp b/src/pass/check_ir.cpp index e8100376..b2938f7f 100644 --- a/src/pass/check_ir.cpp +++ b/src/pass/check_ir.cpp @@ -3,13 +3,46 @@ #include "pass/check_ir.hpp" #include "error.hpp" -#include "node/inst_node.hpp" -#include "node/region_node.hpp" +#include "support/casting.hpp" +#include "support/visit.hpp" #include "support/walk.hpp" #include "tinytc/types.hpp" namespace tinytc { +void check_ir_pass::check_yield(region_node const ®, inst_node const &in, + status yield_missing_status) { + auto last_inst = --reg.end(); + if (last_inst == reg.end()) { + throw compilation_error(reg.loc(), yield_missing_status); + } + auto yield = dyn_cast(last_inst.get()); + if (!yield) { + throw compilation_error(reg.loc(), yield_missing_status); + } + if (yield->num_operands() != in.num_results()) { + throw compilation_error(reg.loc(), status::ir_yield_mismatch); + } + for (std::int64_t i = 0; i < in.num_results(); ++i) { + if (yield->op(i).ty() != in.result(i).ty()) { + throw compilation_error(reg.loc(), status::ir_yield_mismatch); + } + } +} + +void check_ir_pass::operator()(inst_node const &) {} +void check_ir_pass::operator()(for_inst const &in) { + if (in.num_results() > 0) { + check_yield(in.body(), in); + } +} +void check_ir_pass::operator()(if_inst const &in) { + if (in.num_results() > 0) { + check_yield(in.then(), in); + check_yield(in.otherwise(), in, status::ir_yield_in_else_branch_missing); + } +} + void check_ir_pass::run_on_function(function_node &fn) { walk(fn, [this](inst_node const &i, walk_stage const &stage) { const bool child_region_is_spmd_region = @@ -30,6 +63,8 @@ void check_ir_pass::run_on_function(function_node &fn) { if (child_region_is_spmd_region && stage.is_after_all_regions()) { inside_spmd_region_ = false; } + + visit(*this, i); }); } diff --git a/src/pass/check_ir.hpp b/src/pass/check_ir.hpp index c58b4a90..9b9f4c49 100644 --- a/src/pass/check_ir.hpp +++ b/src/pass/check_ir.hpp @@ -5,14 +5,23 @@ #define CHECK_IR_20240222_HPP #include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" namespace tinytc { class check_ir_pass { public: + void operator()(inst_node const &in); + void operator()(for_inst const &in); + void operator()(if_inst const &in); + void run_on_function(function_node &fn); private: + void check_yield(region_node const ®, inst_node const &in, + status yield_missing_status = status::ir_must_have_yield); + bool inside_spmd_region_ = false; }; diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 7a92b58d..911a98d3 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -236,7 +236,7 @@ void dump_ir_pass::operator()(cooperative_matrix_mul_add_inst const &c) { *os_ << ", "; visit(*this, *c.b().ty()); *os_ << ", "; - visit(*this, *c.b().ty()); + visit(*this, *c.c().ty()); *os_ << " -> "; visit(*this, *c.result(0).ty()); } @@ -412,6 +412,12 @@ void dump_ir_pass::operator()(if_inst const &in) { *os_ << "if "; dump_val(in.condition()); *os_ << " "; + if (in.num_results() > 0) { + *os_ << "-> ("; + do_with_infix(in.result_begin(), in.result_end(), + [this](auto const &i) { visit(*this, *i.ty()); }); + *os_ << ") "; + } dump_region(in.then()); if (!in.is_otherwise_empty()) { *os_ << " else "; diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index 4ed4eb5f..c6c4acc0 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -52,10 +52,10 @@ void gemm_microkernel(region_builder &bb, transpose tA, transpose tB, bool atomi }(); auto coopmatrix_c_ty = get_coopmatrix(c_ty, m_block_size, n_block_size, matrix_use::acc, loc); auto const compute_c = [&](region_builder &bb, std::int32_t k_block_size, value K0, value K1, - value c_init) -> value { + value c_acc) -> value { auto c_step = bb.add(make_constant(k_block_size, index_ty, loc)); auto return_values = bb.for_loop( - K0, K1, c_step, {c_init}, index_ty, [&](region_builder &bb, array_view p) { + K0, K1, c_step, {c_acc}, index_ty, [&](region_builder &bb, array_view p) { const auto k = p[0]; value pos_a[2] = {m_block, k}; @@ -82,7 +82,7 @@ void gemm_microkernel(region_builder &bb, transpose tA, transpose tB, bool atomi return return_values[0]; }; - auto c_init = bb.add(make_constant_zero(coopmatrix_c_ty, loc)); + auto c_acc = bb.add(make_constant_zero(coopmatrix_c_ty, loc)); auto k_block_size = max_K_unrolling; @@ -95,25 +95,25 @@ void gemm_microkernel(region_builder &bb, transpose tA, transpose tB, bool atomi auto c_k_block_size = bb.add(make_constant(k_block_size, index_ty, loc)); auto tmp = instant_constant_fold_add(bb, make_arith(arithmetic::div, K, c_k_block_size, loc)); auto K0 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, tmp, c_k_block_size, loc)); - c_init = compute_c(bb, k_block_size, c_zero, K0, c_init); + c_acc = compute_c(bb, k_block_size, c_zero, K0, c_acc); auto needs_remainder = instant_constant_fold_add(bb, make_cmp(cmp_condition::lt, K0, K, loc)); auto r = get_bool_constant(needs_remainder); if (r) { if (*r != 0) { - c_init = compute_c(bb, 1, K0, K, c_init); + c_acc = compute_c(bb, 1, K0, K, c_acc); } } else { - auto remainder = bb.if_condition( + auto remainder = bb.ifelse( needs_remainder, [&](region_builder &bb) { - auto c_next = compute_c(bb, 1, K0, K, c_init); + auto c_next = compute_c(bb, 1, K0, K, c_acc); bb.add(make_yield(c_next, loc)); }, - {coopmatrix_c_ty}, loc); - c_init = remainder[0]; + [&](region_builder &bb) { bb.add(make_yield(c_acc, loc)); }, {coopmatrix_c_ty}, loc); + c_acc = remainder[0]; } - auto alpha_ab = mixed_precision_coopmatrix_scale(bb, alpha, c_init, loc); + auto alpha_ab = mixed_precision_coopmatrix_scale(bb, alpha, c_acc, loc); if (atomic) { auto flag = get_atomic_store_flag(beta); if (!flag) { @@ -437,7 +437,7 @@ auto linalg_generator::operator()(sum_inst &in) -> inst { [&](region_builder &bb) { blas_update(bb, in.atomic(), &in.alpha(), sum, &in.beta(), &in.B(), {}, in.loc()); }, - {}, in.loc()); + in.loc()); } else if (bt->dim() == 1) { auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.B(), 0, in.loc())); auto c_trip_count = instant_constant_fold_add( diff --git a/test/opt/dead-code-elimination.ir b/test/opt/dead-code-elimination.ir index 7ce5d69d..cca54f1b 100644 --- a/test/opt/dead-code-elimination.ir +++ b/test/opt/dead-code-elimination.ir @@ -33,7 +33,7 @@ func @dead_if_with_yield(%a: memref) { store %0, %a[] : memref ; Cannot eliminate if that returns results currently ; CHECK-LABEL: func @dead_if_with_yield({{.*}} -; CHECK: %0 = if %c0 { +; CHECK: %0 = if %c0 -> (f64) { ; CHECK-NEXT: %c42 = constant 0x1.5p+5 -> f64 ; CHECK-NEXT: yield %c42 : f64 ; CHECK-NEXT: } else { From 07fe01285d13118aa1e8bc51aec54063d18cbb85 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 19 Nov 2024 13:26:26 +0100 Subject: [PATCH 113/297] Add instruction cloner; rework foreach instruction; add lower foreach pass Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 27 +++- include/tinytc/tinytc.h | 16 ++- include/tinytc/tinytc.hpp | 51 ++++++-- include/tinytc/types.h | 3 +- include/tinytc/types.hpp | 1 + src/CMakeLists.txt | 2 + src/compiler.cpp | 2 + src/error.cpp | 2 + src/inst.cpp | 13 +- src/node/inst_node.cpp | 128 +++++++++++-------- src/node/inst_node.hpp | 116 ++++++++++++----- src/parser/parser_impl.yy | 11 +- src/pass/clone.cpp | 199 +++++++++++++++++++++++++++++ src/pass/clone.hpp | 77 +++++++++++ src/pass/convert_to_opencl.cpp | 5 +- src/pass/dead_code_elimination.cpp | 4 +- src/pass/dump_ir.cpp | 17 +-- src/pass/lower_foreach.cpp | 149 +++++++++++++++++++++ src/pass/lower_foreach.hpp | 23 ++++ src/passes.def | 1 + 20 files changed, 717 insertions(+), 130 deletions(-) create mode 100644 src/pass/clone.cpp create mode 100644 src/pass/clone.hpp create mode 100644 src/pass/lower_foreach.cpp create mode 100644 src/pass/lower_foreach.hpp diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 1717264f..df833c53 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -393,20 +393,39 @@ Foreach .. code:: abnf - instruction =/ "foreach" local-identifier "=" local-identifier "," local-identifier + instruction =/ "foreach" "(" local-identifier-list ")" "=" + "(" local-identifier-list ")" "," "(" local-identifier-list ")" [":" integer-type] region Overview ~~~~~~~~ -A foreach loop that executes the loop's range [from; to) without any sequence guarantee. +A foreach loop that executes the loop's range without any sequence guarantee. The region of a foreach is a *spmd region*. -The trip count is stored in the first local identifier and is accessible within the loop body. -The loop's range [from; to) is given by the first and the second local identifier after the equals sign. +The three local identifier lists define the loop range and the local identifiers that +make the trip count available within the loop body. +All three lists must have the same length and have the following format: + +.. math:: + + (\text{var}_1, \dots, \text{var}_N) = (\text{from}_1, \dots, \text{from}_N), + (\text{to}_1, \dots, \text{to}_N), + +where :math:`N` is the common length of each of the three lists. +The loop range is defined as the cartesian product of the half-open intervals +:math:`[\text{from}_i; \text{to}_i)` such that the trip count take the values + +.. math:: + + (\text{var}_1, \dots, \text{var}_N) \in [\text{from}_1; \text{to}_1) \times \dots \times + [\text{from}_N; \text{to}_N) + The integer type of the loop variable and the loop bounds is given after the colon and the default integer type is ``index``. +The mapping of trip count to work-item is implementation-defined. + GEMM .... diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 027129cb..1f1bc5ca 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -848,21 +848,23 @@ TINYTC_EXPORT tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinyt * @brief Create foreach loop * * @code - * foreach %loop_var = %from, %to : loop_var_type { } - * ; loop_var_type == type(%from) - * ; loop_var_type == type(%to) + * foreach (loop_var_list) = (from_list), (to_list) : loop_var_type { } + * ; loop_var_type == type(%f) forall %f in from_list + * ; loop_var_type == type(%t) forall %t in to_list * @endcode * * @param instr [out] pointer to the inst object created - * @param from [in] loop begion - * @param to [in] loop bound + * @param dim [in] length of from and to array; must be > 0 + * @param from_list [in][range(1, dim)] loop begion + * @param to_list [in][range(1, dim)] loop bound * @param loop_var_type [in] type of loop variable * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, tinytc_value_t from, - tinytc_value_t to, +TINYTC_EXPORT tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, uint32_t dim, + const tinytc_value_t *from_list, + const tinytc_value_t *to_list, tinytc_data_type_t loop_var_type, const tinytc_location_t *loc); diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 4709b472..98ccc2b8 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1528,16 +1528,32 @@ inline inst make_for(value from, value to, value step, array_view initial /** * @brief Make foreach loop instruction * - * @param from Loop variable start - * @param to Loop variable bound + * @param from_list List of loop variable start + * @param to_list List of loop variable bound * @param loop_var_type Type of loop variable * @param loc Source code location * * @return Instruction */ -inline inst make_foreach(value from, value to, data_type loop_var_type, location const &loc = {}) { +inline inst make_foreach(array_view from_list, array_view to_list, + data_type loop_var_type, location const &loc = {}) { + tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_foreach_inst_create(&instr, from, to, loop_var_type, &loc), loc); + if (from_list.size() != to_list.size()) { + throw std::invalid_argument("from list must have the same length as the to list"); + } + const auto from_len = from_list.size(); + if (from_len > std::numeric_limits::max()) { + throw std::out_of_range("from list too long"); + } + const auto to_len = to_list.size(); + if (to_len > std::numeric_limits::max()) { + throw std::out_of_range("to list too long"); + } + const tinytc_value_t *fl = reinterpret_cast(from_list.data()); + const tinytc_value_t *tl = reinterpret_cast(to_list.data()); + CHECK_STATUS_LOC(tinytc_foreach_inst_create(&instr, from_len, fl, tl, loop_var_type, &loc), + loc); return inst(instr); } @@ -1887,28 +1903,30 @@ class region_builder { return results; } /** - * @brief Build foreach-loop with functor f(region_builder&, value) -> void + * @brief Build foreach-loop with functor f(region_builder&, array_view) -> void * * @tparam F Functor type - * @param from Loop variable start - * @param to Loop variable bound + * @param from Loop variable start list + * @param to Loop variable bound list * @param loop_var_ty Type of loop variable * @param f functor * @param loc Source code location */ template - void foreach (value from, value to, data_type loop_var_ty, F && f, location const &loc = {}) { - auto fi = ::tinytc::make_foreach(from, to, loop_var_ty, loc); + void foreach (array_view from, array_view to, data_type loop_var_ty, F && f, + location const &loc = {}) { + auto fi = ::tinytc::make_foreach(std::move(from), std::move(to), loop_var_ty, loc); auto reg = region{}; fi.get_regions(reg); - auto loop_var = value{}; - reg.get_parameters(loop_var); - if (!reg || !loop_var) { + auto num_params = reg.get_parameters({}); + auto params = std::vector(num_params); + reg.get_parameters(params); + if (!reg || num_params != from.size() || num_params != to.size()) { throw status::internal_compiler_error; } reg_.add_instruction(std::move(fi)); auto bb = region_builder{reg}; - f(bb, loop_var); + f(bb, array_view(params)); } /** @@ -1967,6 +1985,13 @@ class region_builder { return results; } + /** + * @brief Get region + * + * @return Region + */ + inline auto get_region() -> region { return reg_; } + private: region reg_; }; diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 7b16ccc2..270e839e 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -89,7 +89,8 @@ typedef enum { tinytc_status_ir_insufficient_alignment = 0x128, ///< Insufficient alignment tinytc_status_ir_must_have_yield = 0x129, ///< Must have yield instruction tinytc_status_ir_yield_in_else_branch_missing = - 0x130, ///< Must have yield instruction in else branch + 0x130, ///< Must have yield instruction in else branch + tinytc_status_ir_from_to_mismatch = 0x131, ///< size(from) != size(to) in foreach // SPIR-V errors tinytc_status_spirv_forbidden_forward_declaration = 0x1000, ///< Forward declaration of id is forbidden diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 7237b191..47ba1677 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -97,6 +97,7 @@ enum class status { ir_insufficient_alignment = tinytc_status_ir_insufficient_alignment, ir_must_have_yield = tinytc_status_ir_must_have_yield, ir_yield_in_else_branch_missing = tinytc_status_ir_yield_in_else_branch_missing, + ir_from_to_mismatch = tinytc_status_ir_from_to_mismatch, spirv_forbidden_forward_declaration = tinytc_status_spirv_forbidden_forward_declaration, spirv_undefined_value = tinytc_status_spirv_undefined_value, spirv_missing_dope_vector = tinytc_status_spirv_missing_dope_vector, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e49b6479..80c330e3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -43,6 +43,7 @@ set(SOURCES parser/parse_context.cpp parser.cpp pass/check_ir.cpp + pass/clone.cpp pass/constant_folding.cpp pass/constant_propagation.cpp pass/convert_to_opencl.cpp @@ -53,6 +54,7 @@ set(SOURCES pass/dump_ir.cpp pass/insert_barrier.cpp pass/insert_lifetime_stop.cpp + pass/lower_foreach.cpp pass/lower_linalg.cpp pass/slot_tracker.cpp pass/stack.cpp diff --git a/src/compiler.cpp b/src/compiler.cpp index 4a93ff1d..a8ed5aa3 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -14,6 +14,7 @@ #include "pass/dump_ir.hpp" #include "pass/insert_barrier.hpp" #include "pass/insert_lifetime_stop.hpp" +#include "pass/lower_foreach.hpp" #include "pass/lower_linalg.hpp" #include "pass/stack.hpp" #include "pass/work_group_size.hpp" @@ -73,6 +74,7 @@ void apply_default_optimization_pipeline(tinytc_prog_t prg, const_tinytc_core_in run_function_pass(work_group_size_pass{info}, *prg); run_function_pass(lower_linalg_pass{info}, *prg); + run_function_pass(lower_foreach_pass{info}, *prg); if (opt_level >= 1) { run_function_pass(cpp, *prg); run_function_pass(dead_code_elimination_pass{}, *prg); diff --git a/src/error.cpp b/src/error.cpp index 5fe7bef0..2e4dfe26 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -209,6 +209,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Last instruction of region that returns values must be \"yield\""; case tinytc_status_ir_yield_in_else_branch_missing: return "Else-branch must have yield instruction if then-branch has yield instruction"; + case tinytc_status_ir_from_to_mismatch: + return "length(from) must equal length(to) and length must be greater than 0"; // SPIR-V case tinytc_status_spirv_forbidden_forward_declaration: return "Forward declaration of id is forbidden"; diff --git a/src/inst.cpp b/src/inst.cpp index 2260c573..570574b8 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -639,15 +639,20 @@ tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_value_t from }); } -tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, tinytc_value_t from, - tinytc_value_t to, tinytc_data_type_t loop_var_type, +tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, uint32_t dim, + const tinytc_value_t *from_list, + const tinytc_value_t *to_list, + tinytc_data_type_t loop_var_type, const tinytc_location_t *loc) { - if (instr == nullptr || loop_var_type == nullptr || from == nullptr || to == nullptr) { + if (instr == nullptr || loop_var_type == nullptr || from_list == nullptr || + to_list == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { *instr = - std::make_unique(from, to, loop_var_type, get_optional(loc)).release(); + std::make_unique(array_view{from_list, dim}, array_view{to_list, dim}, + loop_var_type, get_optional(loc)) + .release(); }); } diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index be1756ab..ed2e4660 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -147,50 +147,6 @@ blas_a3_inst::blas_a3_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinyt } } -loop_inst::loop_inst(IK tid, tinytc_value_t from0, tinytc_value_t to0, tinytc_value_t step0, - array_view init_values, tinytc_data_type_t loop_var_type, - location const &lc) - : standard_inst{tid, (step0 ? 3 : 2) + static_cast(init_values.size()), - static_cast(init_values.size())} { - - op(op_from, from0); - op(op_to, to0); - if (step0) { - op(op_step, step0); - } - - body().set_num_params(1 + init_values.size()); - body().set_param(0, loop_var_type, lc); - body().loc(lc); - for (std::size_t i = 0; i < init_values.size(); ++i) { - body().set_param(1 + i, init_values[i]->ty(), lc); - result(i) = value_node{init_values[i]->ty(), this, lc}; - } - for (std::size_t i = 0; i < init_values.size(); ++i) { - if (!isa(*init_values[i]->ty()) && - !isa(*init_values[i]->ty()) && - !isa(*init_values[i]->ty())) { - throw compilation_error(loc(), status::ir_expected_coopmatrix_scalar_or_boolean); - } - op(op_init() + i, init_values[i]); - } - loc(lc); - - auto lvt = get_scalar_type(loc(), loop_var()); - auto fromt = get_scalar_type(loc(), from()); - auto tot = get_scalar_type(loc(), to()); - bool step_ok = true; - if (has_step()) { - auto stept = get_scalar_type(loc(), step()); - step_ok = lvt->ty() == stept->ty(); - } - - if (!is_integer_type(lvt->ty()) || lvt->ty() != fromt->ty() || lvt->ty() != tot->ty() || - !step_ok) { - throw compilation_error(loc(), status::ir_scalar_mismatch); - } -} - alloca_inst::alloca_inst(tinytc_data_type_t ty, location const &lc) : standard_inst{IK::alloca}, stack_ptr_{-1} { loc(lc); @@ -658,6 +614,84 @@ expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, result(0) = value_node{result_ty, this, lc}; } +for_inst::for_inst(tinytc_value_t from0, tinytc_value_t to0, tinytc_value_t step0, + array_view init_values, tinytc_data_type_t loop_var_type, + location const &lc) + : loop_inst{IK::for_loop, (step0 ? 3 : 2) + static_cast(init_values.size()), + static_cast(init_values.size())} { + op(op_from, from0); + op(op_to, to0); + if (step0) { + op(op_step, step0); + } + + body().set_num_params(1 + init_values.size()); + body().set_param(0, loop_var_type, lc); + body().loc(lc); + for (std::size_t i = 0; i < init_values.size(); ++i) { + body().set_param(1 + i, init_values[i]->ty(), lc); + result(i) = value_node{init_values[i]->ty(), this, lc}; + } + for (std::size_t i = 0; i < init_values.size(); ++i) { + if (!isa(*init_values[i]->ty()) && + !isa(*init_values[i]->ty()) && + !isa(*init_values[i]->ty())) { + throw compilation_error(loc(), status::ir_expected_coopmatrix_scalar_or_boolean); + } + op(op_init() + i, init_values[i]); + } + loc(lc); + + auto lvt = get_scalar_type(loc(), loop_var()); + auto fromt = get_scalar_type(loc(), from()); + auto tot = get_scalar_type(loc(), to()); + bool step_ok = true; + if (has_step()) { + auto stept = get_scalar_type(loc(), step()); + step_ok = lvt->ty() == stept->ty(); + } + + if (!is_integer_type(lvt->ty()) || lvt->ty() != fromt->ty() || lvt->ty() != tot->ty() || + !step_ok) { + throw compilation_error(loc(), status::ir_scalar_mismatch); + } +} + +foreach_inst::foreach_inst(array_view from, array_view to, + tinytc_data_type_t loop_var_type, location const &lc) + : loop_inst{IK::foreach_loop, static_cast(from.size() + to.size()), + std::int64_t{0}} { + std::int64_t op_no = 0; + for (auto &v : from) { + op(op_no++, v); + } + for (auto &v : to) { + op(op_no++, v); + } + body().set_num_params(from.size()); + for (std::int64_t i = 0; i < static_cast(from.size()); ++i) { + body().set_param(i, loop_var_type, lc); + } + body().loc(lc); + child_region(0).kind(region_kind::spmd); + loc(lc); + + if (from.size() == 0 || from.size() != to.size()) { + throw compilation_error(loc(), status::ir_from_to_mismatch); + } + + if (auto lv_ty = dyn_cast(loop_var_type); lv_ty) { + if (!is_integer_type(lv_ty->ty()) || + std::any_of(op_begin(), op_end(), [&loop_var_type](tinytc_value &val) { + return val.ty() != loop_var_type; + })) { + throw compilation_error(loc(), status::ir_scalar_mismatch); + } + } else { + throw compilation_error(loc(), status::ir_expected_scalar); + } +} + fuse_inst::fuse_inst(tinytc_value_t op0, std::int64_t from, std::int64_t to, location const &lc) : standard_inst{IK::fuse}, from_(from), to_(to) { op(0, op0); @@ -803,12 +837,6 @@ ger_inst::ger_inst(tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t B0, } } -foreach_inst::foreach_inst(tinytc_value_t from, tinytc_value_t to, tinytc_data_type_t loop_var_type, - location const &loc) - : loop_inst{IK::foreach_loop, std::move(from), std::move(to), nullptr, {}, loop_var_type, loc} { - child_region(0).kind(region_kind::spmd); -} - hadamard_inst::hadamard_inst(tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t B0, tinytc_value_t beta0, tinytc_value_t C0, bool atomic, location const &lc) diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 63040706..e089ad84 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -332,29 +332,11 @@ class loop_inst : public standard_inst { inline static bool classof(inst_node const &i) { return i.type_id() >= IK::loop && i.type_id() <= IK::last_loop; } - enum op_number { op_from = 0, op_to = 1, op_step = 2 }; - loop_inst(IK tid, tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, - array_view init_values, tinytc_data_type_t loop_var_type, - location const &loc = {}); - inline auto from() const -> tinytc_value const & { return op(op_from); } - inline auto to() const -> tinytc_value const & { return op(op_to); } - inline auto has_step() const -> bool { return op_init() == 3; } - inline auto step() const -> tinytc_value const & { return op(op_step); } + inline loop_inst(IK tid, std::int64_t num_operands, std::int64_t num_results) + : standard_inst{tid, num_operands, num_results} {} + inline auto body() -> tinytc_region & { return child_region(0); } inline auto body() const -> tinytc_region const & { return child_region(0); } - inline auto loop_var() -> tinytc_value & { return body().param(0); } - inline auto loop_var() const -> tinytc_value const & { return body().param(0); } - inline auto iter_arg(std::int64_t no) -> tinytc_value & { return body().param(no + 1); } - inline auto iter_arg(std::int64_t no) const -> tinytc_value const & { - return body().param(no + 1); - } - inline auto iter_init(std::int64_t no) -> tinytc_value & { return op(op_init() + no); } - inline auto iter_init(std::int64_t no) const -> tinytc_value const & { - return op(op_init() + no); - } - - private: - inline auto op_init() const -> std::int64_t { return num_operands() - num_results(); } }; class alloca_inst : public standard_inst<0, 1> { @@ -404,6 +386,7 @@ class arith_unary_inst : public standard_inst<1, 1> { arith_unary_inst(arithmetic_unary op, tinytc_value_t a, location const &lc = {}); inline arithmetic_unary operation() const { return operation_; } + inline auto a() -> tinytc_value & { return op(op_a); } inline auto a() const -> tinytc_value const & { return op(op_a); } private: @@ -433,6 +416,8 @@ class cast_inst : public standard_inst<1, 1> { inline static bool classof(inst_node const &i) { return i.type_id() == IK::cast; } enum op_number { op_a = 0 }; cast_inst(tinytc_value_t a, tinytc_data_type_t to_ty, location const &lc = {}); + + inline auto a() -> tinytc_value & { return op(op_a); } inline auto a() const -> tinytc_value const & { return op(op_a); } }; @@ -443,7 +428,9 @@ class compare_inst : public standard_inst<2, 1> { compare_inst(cmp_condition cond, tinytc_value_t a, tinytc_value_t b, location const &lc = {}); inline cmp_condition cond() const { return cond_; } + inline auto a() -> tinytc_value & { return op(op_a); } inline auto a() const -> tinytc_value const & { return op(op_a); } + inline auto b() -> tinytc_value & { return op(op_b); } inline auto b() const -> tinytc_value const & { return op(op_b); } private: @@ -474,10 +461,14 @@ class cooperative_matrix_load_inst : public standard_inst<3, 1, 0> { cooperative_matrix_load_inst(transpose t, checked_flag flag, tinytc_value_t op0, tinytc_value_t p0, tinytc_value_t p1, tinytc_data_type_t to_ty, location const &lc = {}); + inline auto t() const -> transpose { return t_; } inline auto checked() const -> checked_flag { return flag_; } + inline auto operand() -> tinytc_value & { return op(op_operand); } inline auto operand() const -> tinytc_value const & { return op(op_operand); } + inline auto pos0() -> tinytc_value & { return op(op_pos0); } inline auto pos0() const -> tinytc_value const & { return op(op_pos0); } + inline auto pos1() -> tinytc_value & { return op(op_pos1); } inline auto pos1() const -> tinytc_value const & { return op(op_pos1); } private: @@ -493,8 +484,12 @@ class cooperative_matrix_mul_add_inst : public standard_inst<3, 1, 0> { enum op_number { op_a = 0, op_b = 1, op_c = 2 }; cooperative_matrix_mul_add_inst(tinytc_value_t a0, tinytc_value_t b0, tinytc_value_t c0, tinytc_data_type_t to_ty, location const &lc = {}); + + inline auto a() -> tinytc_value & { return op(op_a); } inline auto a() const -> tinytc_value const & { return op(op_a); } + inline auto b() -> tinytc_value & { return op(op_b); } inline auto b() const -> tinytc_value const & { return op(op_b); } + inline auto c() -> tinytc_value & { return op(op_c); } inline auto c() const -> tinytc_value const & { return op(op_c); } }; @@ -505,6 +500,7 @@ class cooperative_matrix_scale_inst : public standard_inst<2, 1, 0> { } enum op_number { op_a = 0, op_b = 1 }; cooperative_matrix_scale_inst(tinytc_value_t a0, tinytc_value_t b0, location const &lc = {}); + inline auto a() -> tinytc_value & { return op(op_a); } inline auto a() const -> tinytc_value const & { return op(op_a); } inline auto b() -> tinytc_value & { return op(op_b); } @@ -520,11 +516,16 @@ class cooperative_matrix_store_inst : public standard_inst<4, 0, 0> { cooperative_matrix_store_inst(checked_flag cflag, store_flag sflag, tinytc_value_t val0, tinytc_value_t op0, tinytc_value_t p0, tinytc_value_t p1, location const &lc = {}); + inline auto checked() const -> checked_flag { return cflag_; } inline auto flag() const -> store_flag { return sflag_; } + inline auto val() -> tinytc_value & { return op(op_val); } inline auto val() const -> tinytc_value const & { return op(op_val); } + inline auto operand() -> tinytc_value & { return op(op_operand); } inline auto operand() const -> tinytc_value const & { return op(op_operand); } + inline auto pos0() -> tinytc_value & { return op(op_pos0); } inline auto pos0() const -> tinytc_value const & { return op(op_pos0); } + inline auto pos1() -> tinytc_value & { return op(op_pos1); } inline auto pos1() const -> tinytc_value const & { return op(op_pos1); } private: @@ -544,6 +545,7 @@ class expand_inst : public standard_inst { return static_expand_shape_; } + inline auto operand() -> tinytc_value & { return op(0); } inline auto operand() const -> tinytc_value const & { return op(0); } inline auto expand_shape() { return operands() | std::views::drop(1); } inline auto expand_shape() const { return operands() | std::views::drop(1); } @@ -559,6 +561,7 @@ class fuse_inst : public standard_inst<1, 1> { inline static bool classof(inst_node const &i) { return i.type_id() == IK::fuse; } fuse_inst(tinytc_value_t op, std::int64_t from, std::int64_t to, location const &lc = {}); + inline auto operand() -> tinytc_value & { return op(0); } inline auto operand() const -> tinytc_value const & { return op(0); } inline std::int64_t from() const { return from_; } inline std::int64_t to() const { return to_; } @@ -572,7 +575,9 @@ class load_inst : public standard_inst { inline static bool classof(inst_node const &i) { return i.type_id() == IK::load; } load_inst(tinytc_value_t op, array_view index_list, location const &lc = {}); + inline auto operand() -> tinytc_value & { return op(0); } inline auto operand() const -> tinytc_value const & { return op(0); } + inline auto index_list() { return operands() | std::views::drop(1); } inline auto index_list() const { return operands() | std::views::drop(1); } }; @@ -599,7 +604,13 @@ class group_size_inst : public standard_inst<0, 1> { class lifetime_stop_inst : public standard_inst<1, 0> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::lifetime_stop; } - inline lifetime_stop_inst(tinytc_value_t obj) : standard_inst{IK::lifetime_stop} { op(0, obj); } + inline lifetime_stop_inst(tinytc_value_t obj, location const &lc = {}) + : standard_inst{IK::lifetime_stop} { + op(0, obj); + loc(lc); + } + + inline auto object() -> tinytc_value & { return op(0); } inline auto object() const -> tinytc_value const & { return op(0); } }; @@ -638,23 +649,48 @@ class ger_inst : public blas_a3_inst { class for_inst : public loop_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::for_loop; } - inline for_inst(tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, - array_view init_values, tinytc_data_type_t loop_var_type, - location const &loc = {}) - : loop_inst{IK::for_loop, - std::move(from), - std::move(to), - std::move(step), - std::move(init_values), - loop_var_type, - loc} {} + enum op_number { op_from = 0, op_to = 1, op_step = 2 }; + for_inst(tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, + array_view init_values, tinytc_data_type_t loop_var_type, + location const &loc = {}); + + inline auto from() -> tinytc_value & { return op(op_from); } + inline auto from() const -> tinytc_value const & { return op(op_from); } + inline auto to() -> tinytc_value & { return op(op_to); } + inline auto to() const -> tinytc_value const & { return op(op_to); } + inline auto has_step() const -> bool { return op_init() == 3; } + inline auto step() -> tinytc_value & { return op(op_step); } + inline auto step() const -> tinytc_value const & { return op(op_step); } + inline auto loop_var() -> tinytc_value & { return body().param(0); } + inline auto loop_var() const -> tinytc_value const & { return body().param(0); } + inline auto iter_arg(std::int64_t no) -> tinytc_value & { return body().param(no + 1); } + inline auto iter_arg(std::int64_t no) const -> tinytc_value const & { + return body().param(no + 1); + } + inline auto iter_init(std::int64_t no) -> tinytc_value & { return op(op_init() + no); } + inline auto iter_init(std::int64_t no) const -> tinytc_value const & { + return op(op_init() + no); + } + inline auto iter_init() { return operands() | std::views::drop(op_init()); } + inline auto iter_init() const { return operands() | std::views::drop(op_init()); } + + private: + inline auto op_init() const -> std::int64_t { return num_operands() - num_results(); } }; class foreach_inst : public loop_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::foreach_loop; } - foreach_inst(tinytc_value_t from, tinytc_value_t to, tinytc_data_type_t loop_var_type, - location const &loc = {}); + foreach_inst(array_view from, array_view to, + tinytc_data_type_t loop_var_type, location const &lc = {}); + + inline auto dim() const -> std::int64_t { return num_operands() / 2; } + inline auto loop_vars() { return body().params(); } + inline auto loop_vars() const { return body().params(); } + inline auto from() { return operands() | std::views::take(dim()); } + inline auto from() const { return operands() | std::views::take(dim()); } + inline auto to() { return operands() | std::views::drop(dim()); } + inline auto to() const { return operands() | std::views::drop(dim()); } }; class hadamard_inst : public blas_a3_inst { @@ -670,6 +706,8 @@ class if_inst : public standard_inst<1, dynamic, 2> { enum child_region_number { child_region_then = 0, child_region_otherwise = 1 }; if_inst(tinytc_value_t condition, array_view return_types = {}, location const &lc = {}); + + inline auto condition() -> tinytc_value & { return op(0); } inline auto condition() const -> tinytc_value const & { return op(0); } inline auto then() -> tinytc_region & { return child_region(child_region_then); } inline auto then() const -> tinytc_region const & { return child_region(child_region_then); } @@ -704,6 +742,7 @@ class size_inst : public standard_inst<1, 1> { inline static bool classof(inst_node const &i) { return i.type_id() == IK::size; } size_inst(tinytc_value_t op, std::int64_t mode, location const &lc = {}); + inline auto operand() -> tinytc_value & { return op(0); } inline auto operand() const -> tinytc_value const & { return op(0); } inline std::int64_t mode() const { return mode_; } @@ -751,10 +790,15 @@ class subview_inst : public standard_inst { inline auto static_offsets() const -> array_view { return static_offsets_; } inline auto static_sizes() const -> array_view { return static_sizes_; } + inline auto operand() -> tinytc_value & { return op(0); } inline auto operand() const -> tinytc_value const & { return op(0); } + inline auto offsets() { + return operands() | std::views::drop(1) | std::views::take(num_dyn_offsets_); + } inline auto offsets() const { return operands() | std::views::drop(1) | std::views::take(num_dyn_offsets_); } + inline auto sizes() { return operands() | std::views::drop(1 + num_dyn_offsets_); } inline auto sizes() const { return operands() | std::views::drop(1 + num_dyn_offsets_); } private: @@ -770,8 +814,11 @@ class store_inst : public standard_inst { array_view index_list, location const &lc = {}); inline auto flag() const -> store_flag { return flag_; } + inline auto val() -> tinytc_value & { return op(op_val); } inline auto val() const -> tinytc_value const & { return op(op_val); } + inline auto operand() -> tinytc_value & { return op(op_operand); } inline auto operand() const -> tinytc_value const & { return op(op_operand); } + inline auto index_list() { return operands() | std::views::drop(2); } inline auto index_list() const { return operands() | std::views::drop(2); } private: @@ -797,6 +844,7 @@ class work_group_inst : public standard_inst<1, 1> { location const &lc = {}); inline auto operation() const -> work_group_operation { return operation_; } + inline auto operand() -> tinytc_value & { return op(0); } inline auto operand() const -> tinytc_value const & { return op(0); } private: diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index e4186854..39fc1f95 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -671,17 +671,18 @@ init_value: ; foreach_inst: - FOREACH LOCAL_IDENTIFIER[loop_var] EQUALS var[from] COMMA var[to] for_loop_var_type { - check_type($from, $for_loop_var_type, @from, @for_loop_var_type); - check_type($to, $for_loop_var_type, @to, @for_loop_var_type); + FOREACH LPAREN identifier_list[loop_var] RPAREN EQUALS + LPAREN value_list[from] RPAREN COMMA LPAREN value_list[to] RPAREN for_loop_var_type { try { location loc = @FOREACH; loc.end = @for_loop_var_type.end; auto inode = std::make_unique($from, $to, $for_loop_var_type, loc); ctx.push_scope(); - auto &loop_var = inode->loop_var(); - ctx.val($loop_var, loop_var, @loop_var); + auto loop_vars = inode->loop_vars().begin(); + for (std::int64_t i = 0; i < inode->dim(); ++i) { + ctx.val($loop_var[i], loop_vars[i], @loop_var); + } ctx.push_region(&inode->body()); $$ = inst{inode.release()}; } catch (compilation_error const &e) { diff --git a/src/pass/clone.cpp b/src/pass/clone.cpp new file mode 100644 index 00000000..3ec3dcbf --- /dev/null +++ b/src/pass/clone.cpp @@ -0,0 +1,199 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/clone.hpp" +#include "support/visit.hpp" + +namespace tinytc { + +void inst_cloner::reset_subs() { subs_map_.clear(); } +void inst_cloner::set_subs(tinytc_value_t in_val, tinytc_value_t out_val) { + subs_map_[in_val] = out_val; +} +auto inst_cloner::subs(tinytc_value_t val) -> tinytc_value_t { + if (auto it = subs_map_.find(val); it != subs_map_.end()) { + return it->second; + } + return val; +} + +auto inst_cloner::operator()(alloca_inst &in) -> std::unique_ptr { + return std::make_unique(in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(axpby_inst &in) -> std::unique_ptr { + return std::make_unique(in.tA(), subs(&in.alpha()), subs(&in.A()), subs(&in.beta()), + subs(&in.B()), in.atomic(), in.loc()); +} +auto inst_cloner::operator()(arith_inst &in) -> std::unique_ptr { + return std::make_unique(in.operation(), subs(&in.a()), subs(&in.b()), in.loc()); +} +auto inst_cloner::operator()(arith_unary_inst &in) -> std::unique_ptr { + return std::make_unique(in.operation(), subs(&in.a()), in.loc()); +} +auto inst_cloner::operator()(barrier_inst &in) -> std::unique_ptr { + return std::make_unique(in.fence_flags(), in.loc()); +} +auto inst_cloner::operator()(cast_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.a()), in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(compare_inst &in) -> std::unique_ptr { + return std::make_unique(in.cond(), subs(&in.a()), subs(&in.b()), in.loc()); +} +auto inst_cloner::operator()(constant_inst &in) -> std::unique_ptr { + return std::make_unique(in.value(), in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(cooperative_matrix_load_inst &in) -> std::unique_ptr { + return std::make_unique(in.t(), in.checked(), subs(&in.operand()), + subs(&in.pos0()), subs(&in.pos1()), + in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(cooperative_matrix_mul_add_inst &in) -> std::unique_ptr { + return std::make_unique( + subs(&in.a()), subs(&in.b()), subs(&in.c()), in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(cooperative_matrix_scale_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.a()), subs(&in.b()), in.loc()); +} +auto inst_cloner::operator()(cooperative_matrix_store_inst &in) -> std::unique_ptr { + return std::make_unique(in.checked(), in.flag(), subs(&in.val()), + subs(&in.operand()), subs(&in.pos0()), + subs(&in.pos1()), in.loc()); +} +auto inst_cloner::operator()(expand_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.operand()), in.expanded_mode(), + in.static_expand_shape(), + subs_value_range(in.expand_shape()), in.loc()); +} +auto inst_cloner::operator()(fuse_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.operand()), in.from(), in.to(), in.loc()); +} + +auto inst_cloner::operator()(lifetime_stop_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.object()), in.loc()); +} +auto inst_cloner::operator()(load_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.operand()), subs_value_range(in.index_list()), + in.loc()); +} +auto inst_cloner::operator()(group_id_inst &in) -> std::unique_ptr { + return std::make_unique(in.context(), in.loc()); +} + +auto inst_cloner::operator()(group_size_inst &in) -> std::unique_ptr { + return std::make_unique(in.context(), in.loc()); +} + +auto inst_cloner::operator()(gemm_inst &in) -> std::unique_ptr { + return std::make_unique(in.tA(), in.tB(), subs(&in.alpha()), subs(&in.A()), + subs(&in.B()), subs(&in.beta()), subs(&in.C()), in.atomic(), + in.loc()); +} + +auto inst_cloner::operator()(gemv_inst &in) -> std::unique_ptr { + return std::make_unique(in.tA(), subs(&in.alpha()), subs(&in.A()), subs(&in.B()), + subs(&in.beta()), subs(&in.C()), in.atomic(), in.loc()); +} + +auto inst_cloner::operator()(ger_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.alpha()), subs(&in.A()), subs(&in.B()), + subs(&in.beta()), subs(&in.C()), in.atomic(), in.loc()); +} +auto inst_cloner::operator()(for_inst &in) -> std::unique_ptr { + return std::make_unique( + subs(&in.from()), subs(&in.to()), in.has_step() ? subs(&in.step()) : nullptr, + subs_value_range(in.iter_init()), in.body().param(0).ty(), in.loc()); +} + +auto inst_cloner::operator()(foreach_inst &in) -> std::unique_ptr { + return std::make_unique(subs_value_range(in.from()), subs_value_range(in.to()), + in.body().param(0).ty(), in.loc()); +} + +auto inst_cloner::operator()(hadamard_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.alpha()), subs(&in.A()), subs(&in.B()), + subs(&in.beta()), subs(&in.C()), in.atomic(), in.loc()); +} + +auto inst_cloner::operator()(if_inst &in) -> std::unique_ptr { + auto return_types = std::vector(in.num_results()); + for (std::int64_t i = 0; i < in.num_results(); ++i) { + return_types[i] = in.result(i).ty(); + } + return std::make_unique(subs(&in.condition()), return_types, in.loc()); +} + +auto inst_cloner::operator()(num_subgroups_inst &in) -> std::unique_ptr { + return std::make_unique(in.context(), in.loc()); +} + +auto inst_cloner::operator()(parallel_inst &in) -> std::unique_ptr { + return std::make_unique(in.loc()); +} + +auto inst_cloner::operator()(size_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.operand()), in.mode(), in.loc()); +} + +auto inst_cloner::operator()(subgroup_id_inst &in) -> std::unique_ptr { + return std::make_unique(in.context(), in.loc()); +} + +auto inst_cloner::operator()(subgroup_local_id_inst &in) -> std::unique_ptr { + return std::make_unique(in.context(), in.loc()); +} + +auto inst_cloner::operator()(subgroup_size_inst &in) -> std::unique_ptr { + return std::make_unique(in.context(), in.loc()); +} + +auto inst_cloner::operator()(subview_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.operand()), in.static_offsets(), + in.static_sizes(), subs_value_range(in.offsets()), + subs_value_range(in.sizes()), in.loc()); +} + +auto inst_cloner::operator()(store_inst &in) -> std::unique_ptr { + return std::make_unique(in.flag(), subs(&in.val()), subs(&in.operand()), + subs_value_range(in.index_list()), in.loc()); +} + +auto inst_cloner::operator()(sum_inst &in) -> std::unique_ptr { + return std::make_unique(in.tA(), subs(&in.alpha()), subs(&in.A()), subs(&in.beta()), + subs(&in.B()), in.atomic(), in.loc()); +} + +auto inst_cloner::operator()(work_group_inst &in) -> std::unique_ptr { + return std::make_unique(in.operation(), subs(&in.operand()), in.loc()); +} + +auto inst_cloner::operator()(yield_inst &in) -> std::unique_ptr { + return std::make_unique(subs_value_range(std::views::all(in.operands())), in.loc()); +} + +auto inst_cloner::clone_instruction(inst_node &in) -> std::unique_ptr { + auto cloned = visit(*this, in); + for (auto res_orig = in.result_begin(), res_cloned = cloned->result_begin(); + res_orig != in.result_end() && res_cloned != cloned->result_end(); + ++res_orig, ++res_cloned) { + set_subs(&(*res_orig), &(*res_cloned)); + } + for (auto reg_orig = in.child_regions_begin(), reg_cloned = cloned->child_regions_begin(); + reg_orig != in.child_regions_end() && reg_cloned != cloned->child_regions_end(); + ++reg_orig, ++reg_cloned) { + for (auto p_orig = reg_orig->param_begin(), p_cloned = reg_cloned->param_begin(); + p_orig != reg_orig->param_end() && p_cloned != reg_cloned->param_end(); + ++p_orig, ++p_cloned) { + set_subs(&(*p_orig), &(*p_cloned)); + } + clone_region(*reg_orig, *reg_cloned); + } + return cloned; +} + +void inst_cloner::clone_region(region_node &source, region_node &target) { + for (auto &in_orig : source.insts()) { + target.insts().push_back(clone_instruction(in_orig).release()); + } +} + +} // namespace tinytc diff --git a/src/pass/clone.hpp b/src/pass/clone.hpp new file mode 100644 index 00000000..d7855216 --- /dev/null +++ b/src/pass/clone.hpp @@ -0,0 +1,77 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CLONE_20241118_HPP +#define CLONE_20241118_HPP + +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" + +#include +#include + +namespace tinytc { + +class inst_cloner { + public: + auto operator()(alloca_inst &in) -> std::unique_ptr; + auto operator()(axpby_inst &in) -> std::unique_ptr; + auto operator()(arith_inst &in) -> std::unique_ptr; + auto operator()(arith_unary_inst &in) -> std::unique_ptr; + auto operator()(barrier_inst &in) -> std::unique_ptr; + auto operator()(cast_inst &in) -> std::unique_ptr; + auto operator()(compare_inst &in) -> std::unique_ptr; + auto operator()(constant_inst &in) -> std::unique_ptr; + auto operator()(cooperative_matrix_load_inst &in) -> std::unique_ptr; + auto operator()(cooperative_matrix_mul_add_inst &in) -> std::unique_ptr; + auto operator()(cooperative_matrix_scale_inst &in) -> std::unique_ptr; + auto operator()(cooperative_matrix_store_inst &in) -> std::unique_ptr; + auto operator()(expand_inst &in) -> std::unique_ptr; + auto operator()(fuse_inst &in) -> std::unique_ptr; + auto operator()(load_inst &in) -> std::unique_ptr; + auto operator()(group_id_inst &in) -> std::unique_ptr; + auto operator()(group_size_inst &in) -> std::unique_ptr; + auto operator()(lifetime_stop_inst &in) -> std::unique_ptr; + auto operator()(gemm_inst &in) -> std::unique_ptr; + auto operator()(gemv_inst &in) -> std::unique_ptr; + auto operator()(ger_inst &in) -> std::unique_ptr; + auto operator()(for_inst &in) -> std::unique_ptr; + auto operator()(foreach_inst &in) -> std::unique_ptr; + auto operator()(hadamard_inst &in) -> std::unique_ptr; + auto operator()(if_inst &in) -> std::unique_ptr; + auto operator()(num_subgroups_inst &in) -> std::unique_ptr; + auto operator()(parallel_inst &in) -> std::unique_ptr; + auto operator()(size_inst &in) -> std::unique_ptr; + auto operator()(subgroup_id_inst &in) -> std::unique_ptr; + auto operator()(subgroup_local_id_inst &in) -> std::unique_ptr; + auto operator()(subgroup_size_inst &in) -> std::unique_ptr; + auto operator()(subview_inst &in) -> std::unique_ptr; + auto operator()(store_inst &in) -> std::unique_ptr; + auto operator()(sum_inst &in) -> std::unique_ptr; + auto operator()(work_group_inst &in) -> std::unique_ptr; + auto operator()(yield_inst &in) -> std::unique_ptr; + + void reset_subs(); + void set_subs(tinytc_value_t in_val, tinytc_value_t out_val); + auto subs(tinytc_value_t val) -> tinytc_value_t; + + auto clone_instruction(inst_node &in) -> std::unique_ptr; + void clone_region(region_node &source, region_node &target); + + private: + template auto subs_value_range(T &&range) { + auto vec = std::vector(); + vec.reserve(range.size()); + for (auto &r : range) { + vec.emplace_back(subs(&r)); + } + return vec; + } + + std::unordered_map subs_map_; +}; + +} // namespace tinytc + +#endif // CLONE_20241118_HPP diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 9e354ab8..e7d76071 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -1240,7 +1240,8 @@ std::vector convert_to_opencl_pass::operator()(for_inst const &in) { } std::vector convert_to_opencl_pass::operator()(foreach_inst const &p) { - auto lv = declare(p.loop_var()); + throw compilation_error(p.loc(), status::not_implemented); + /*auto lv = declare(p.loop_var()); auto lv_ty = visit(*this, *p.loop_var().ty()); auto from = val(p.from()); auto to = val(p.to()); @@ -1254,7 +1255,7 @@ std::vector convert_to_opencl_pass::operator()(foreach_inst const &p bb.add(clir::declaration_assignment(lv_ty, lv, std::move(block) + m + from)); bb.add(run_on_region(p.body())); }); - return {bb.get_product()}; + return {bb.get_product()};*/ } std::vector convert_to_opencl_pass::operator()(hadamard_inst const &g) { diff --git a/src/pass/dead_code_elimination.cpp b/src/pass/dead_code_elimination.cpp index 739ced1e..ae0c43c8 100644 --- a/src/pass/dead_code_elimination.cpp +++ b/src/pass/dead_code_elimination.cpp @@ -20,7 +20,7 @@ class dead_code_analysis { public: auto operator()(inst_node &in) -> bool; auto operator()(if_inst &in) -> bool; - auto operator()(loop_inst &in) -> bool; + auto operator()(for_inst &in) -> bool; }; auto dead_code_analysis::operator()(inst_node &in) -> bool { @@ -51,7 +51,7 @@ auto dead_code_analysis::operator()(if_inst &in) -> bool { return false; } -auto dead_code_analysis::operator()(loop_inst &in) -> bool { +auto dead_code_analysis::operator()(for_inst &in) -> bool { constant_inst *from_const = dyn_cast(in.from().defining_inst()); constant_inst *to_const = dyn_cast(in.to().defining_inst()); if (in.num_results() == 0 && from_const && to_const) { diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 911a98d3..0e9e82bb 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -385,17 +385,18 @@ void dump_ir_pass::operator()(for_inst const &in) { dump_region(in.body()); } -void dump_ir_pass::operator()(foreach_inst const &p) { - *os_ << "foreach "; - dump_val(p.loop_var()); - *os_ << "="; - dump_val(p.from()); +void dump_ir_pass::operator()(foreach_inst const &in) { + *os_ << "foreach ("; + do_with_infix(in.loop_vars().begin(), in.loop_vars().end(), + [this](auto const &i) { dump_val(i); }); + *os_ << ")="; + do_with_infix(in.from().begin(), in.from().end(), [this](auto const &i) { dump_val(i); }); *os_ << ","; - dump_val(p.to()); + do_with_infix(in.to().begin(), in.to().end(), [this](auto const &i) { dump_val(i); }); *os_ << " : "; - visit(*this, *p.loop_var().ty()); + visit(*this, *in.loop_vars().begin()->ty()); *os_ << " "; - dump_region(p.body()); + dump_region(in.body()); } void dump_ir_pass::operator()(hadamard_inst const &g) { diff --git a/src/pass/lower_foreach.cpp b/src/pass/lower_foreach.cpp new file mode 100644 index 00000000..f61d5d57 --- /dev/null +++ b/src/pass/lower_foreach.cpp @@ -0,0 +1,149 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/lower_foreach.hpp" +#include "codegen_tools.hpp" +#include "device_info.hpp" +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "pass/clone.hpp" +#include "support/visit.hpp" +#include "support/walk.hpp" +#include "tiling.hpp" + +namespace tinytc { + +template +void make_loop0(region_builder &bb, value from, value to, value sg_id, int sgs, int num_tiles, + F &&make_body, location const &loc) { + auto ctx = compiler_context{sg_id->context(), true}; + auto index_ty = get_scalar(ctx, scalar_type::index); + auto sg_lid_i32 = bb.add(make_subgroup_local_id(ctx)); + auto sg_lid = bb.add(make_cast(sg_lid_i32, index_ty)); + auto size = bb.add(make_arith(arithmetic::sub, to, from, loc)); + auto work_item_offset = bb.add(make_arith(arithmetic::add, from, sg_lid)); + tile_loop_by_sgs_new( + bb, size, sgs, num_tiles, sg_id, + [&](region_builder &bb, value block, bool is_remainder, value trip_count) { + auto loop_var0 = bb.add(make_arith(arithmetic::add, block, work_item_offset)); + if (is_remainder) { + auto cond = bb.add(make_cmp(cmp_condition::lt, sg_lid, trip_count)); + bb.if_condition(cond, [&](region_builder &bb) { make_body(bb, loop_var0); }); + } else { + make_body(bb, loop_var0); + } + }); +} + +class foreach_generator { + public: + foreach_generator(local_tiling tiling, core_config core_cfg) + : tiling_{std::move(tiling)}, core_cfg_{std::move(core_cfg)} {} + auto operator()(inst_node &) -> inst { return inst{}; } + auto operator()(foreach_inst &in) -> inst; + + private: + local_tiling tiling_ = {}; + core_config core_cfg_ = {}; +}; + +auto foreach_generator::operator()(foreach_inst &in) -> inst { + auto parallel = make_parallel(in.loc()); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto ctx = compiler_context{in.context(), true}; + auto i32_ty = get_scalar(ctx, scalar_type::i32); + auto index_ty = get_scalar(ctx, scalar_type::index); + + auto sg_id = bb.add(make_subgroup_id(ctx, in.loc())); + + auto cloner = inst_cloner{}; + auto loop_vars = in.loop_vars().begin(); + auto from = in.from().begin(); + auto to = in.to().begin(); + + if (in.dim() > 1) { + auto const make_inner_loop_nest = [&](region_builder &bb, value from1, value to1) { + tinytc_region_t current_region = bb.get_region().get(); + for (std::int64_t i = in.dim() - 1; i > 1; --i) { + auto for_i = std::make_unique( + &from[i], &to[i], nullptr, array_view{}, index_ty, in.loc()); + cloner.set_subs(&loop_vars[i], &for_i->loop_var()); + tinytc_region_t next_region = &for_i->body(); + current_region->insts().push_back(for_i.release()); + current_region = next_region; + } + region_builder{current_region}.for_loop( + from1, to1, index_ty, + [&](region_builder &bb, value loop_var1) { + cloner.set_subs(&loop_vars[1], loop_var1.get()); + cloner.clone_region(in.body(), *bb.get_region()); + }, + in.loc()); + }; + + auto c_m_tiles = bb.add(make_constant(tiling_.m_tiles(), i32_ty, in.loc())); + auto sg_id1 = bb.add(make_arith(arithmetic::div, sg_id, c_m_tiles, in.loc())); + auto sg_id0 = bb.add(make_arith(arithmetic::rem, sg_id, c_m_tiles, in.loc())); + + auto size1 = bb.add(make_arith(arithmetic::sub, &to[1], &from[1], in.loc())); + tile_loop_uniformly_new( + bb, size1, core_cfg_.subgroup_size, tiling_.n_tiles(), sg_id1, + [&](region_builder &bb, value block, value trip_count1) { + auto from1 = bb.add(make_arith(arithmetic::add, &from[1], block)); + auto to1 = bb.add(make_arith(arithmetic::add, from1, trip_count1)); + make_loop0( + bb, &from[0], &to[0], sg_id0, core_cfg_.subgroup_size, tiling_.m_tiles(), + [&](region_builder &bb, value loop_var0) { + cloner.set_subs(&loop_vars[0], loop_var0.get()); + make_inner_loop_nest(bb, from1, to1); + }, + in.loc()); + }); + } else if (in.dim() == 1) { + make_loop0( + bb, &from[0], &to[0], sg_id, core_cfg_.subgroup_size, + tiling_.m_tiles() * tiling_.n_tiles(), + [&](region_builder &bb, value loop_var0) { + cloner.set_subs(&loop_vars[0], loop_var0.get()); + cloner.clone_region(in.body(), *bb.get_region()); + }, + in.loc()); + } + + return parallel; +} + +lower_foreach_pass::lower_foreach_pass(::tinytc_core_info const *info) : info_(std::move(info)) { + if (info_ == nullptr) { + throw std::invalid_argument("info must not be nullptr"); + } +} + +void lower_foreach_pass::run_on_function(function_node &fn) { + auto const subgroup_size = fn.subgroup_size(); + core_config core_cfg = {}; + try { + core_cfg = info_->get_core_config(subgroup_size); + } catch (std::out_of_range const &e) { + throw compilation_error(fn.loc(), status::unsupported_subgroup_size); + } + auto const work_group_size = fn.work_group_size(); + local_tiling tiling = {}; + tiling[0] = work_group_size[0] / subgroup_size; + tiling[1] = work_group_size[1]; + + walk(fn, [&](region_node ®) { + for (auto it = reg.begin(); it != reg.end(); ++it) { + auto lowered_inst = visit(foreach_generator{tiling, core_cfg}, *it); + if (lowered_inst) { + it = reg.insts().erase(it); + it = reg.insts().insert(it, lowered_inst.release()); + } + } + }); +} + +} // namespace tinytc diff --git a/src/pass/lower_foreach.hpp b/src/pass/lower_foreach.hpp new file mode 100644 index 00000000..0486dd03 --- /dev/null +++ b/src/pass/lower_foreach.hpp @@ -0,0 +1,23 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LOWER_FOREACH_20241118_HPP +#define LOWER_FOREACH_20241118_HPP + +#include "tinytc/types.h" + +namespace tinytc { + +class lower_foreach_pass { + public: + lower_foreach_pass(::tinytc_core_info const *info); + + void run_on_function(::tinytc_func &fn); + + private: + ::tinytc_core_info const *info_; +}; + +} // namespace tinytc + +#endif // LOWER_FOREACH_20241118_HPP diff --git a/src/passes.def b/src/passes.def index 2e3da426..b0849aff 100644 --- a/src/passes.def +++ b/src/passes.def @@ -10,5 +10,6 @@ FUNCTION_PASS("dump-ir", dump_ir_pass{std::cout}) FUNCTION_PASS("insert-barrier", insert_barrier_pass{}) FUNCTION_PASS("insert-lifetime-stop", insert_lifetime_stop_pass{}) FUNCTION_PASS("set-stack-ptr", set_stack_ptr_pass{}) +FUNCTION_PASS_WITH_INFO("lower-foreach", [](tinytc_core_info const* info) { return lower_foreach_pass{info}; }) FUNCTION_PASS_WITH_INFO("lower-linalg", [](tinytc_core_info const* info) { return lower_linalg_pass{info}; }) FUNCTION_PASS_WITH_INFO("work-group-size", [](tinytc_core_info const* info) { return work_group_size_pass{info}; }) From b167983c1387a1ec301f8956e5e6ec1f17f971ec Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 19 Nov 2024 13:29:41 +0100 Subject: [PATCH 114/297] Fix lit tests Signed-off-by: Carsten Uphoff --- test/opt/check-ir/nesting0.ir | 2 +- test/opt/check-ir/nesting1.ir | 6 +++--- test/opt/check-ir/nesting3.ir | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/opt/check-ir/nesting0.ir b/test/opt/check-ir/nesting0.ir index aa30d186..c1bf6f07 100644 --- a/test/opt/check-ir/nesting0.ir +++ b/test/opt/check-ir/nesting0.ir @@ -5,7 +5,7 @@ func @illegal_nesting(%c: f32, %A: memref, %B: memref, %C: memref) { %lb = constant 1 -> index %ub = constant 16 -> index - foreach %i=%lb,%ub { + foreach (%i)=(%lb),(%ub) { gemm.n.n %c, %A, %B, %c, %C : f32, memref, memref, f32, memref } ; CHECK: 9.9-97: Collective instruction must not be called from SPMD region diff --git a/test/opt/check-ir/nesting1.ir b/test/opt/check-ir/nesting1.ir index a3b13333..e49e9fbb 100644 --- a/test/opt/check-ir/nesting1.ir +++ b/test/opt/check-ir/nesting1.ir @@ -5,9 +5,9 @@ func @illegal_nesting() { %lb = constant 1 -> index %ub = constant 16 -> index - foreach %i=%lb,%ub { - foreach %j=%lb,%ub { + foreach (%i)=(%lb),(%ub) { + foreach (%j)=(%lb),(%ub) { } -; CHECK: 9.9-26: Collective instruction must not be called from SPMD region +; CHECK: 9.9-32: Collective instruction must not be called from SPMD region } } diff --git a/test/opt/check-ir/nesting3.ir b/test/opt/check-ir/nesting3.ir index 28005dcc..ac787966 100644 --- a/test/opt/check-ir/nesting3.ir +++ b/test/opt/check-ir/nesting3.ir @@ -6,8 +6,8 @@ func @illegal_nesting() { %lb = constant 1 -> index %ub = constant 16 -> index parallel { - foreach %j=%lb,%ub { + foreach (%j)=(%lb),(%ub) { } -; CHECK: 9.9-26: Collective instruction must not be called from SPMD region +; CHECK: 9.9-32: Collective instruction must not be called from SPMD region } } From ddf18bcee1b51c6402b9b7b7788528bc6deee8f8 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Wed, 20 Nov 2024 16:40:01 +0100 Subject: [PATCH 115/297] Lower linalg to foreach (except for GEMM) Signed-off-by: Carsten Uphoff --- src/codegen_tools.cpp | 20 --- src/codegen_tools.hpp | 4 - src/pass/dump_ir.cpp | 6 +- src/pass/lower_linalg.cpp | 315 +++++++++++++++++--------------------- src/support/util.hpp | 7 + 5 files changed, 152 insertions(+), 200 deletions(-) diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index fe21f4bc..962ac435 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -520,26 +520,6 @@ void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, in }); } -void tile_loop_by_sgs_standard(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, - value sg_id, sgs_loop_body_builder_standard const &body) { - auto ctx = compiler_context{sg_id->context(), true}; - auto index_ty = get_scalar(ctx, scalar_type::index); - auto m = bb.add(make_subgroup_local_id(ctx)); - auto m_index = bb.add(make_cast(m, index_ty)); - tile_loop_by_sgs_new( - bb, loop_trip_count, sgs, num_tiles, sg_id, - [&m_index, &body](region_builder &bb, value block, bool is_remainder, value trip_count) { - auto mm = instant_constant_fold_add(bb, make_arith(arithmetic::add, block, m_index)); - if (is_remainder) { - auto cond = - instant_constant_fold_add(bb, make_cmp(cmp_condition::lt, m_index, trip_count)); - bb.if_condition(cond, [&](region_builder &bb) { body(bb, mm); }); - } else { - body(bb, mm); - } - }); -} - void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int block_size, int num_tiles, value sg_id, uniform_loop_body_builder_new const &body) { diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index 6807c0d8..6867d071 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -135,15 +135,11 @@ void write_matrix_block(clir::block_builder &bb, block_accessor const &block, // tools for tinytc lowering using sgs_loop_body_builder_new = std::function; -using sgs_loop_body_builder_standard = std::function; using uniform_loop_body_builder_new = std::function; void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, value sg_id, sgs_loop_body_builder_new const &body); -void tile_loop_by_sgs_standard(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, - value sg_id, sgs_loop_body_builder_standard const &body); - void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int block_size, int num_tiles, value sg_id, uniform_loop_body_builder_new const &body); diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 0e9e82bb..7f11e04d 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -389,11 +389,11 @@ void dump_ir_pass::operator()(foreach_inst const &in) { *os_ << "foreach ("; do_with_infix(in.loop_vars().begin(), in.loop_vars().end(), [this](auto const &i) { dump_val(i); }); - *os_ << ")="; + *os_ << ")=("; do_with_infix(in.from().begin(), in.from().end(), [this](auto const &i) { dump_val(i); }); - *os_ << ","; + *os_ << "),("; do_with_infix(in.to().begin(), in.to().end(), [this](auto const &i) { dump_val(i); }); - *os_ << " : "; + *os_ << ") : "; visit(*this, *in.loop_vars().begin()->ty()); *os_ << " "; dump_region(in.body()); diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index c6c4acc0..93ed63b8 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -134,19 +134,41 @@ class linalg_generator { public: linalg_generator(local_tiling tiling, core_config core_cfg) : tiling_{std::move(tiling)}, core_cfg_{std::move(core_cfg)} {} - auto operator()(inst_node &) -> inst { return inst{}; } - auto operator()(axpby_inst &in) -> inst; - auto operator()(ger_inst &in) -> inst; - auto operator()(gemm_inst &in) -> inst; - auto operator()(gemv_inst &in) -> inst; - auto operator()(hadamard_inst &in) -> inst; - auto operator()(sum_inst &in) -> inst; + inline void operator()(inst_node &in) { + throw compilation_error(in.loc(), status::not_implemented); + } + void operator()(axpby_inst &in); + void operator()(ger_inst &in); + void operator()(gemm_inst &in); + void operator()(gemv_inst &in); + void operator()(hadamard_inst &in); + void operator()(sum_inst &in); + + inline auto insertion_point() const -> region_node::iterator { return ip_; } + inline auto insertion_point(region_node::iterator ip) { ip_ = ip; } + + inline auto add(inst in) -> value { + auto result = value{}; + in.get_values(result); + ip_ = ip_->parent()->insts().insert(++ip_, in.release()); + return result; + } private: auto get_memref_type(value_node const &v) const -> const memref_data_type *; + template + void add_foreach(array_view from, array_view to, + data_type loop_var_ty, F &&f, location const &loc = {}) { + auto fi = std::make_unique(std::move(from), std::move(to), loop_var_ty, loc); + auto bb = region_builder{&fi->body()}; + f(bb, fi->loop_vars()); + add(inst{fi.release()}); + } + local_tiling tiling_ = {}; core_config core_cfg_ = {}; + region_node::iterator ip_; }; auto linalg_generator::get_memref_type(value_node const &v) const -> const memref_data_type * { @@ -157,111 +179,77 @@ auto linalg_generator::get_memref_type(value_node const &v) const -> const memre return t; } -auto linalg_generator::operator()(axpby_inst &in) -> inst { - auto parallel = make_parallel(in.loc()); - tinytc_region_t body = ¶llel->child_region(0); - auto bb = region_builder{body}; - +void linalg_generator::operator()(axpby_inst &in) { auto ctx = compiler_context{in.alpha().context(), true}; auto index_ty = get_scalar(ctx, scalar_type::index); - auto i32_ty = get_scalar(ctx, scalar_type::i32); auto bt = get_memref_type(in.B()); - - auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); - - auto const inner_loop = [&](region_builder &bb, value Ab, value Bb, value trip_count, - int num_tiles, value sgid) { - tile_loop_by_sgs_standard(bb, trip_count, core_cfg_.subgroup_size, num_tiles, sgid, - [&](region_builder &bb, value mm) { - auto a = bb.add(make_load(Ab, {mm}, in.loc())); - blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), Bb, - {mm}, in.loc()); - }); - }; - if (bt->dim() == 0) { - auto m = bb.add(make_subgroup_local_id(ctx, in.loc())); + auto parallel = make_parallel(in.loc()); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto sg_id = bb.add(make_subgroup_id(ctx, in.loc())); + auto sg_lid = bb.add(make_subgroup_local_id(ctx, in.loc())); + auto i32_ty = get_scalar(ctx, scalar_type::i32); auto c0 = bb.add(make_constant(0, i32_ty)); - auto cond0 = bb.add(make_cmp(cmp_condition::eq, sgid, c0)); - auto cond1 = bb.add(make_cmp(cmp_condition::eq, m, c0)); + auto cond0 = bb.add(make_cmp(cmp_condition::eq, sg_id, c0)); + auto cond1 = bb.add(make_cmp(cmp_condition::eq, sg_lid, c0)); auto cond = bb.add(make_arith(arithmetic::and_, cond0, cond1)); bb.if_condition(cond, [&](region_builder &bb) { auto a = bb.add(make_load(&in.A(), {}, in.loc())); blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), {}, in.loc()); }); + + add(std::move(parallel)); } else if (bt->dim() == 1) { - auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.B(), 0, in.loc())); - inner_loop(bb, &in.A(), &in.B(), c_shape0, tiling_.m_tiles() * tiling_.n_tiles(), sgid); + auto c0 = add(make_constant(0, index_ty, in.loc())); + auto c_shape0 = add(make_size(&in.B(), 0, in.loc())); + add_foreach( + {c0.get()}, {c_shape0.get()}, index_ty, + [&](region_builder &bb, auto loop_vars) { + auto a = bb.add(make_load(&in.A(), {&loop_vars[0]}, in.loc())); + blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), {&loop_vars[0]}, + in.loc()); + }, + in.loc()); } else if (bt->dim() == 2) { - auto c_m_tiles = bb.add(make_constant(tiling_.m_tiles(), i32_ty, in.loc())); - auto sg_n = bb.add(make_arith(arithmetic::div, sgid, c_m_tiles, in.loc())); - auto sg_m = bb.add(make_arith(arithmetic::rem, sgid, c_m_tiles, in.loc())); - - auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.B(), 0, in.loc())); - auto c_shape1 = instant_constant_fold_add(bb, make_size(&in.B(), 1, in.loc())); - tile_loop_uniformly_new( - bb, c_shape1, core_cfg_.subgroup_size, tiling_.n_tiles(), sg_n, - [&](region_builder &bb, value block, value trip_count) { - auto zero = bb.add(make_constant(0, index_ty)); - bb.for_loop(zero, trip_count, index_ty, [&](region_builder &bb, value n) { - auto nn = bb.add(make_arith(arithmetic::add, block, n, in.loc())); - auto static_offset_list = std::array{0, dynamic}; - auto static_size_list = std::array{dynamic, 0}; - auto Bb = bb.add(make_subview(&in.B(), static_offset_list, static_size_list, - {nn}, {c_shape0}, in.loc())); - if (in.tA() == transpose::T) { - std::swap(static_offset_list[0], static_offset_list[1]); - std::swap(static_size_list[0], static_size_list[1]); - } - auto Ab = bb.add(make_subview(&in.A(), static_offset_list, static_size_list, - {nn}, {c_shape0}, in.loc())); - inner_loop(bb, Ab, Bb, c_shape0, tiling_.m_tiles(), sg_m); - }); - }); + auto c0 = add(make_constant(0, index_ty, in.loc())); + auto c_shape0 = add(make_size(&in.B(), 0, in.loc())); + auto c_shape1 = add(make_size(&in.B(), 1, in.loc())); + add_foreach( + {c0.get(), c0.get()}, {c_shape0.get(), c_shape1.get()}, index_ty, + [&](region_builder &bb, auto loop_vars) { + auto a_idx = std::array{&loop_vars[0], &loop_vars[1]}; + if (in.tA() == transpose::T) { + std::swap(a_idx[0], a_idx[1]); + } + auto a = bb.add(make_load(&in.A(), a_idx, in.loc())); + blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), + {&loop_vars[0], &loop_vars[1]}, in.loc()); + }, + in.loc()); } - - return parallel; } -auto linalg_generator::operator()(ger_inst &in) -> inst { - auto parallel = make_parallel(in.loc()); - tinytc_region_t body = ¶llel->child_region(0); - auto bb = region_builder{body}; - - auto ctx = compiler_context{in.alpha().context(), true}; - auto i32_ty = get_scalar(ctx, scalar_type::i32); - auto index_ty = get_scalar(ctx, scalar_type::index); - - auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); - auto c_m_tiles = bb.add(make_constant(tiling_.m_tiles(), i32_ty, in.loc())); - auto sg_n = bb.add(make_arith(arithmetic::div, sgid, c_m_tiles, in.loc())); - auto sg_m = bb.add(make_arith(arithmetic::rem, sgid, c_m_tiles, in.loc())); - - auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.C(), 0, in.loc())); - auto c_shape1 = instant_constant_fold_add(bb, make_size(&in.C(), 1, in.loc())); - tile_loop_uniformly_new( - bb, c_shape1, core_cfg_.subgroup_size, tiling_.n_tiles(), sg_n, - [&](region_builder &bb, value block, value trip_count) { - auto zero = bb.add(make_constant(0, index_ty)); - bb.for_loop(zero, trip_count, index_ty, [&](region_builder &bb, value n) { - auto nn = bb.add(make_arith(arithmetic::add, block, n, in.loc())); - auto b = bb.add(make_load(&in.B(), {nn}, in.loc())); - tile_loop_by_sgs_standard(bb, c_shape0, core_cfg_.subgroup_size, tiling_.m_tiles(), - sg_m, [&](region_builder &bb, value mm) { - auto a = bb.add(make_load(&in.A(), {mm}, in.loc())); - auto ab = mixed_precision_arithmetic( - bb, arithmetic::mul, a, b, in.loc()); - blas_update(bb, in.atomic(), &in.alpha(), ab, - &in.beta(), &in.C(), {mm, nn}, in.loc()); - }); - }); - }); - - return parallel; +void linalg_generator::operator()(ger_inst &in) { + auto index_ty = scalar_data_type::get(in.alpha().context(), scalar_type::index); + auto c0 = add(make_constant(0, index_ty, in.loc())); + auto c_shape0 = add(make_size(&in.C(), 0, in.loc())); + auto c_shape1 = add(make_size(&in.C(), 1, in.loc())); + add_foreach( + {c0.get(), c0.get()}, {c_shape0.get(), c_shape1.get()}, index_ty, + [&](region_builder &bb, auto loop_vars) { + auto a = bb.add(make_load(&in.A(), {&loop_vars[0]}, in.loc())); + auto b = bb.add(make_load(&in.B(), {&loop_vars[1]}, in.loc())); + auto ab = mixed_precision_arithmetic(bb, arithmetic::mul, a, b, in.loc()); + blas_update(bb, in.atomic(), &in.alpha(), ab, &in.beta(), &in.C(), + {&loop_vars[0], &loop_vars[1]}, in.loc()); + }, + in.loc()); } -auto linalg_generator::operator()(gemm_inst &in) -> inst { +void linalg_generator::operator()(gemm_inst &in) { auto parallel = make_parallel(in.loc()); tinytc_region_t body = ¶llel->child_region(0); auto bb = region_builder{body}; @@ -329,35 +317,22 @@ auto linalg_generator::operator()(gemm_inst &in) -> inst { }); } - return parallel; + add(std::move(parallel)); } -auto linalg_generator::operator()(gemv_inst &in) -> inst { - auto parallel = make_parallel(in.loc()); - tinytc_region_t body = ¶llel->child_region(0); - auto bb = region_builder{body}; - +void linalg_generator::operator()(gemv_inst &in) { + auto index_ty = scalar_data_type::get(in.alpha().context(), scalar_type::index); + auto c0 = add(make_constant(0, index_ty, in.loc())); + auto c_shape0 = add(make_size(&in.C(), 0, in.loc())); auto ct = get_memref_type(in.C()); - - auto ctx = compiler_context{in.alpha().context(), true}; - auto index_ty = get_scalar(ctx, scalar_type::index); - - auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); - - auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.C(), 0, in.loc())); - auto K = instant_constant_fold_add( - bb, make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, in.loc())); - - tile_loop_by_sgs_standard( - bb, c_shape0, core_cfg_.subgroup_size, tiling_.m_tiles() * tiling_.n_tiles(), sgid, - [&](region_builder &bb, value mm) { - auto c_zero = bb.add(make_constant(0, index_ty)); - auto c_step = bb.add(make_constant(1, index_ty)); + add_foreach( + {c0.get()}, {c_shape0.get()}, index_ty, + [&](region_builder &bb, auto loop_vars) { auto c_init = bb.add(make_constant_zero(ct->element_data_ty())); + auto K = bb.add(make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, in.loc())); auto c_acc = bb.for_loop( - c_zero, K, c_step, {c_init}, index_ty, - [&](region_builder &bb, array_view p) { - auto a_idx = std::array{mm, p[0]}; + c0, K, {}, {c_init}, index_ty, [&](region_builder &bb, array_view p) { + auto a_idx = std::array{&loop_vars[0], p[0]}; if (in.tA() == transpose::T) { std::swap(a_idx[0], a_idx[1]); } @@ -367,35 +342,29 @@ auto linalg_generator::operator()(gemv_inst &in) -> inst { auto ab_c = mixed_precision_arithmetic(bb, arithmetic::add, p[1], ab, in.loc()); bb.add(make_yield({ab_c}, in.loc())); }); - blas_update(bb, in.atomic(), &in.alpha(), c_acc[0], &in.beta(), &in.C(), {mm}, - in.loc()); - }); - - return parallel; + blas_update(bb, in.atomic(), &in.alpha(), c_acc[0], &in.beta(), &in.C(), + {&loop_vars[0]}, in.loc()); + }, + in.loc()); } -auto linalg_generator::operator()(hadamard_inst &in) -> inst { - auto parallel = make_parallel(in.loc()); - tinytc_region_t body = ¶llel->child_region(0); - auto bb = region_builder{body}; - - auto ctx = compiler_context{in.alpha().context(), true}; - auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); - - auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.C(), 0, in.loc())); - tile_loop_by_sgs_standard( - bb, c_shape0, core_cfg_.subgroup_size, tiling_.m_tiles() * tiling_.n_tiles(), sgid, - [&](region_builder &bb, value mm) { - auto a = bb.add(make_load(&in.A(), {mm}, in.loc())); - auto b = bb.add(make_load(&in.B(), {mm}, in.loc())); +void linalg_generator::operator()(hadamard_inst &in) { + auto index_ty = scalar_data_type::get(in.alpha().context(), scalar_type::index); + auto c0 = add(make_constant(0, index_ty, in.loc())); + auto c_shape0 = add(make_size(&in.C(), 0, in.loc())); + add_foreach( + {c0.get()}, {c_shape0.get()}, index_ty, + [&](region_builder &bb, auto loop_vars) { + auto a = bb.add(make_load(&in.A(), {&loop_vars[0]}, in.loc())); + auto b = bb.add(make_load(&in.B(), {&loop_vars[0]}, in.loc())); auto ab = mixed_precision_arithmetic(bb, arithmetic::mul, a, b, in.loc()); - blas_update(bb, in.atomic(), &in.alpha(), ab, &in.beta(), &in.C(), {mm}, in.loc()); - }); - - return parallel; + blas_update(bb, in.atomic(), &in.alpha(), ab, &in.beta(), &in.C(), {&loop_vars[0]}, + in.loc()); + }, + in.loc()); } -auto linalg_generator::operator()(sum_inst &in) -> inst { +void linalg_generator::operator()(sum_inst &in) { auto parallel = make_parallel(in.loc()); tinytc_region_t body = ¶llel->child_region(0); auto bb = region_builder{body}; @@ -405,9 +374,6 @@ auto linalg_generator::operator()(sum_inst &in) -> inst { auto index_ty = get_scalar(ctx, scalar_type::index); auto bt = get_memref_type(in.B()); - - auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); - if (bt->dim() == 0) { auto c_sgs = bb.add(make_constant(core_cfg_.subgroup_size, i32_ty, in.loc())); auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); @@ -439,31 +405,32 @@ auto linalg_generator::operator()(sum_inst &in) -> inst { }, in.loc()); } else if (bt->dim() == 1) { - auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.B(), 0, in.loc())); - auto c_trip_count = instant_constant_fold_add( - bb, make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, in.loc())); - tile_loop_by_sgs_standard( - bb, c_shape0, core_cfg_.subgroup_size, tiling_.m_tiles() * tiling_.n_tiles(), sgid, - [&](region_builder &bb, value mm) { - auto from = bb.add(make_constant(0, index_ty)); - auto zero = bb.add(make_constant_zero(bt->element_data_ty())); - auto acc = - bb.for_loop(from, c_trip_count, {}, {zero}, index_ty, - [&](region_builder &bb, array_view args) { - auto index_list = std::array{mm, args[0]}; - if (in.tA() == transpose::T) { - std::swap(index_list[0], index_list[1]); - } - auto a = bb.add(make_load(&in.A(), index_list, in.loc())); - auto sum = mixed_precision_arithmetic(bb, arithmetic::add, - args[1], a, in.loc()); - bb.add(make_yield({sum}, in.loc())); - }); - blas_update(bb, in.atomic(), &in.alpha(), acc[0], &in.beta(), &in.B(), {mm}, - in.loc()); - }); + auto index_ty = scalar_data_type::get(in.alpha().context(), scalar_type::index); + auto c0 = add(make_constant(0, index_ty, in.loc())); + auto c_shape0 = add(make_size(&in.B(), 0, in.loc())); + add_foreach( + {c0.get()}, {c_shape0.get()}, index_ty, + [&](region_builder &bb, auto loop_vars) { + auto K = bb.add(make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, in.loc())); + auto c_init = bb.add(make_constant_zero(bt->element_data_ty())); + auto acc = bb.for_loop( + c0, K, {}, {c_init}, index_ty, [&](region_builder &bb, array_view args) { + auto index_list = std::array{&loop_vars[0], args[0]}; + if (in.tA() == transpose::T) { + std::swap(index_list[0], index_list[1]); + } + auto a = bb.add(make_load(&in.A(), index_list, in.loc())); + auto sum = + mixed_precision_arithmetic(bb, arithmetic::add, args[1], a, in.loc()); + bb.add(make_yield({sum}, in.loc())); + }); + blas_update(bb, in.atomic(), &in.alpha(), acc[0], &in.beta(), &in.B(), + {&loop_vars[0]}, in.loc()); + }, + in.loc()); } - return parallel; + + add(std::move(parallel)); } lower_linalg_pass::lower_linalg_pass(::tinytc_core_info const *info) : info_(std::move(info)) { @@ -485,12 +452,14 @@ void lower_linalg_pass::run_on_function(function_node &fn) { tiling[0] = work_group_size[0] / subgroup_size; tiling[1] = work_group_size[1]; + auto gen = linalg_generator{tiling, core_cfg}; walk(fn, [&](region_node ®) { for (auto it = reg.begin(); it != reg.end(); ++it) { - auto lowered_inst = visit(linalg_generator{tiling, core_cfg}, *it); - if (lowered_inst) { - it = reg.insts().erase(it); - it = reg.insts().insert(it, lowered_inst.release()); + if (isa(*it) || isa(*it)) { + gen.insertion_point(it); + visit(gen, *it); + reg.insts().erase(it); + it = gen.insertion_point(); } } }); diff --git a/src/support/util.hpp b/src/support/util.hpp index ff4bc722..cc27ca0c 100644 --- a/src/support/util.hpp +++ b/src/support/util.hpp @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -21,6 +22,12 @@ template class iterator_range_wrapper { ItT begin() { return begin_; } ItT end() { return end_; } + auto &operator[](std::size_t n) + requires std::random_access_iterator + { + return begin()[n]; + } + private: ItT begin_, end_; }; From b2828e5f4b2b1833cb0b30033b84dd2d99f2e7ff Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 21 Nov 2024 10:52:57 +0100 Subject: [PATCH 116/297] Update language spec to encode return types instead of operand types Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 651 +++++++++++++++++++++++--------------- 1 file changed, 394 insertions(+), 257 deletions(-) diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index df833c53..d8fa95ea 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -161,9 +161,9 @@ Scalar types .. code:: abnf scalar-type = integer-type / floating-type / complex-type - integer-type = "i" ("8" / "16" / "32" / "64") / "index" - floating-type = "f" ("32" / "64") - complex-type = "c" ("32" / "64") + integer-type = "i8" / "i16" / "i32" / "i64" / "index" + floating-type = "f32" / "f64" + complex-type = "c32" / "c64" Scalar types are either signless integer ("i"), floating point ("f"), or complex floating point ("c"). @@ -178,6 +178,8 @@ A scalar type :math:`\alpha` is called *compatible to* a scalar type :math:`\bet If an arithmetic operation involves mixed types :math:`\alpha` and :math:`\beta` and :math:`\alpha \preceq \beta`, then :math:`\alpha` is casted to :math:`\beta` and the arithmetic operation is done with type :math:`\beta`. +The function :math:`\text{compatible_type}(a, b)` returns the highest ranking type, +e.g. :math:`\text{compatible_type}(\text{i32},\text{f32}) = \text{f32}`. Memref type @@ -222,6 +224,24 @@ The local memory space is shared by all work-items of the work-group but inacces The default address space is "global", memrefs with "local" address space are returned by the alloca instruction. +Definitions +........... + +Let V be a value of memref type. +The :math:`\text{order}(V)` operation returns the memref's order. +The :math:`\text{shape}(V)` returns the tensor shape as tuple. +:math:`\text{rows}(V)` and :math:`\text{columns}(V)` return the size of the first +and second mode, respectively. +The :math:`\text{element_type}(V)` operation gives the underlying scalar type. + +For example, let B be a value of memref type, then + +* :math:`\text{order}(B) = 3` +* :math:`\text{shape}(B) = (8,16,4)` +* :math:`\text{rows}(B) = 8` +* :math:`\text{columns}(B) = 16` +* :math:`\text{element_type}(B) = \text{f32}` + Memory layout ............. @@ -299,9 +319,28 @@ The supported matrix shapes may depend on data type, matrix use, and target hard An argument to any instruction that has coopmatrix type **must** be dynamically uniform. +Definitions +........... + +Let V be a value of coopmatrix type. +The :math:`\text{rows}(V)` and :math:`\text{columns}(V)` functions return the size of the first +and second mode, respectively, and :math:`\text{shape}(V)` returns rows and cols as tuple. +The :math:`\text{component_type}(V)` operation gives the underlying scalar type +and :math:`\text{use}(V)` returns the use. + +For example, let B be a value of coopmatrix type, then + +* :math:`\text{shape}(B) = (8,16)` +* :math:`\text{rows}(B) = 8` +* :math:`\text{columns}(B) = 16` +* :math:`\text{component_type}(B) = \text{f32}` +* :math:`\text{use}(B) = \text{matrix_acc}` + Instructions ============ +Instructions may return zero, one, or multiple values, and follow the following format: + .. code:: abnf value-instruction-assignment = local-identifier "=" value-instruction @@ -310,6 +349,14 @@ Instructions instruction = value-instruction-assignment / multi-value-instruction-assignment +That is, on the left-hand side we have list of values that are produced by the instruction followed by an equals sign, +or an empty string, if the instruction does not produce values. +On the right-hand side, after the equals sign or empty string, the name of the instruction is written, e.g. "arith", optionally followed by instruction modifiers, e.g. "arith.add". +Then, a list of operands follows that is usually comma-seperated but might also be printed in a custom format +(e.g. for "load", "store", "subview", etc.). +If the instruction produces values, then the types of the returned values must be annotated after a colon. + + Collective instructions ----------------------- @@ -319,23 +366,18 @@ Alloca .. code:: abnf - value-instruction = "alloca" "->" memref-type + value-instruction = "alloca" ":" memref-type Overview ~~~~~~~~ The alloca instruction allocates temporary memory that is freed automatically at the end of the block that contains the alloca. -Returns -~~~~~~~ - -A memref of the memref-type. - Restrictions ~~~~~~~~~~~~ -- The memref's size must known at compile-time, i.e. the tensor shape must not contain any dynamic modes. -- The address space must be "local". +* The memref's size must known at compile-time, i.e. the tensor shape must not contain any dynamic modes. +* The address space must be "local". Axpby ..... @@ -343,9 +385,8 @@ Axpby .. code:: abnf transpose = ".t" / ".n" - instruction =/ "axpby" transpose [".atomic"] - local-identifier "," local-identifier "," local-identifier "," local-identifier - ":" scalar-type "," memref-type "," scalar-type "," memref-type + instruction =/ "axpby" transpose [".atomic"] local-identifier "," local-identifier "," + local-identifier "," local-identifier Overview ~~~~~~~~ @@ -356,34 +397,36 @@ Axpby implements B := \alpha \text{op}(A) + \beta B -for vectors and matrices. -If the atomic flag is set, B is updated atomically. - -Arguments -~~~~~~~~~ - -The first argument gives :math:`\alpha`, and the third argument gives :math:`\beta`. -The second and the fourth argument must have memref type and give A and B, respectively. - -The transpose modifier defines :math:`\text{op}` as following: +for vectors and matrices, where :math:`\text{op}(X)` is defined as .. math:: - \text{op}_i(X) := \left\{ - \begin{array}{rcl} - X^T & \text{ if } & \text{modifier}_i= t \wedge \text{order}(X) = 2,\\ + \text{op}(X) := \left\{ + \begin{array}{rcl} + X^T & \text{ if } & \text{transpose} = \text{".t"} \wedge \text{order}(X) = 2,\\ X & \text{ else. } - \end{array} - \right. + \end{array} + \right. -(Note that ".t" has no effect on vectors.) +If the atomic flag is set, B is updated atomically. + +Operands +~~~~~~~~ -The shape of :math:`\text{op}(A)` and B must be identical and the order of A and B needs to be 1 (vector) -or 2 (matrix). +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type A +3 scalar-type :math:`\beta` +4 memref-type B +======= =========== ============== Restrictions ~~~~~~~~~~~~ +* :math:`\text{shape}(B) = \text{shape}(\text{op}(A))` +* :math:`\text{order}(B) = 1 \lor \text{order}(B) = 2` * :math:`\text{type}(\alpha) \preceq \text{element_type}(A)` * :math:`\text{type}(\beta) \preceq \text{element_type}(B)` * If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. @@ -393,9 +436,8 @@ Foreach .. code:: abnf - instruction =/ "foreach" "(" local-identifier-list ")" "=" - "(" local-identifier-list ")" "," "(" local-identifier-list ")" - [":" integer-type] region + instruction =/ "foreach" "(" local-identifier-list ")" [":" integer-type] "=" + "(" local-identifier-list ")" "," "(" local-identifier-list ")" region Overview ~~~~~~~~ @@ -421,8 +463,8 @@ The loop range is defined as the cartesian product of the half-open intervals (\text{var}_1, \dots, \text{var}_N) \in [\text{from}_1; \text{to}_1) \times \dots \times [\text{from}_N; \text{to}_N) -The integer type of the loop variable and the loop bounds is given after the colon and -the default integer type is ``index``. +The integer type of the loop variable and the loop bounds can be optionally set after the colon. +The default integer type is ``index``. The mapping of trip count to work-item is implementation-defined. @@ -431,9 +473,8 @@ GEMM .. code:: abnf - instruction =/ "gemm" transpose transpose [".atomic"] - "," local-identifier "," local-identifier "," local-identifier "," local-identifier "," local-identifier - ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + instruction =/ "gemm" transpose transpose [".atomic"] local-identifier "," local-identifier "," + local-identifier "," local-identifier "," local-identifier Overview ~~~~~~~~ @@ -444,34 +485,41 @@ GEMM implements the well-known GEMM BLAS-3 operation. C := \alpha \text{op}_1(A) \text{op}_2(B) + \beta C -If the atomic flag is set, C is updated atomically. - -Arguments -~~~~~~~~~ - -The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. -The second, the third, and the fifth argument must have memref type and give -A, B, and C, respectively. - -The first transpose modifier defines :math:`\text{op}_1` and the second transpose modifier -defines :math:`\text{op}_2` as following: +The functions :math:`\text{op}_1` and :math:`\text{op}_2` are defined as .. math:: \text{op}_i(X) := \left\{ \begin{array}{rcl} - X^T & \text{ if } & \text{modifier}_i = t,\\ - X & \text{ if } & \text{modifier}_i = n. + X^T & \text{ if } & \text{transpose}_i = \text{".t"},\\ + X & \text{ if } & \text{transpose}_i = \text{".n"}. \end{array} \right. +where transpose\ :sub:`1` and transpose\ :sub:`2` refer to the first and second transpose modifier, respectively. -If :math:`\text{op}_1(A)` has the shape MxK and -:math:`\text{op}_2(B)` has the shape KxN then C must have the shape MxN. +If the atomic flag is set, C is updated atomically. + +Operands +~~~~~~~~ + +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type A +3 memref-type B +4 scalar-type :math:`\beta` +5 memref-type C +======= =========== ============== Restrictions ~~~~~~~~~~~~ +* :math:`\text{order}(A) = \text{order}(B) = \text{order}(C) = 2` +* :math:`\text{colums}(\text{op}_1(A)) = \text{rows}(\text{op}_2(B))` +* :math:`\text{rows}(C) = \text{rows}(\text{op}_1(A))` +* :math:`\text{columns}(C) = \text{columns}(\text{op}_2(B))` * :math:`\text{type}(\alpha) \preceq \text{compatible_type}(\text{element_type}(A), \text{element_type}(B))` * :math:`\text{type}(\beta) \preceq \text{element_type}(C)` * If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. @@ -481,9 +529,8 @@ GEMV .. code:: abnf - instruction =/ "gemv" transpose [".atomic"] - "," local-identifier "," local-identifier "," local-identifier "," local-identifier "," local-identifier - ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + instruction =/ "gemv" transpose [".atomic"] local-identifier "," local-identifier "," + local-identifier "," local-identifier "," local-identifier Overview ~~~~~~~~ @@ -494,22 +541,30 @@ GEMV implements the well-known GEMM BLAS-2 operation. c := \alpha \text{op}_1(A) b + \beta c -If the atomic flag is set, c is updated atomically. +where :math:`\text{op}_1` is defined as in GEMM. -Arguments -~~~~~~~~~ - -The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. -The second, the third, and the fifth argument must have memref type and give -A, b, and c, respectively. +If the atomic flag is set, c is updated atomically. -The transpose modifier for A as in GEMM. +Operands +~~~~~~~~ -:math:`\text{op}_1(A)` has the shape MxK and :math:`B` has the shape K then c must have the shape M. +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type A +3 memref-type b +4 scalar-type :math:`\beta` +5 memref-type c +======= =========== ============== Restrictions ~~~~~~~~~~~~ +* :math:`\text{order}(A) = 2` +* :math:`\text{order}(b) = \text{order}(c) = 1` +* :math:`\text{colums}(\text{op}_1(A)) = \text{rows}(b)` +* :math:`\text{rows}(c) = \text{rows}(\text{op}_1(A))` * :math:`\text{type}(\alpha) \preceq \text{compatible_type}(\text{element_type}(A), \text{element_type}(b))` * :math:`\text{type}(\beta) \preceq \text{element_type}(C)` * If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. @@ -519,9 +574,8 @@ GER .. code:: abnf - instruction =/ "ger" [".atomic"] - local-identifier "," local-identifier "," local-identifier "," local-identifier "," local-identifier - ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + instruction =/ "ger" [".atomic"] local-identifier "," local-identifier "," + local-identifier "," local-identifier "," local-identifier Overview ~~~~~~~~ @@ -534,18 +588,26 @@ Computes the general rank-1 update: If the atomic flag is set, C is updated atomically. -Arguments -~~~~~~~~~ - -The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. -The second, the third, and the fifth argument must have memref type and give -a, b, and C, respectively. +Operands +~~~~~~~~ -a and b must be vectors. If the size of a is M and the size of b is N the shape of C must be :math:`M\times N`. +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type a +3 memref-type b +4 scalar-type :math:`\beta` +5 memref-type C +======= =========== ============== Restrictions ~~~~~~~~~~~~ +* :math:`\text{order}(a) = \text{order}(b) = 1` +* :math:`\text{order}(C) = 2` +* :math:`\text{rows}(C) = \text{rows}(a)` +* :math:`\text{columns}(C) = \text{rows}(b)` * :math:`\text{type}(\alpha) \preceq \text{compatible_type}(\text{element_type}(a), \text{element_type}(b))` * :math:`\text{type}(\beta) \preceq \text{element_type}(C)` * If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. @@ -556,9 +618,8 @@ Hadamard product .. code:: abnf - instruction =/ "hadamard_product" [".atomic"] - local-identifier "," local-identifier "," local-identifier "," local-identifier "," local-identifier - ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + instruction =/ "hadamard_product" [".atomic"] local-identifier "," local-identifier "," + local-identifier "," local-identifier "," local-identifier Overview ~~~~~~~~ @@ -572,18 +633,24 @@ That is, in index notation we have If the atomic flag is set, c is updated atomically. -Arguments -~~~~~~~~~ - -The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. -The second, the third, and the fifth argument must have memref type and give -a, b, and c, respectively. +Operands +~~~~~~~~ -a, b, and c must be vectors and have equal shape. +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type a +3 memref-type b +4 scalar-type :math:`\beta` +5 memref-type c +======= =========== ============== Restrictions ~~~~~~~~~~~~ +* :math:`\text{order}(a) = \text{order}(b) = \text{order}(c) = 1` +* :math:`\text{shape}(a) = \text{shape}(b) = \text{shape}(c)` * :math:`\text{type}(\alpha) \preceq \text{compatible_type}(\text{element_type}(a), \text{element_type}(b))` * :math:`\text{type}(\beta) \preceq \text{element_type}(c)` * If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. @@ -605,43 +672,47 @@ Sum .. code:: abnf - instruction =/ "sum" transpose [".atomic"] - "," local-identifier "," local-identifier "," local-identifier "," local-identifier - ":" scalar-type "," memref-type "," scalar-type "," memref-type + instruction =/ "sum" transpose [".atomic"] local-identifier "," local-identifier "," + local-identifier "," local-identifier Overview ~~~~~~~~ Computes the matrix-vector product or the dot product of A with a vector of ones. -That is, for matrices we have +That is, if the result is a vector we have .. math:: - B := \alpha \text{op}(A) \vec{1} + \beta B + b := \alpha \text{op}(A) \vec{1} + \beta b, -and for vectors we have +where :math:`\text{op}(A)` is defined as in the axpby instruction, +and if the result is a scalar we have .. math:: - b := \alpha \left + \beta b + b := \alpha \left + \beta b -If the atomic flag is set, B is updated atomically. +If the atomic flag is set, b is updated atomically. -Arguments -~~~~~~~~~ - -The first argument gives :math:`\alpha` and the third argument gives :math:`\beta`. -The second and the fourth argument must have memref type and give A and B, respectively. -If A is a matrix then B must be a vector. -The first mode size of :math:`\text{op}(A)` must match the size of B. -If A is a vector, then B must be a scalar memref. +Operands +~~~~~~~~ -The transpose op is defined as in the axpby instruction. +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type A +3 scalar-type :math:`\beta` +4 memref-type b +======= =========== ============== Restrictions ~~~~~~~~~~~~ +* :math:`\text{order}(b) = 1 \lor \text{order}(b) = 0` +* :math:`\text{order}(A) = \text{order}(b)+1` +* :math:`\text{rows}(b) = \text{rows}(\text{op}(A)) \text{ if } \text{order}(b) = 1` * :math:`\text{type}(\alpha) \preceq \text{element_type}(A)` * :math:`\text{type}(\beta) \preceq \text{element_type}(B)` * If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. @@ -706,7 +777,7 @@ Overview ~~~~~~~~ Unary arithmetic operation on scalars and cooperative matrices. -For integer and floating point input, the returned value has the same type as the operand. +For integer and floating point input, the operand must have the same type as the returned value. For complex input, the returned value has the component floating point type for ".abs", ".im", and ".re", and the returned value has the same type as the operand for ".neg" and ".conj". @@ -757,13 +828,13 @@ Cast .. code:: abnf - value-instruction =/ "cast" local-identifier ":" scalar-type "->" scalar-type - value-instruction =/ "cast" local-identifier ":" coopmatrix-type "->" coopmatrix-type + value-instruction =/ "cast" local-identifier ":" scalar-type + value-instruction =/ "cast" local-identifier ":" coopmatrix-type Overview ~~~~~~~~ -Cast scalar values or cooperative matrices. +Cast scalar values or cooperative matrices to type indicated after the colon. The shape and the use the coopmatrix types must match. Casts from complex types to non-complex types are forbidden. @@ -790,7 +861,7 @@ Comparison .. code:: abnf value-instruction =/ "cmp" (".eq" / ".ne" / ".gt" / ".ge" / ".lt" / ".le") - local-identifier "," local-identifier ":" scalar-type + local-identifier "," local-identifier ":" "bool" Overview ~~~~~~~~ @@ -817,7 +888,7 @@ Constant .. code:: abnf - value-instruction =/ "constant" constant "->" (boolean-type / scalar-type / coopmatrix-type) + value-instruction =/ "constant" constant ":" (boolean-type / scalar-type / coopmatrix-type) Overview ~~~~~~~~ @@ -835,7 +906,7 @@ Cooperative matrix load value-instruction =/ "cooperative_matrix_load" transpose checked-flag local-identifier "[" local-identifier "," local-identifier "]" - ":" memref-type "->" coopmatrix-type + ":" coopmatrix-type checked-flag = ".rows_checked" / ".cols_checked" / ".both_checked" Overview @@ -869,24 +940,31 @@ Flag Description .both_checked.t .rows_checked.t + .cols_checked.t =============== ======================================================================================================= -Arguments -~~~~~~~~~ +Operands +~~~~~~~~ + +======= =============== =========== +Op.-No. Type Description +======= =============== =========== +1 memref-type M +2 index x +3 index y +======= =============== =========== -The first operand must have memref type of dimension 2 with the same component type -as the coopmatrix type. -The indices must be of ``index`` type. +Restrictions +~~~~~~~~~~~~ -All arguments **must** be dynamically uniform. +* :math:`\text{order}(M) = 2` +* :math:`\text{component_type}(A) = \text{element_type}(M)` +* All arguments **must** be dynamically uniform. Cooperative matrix mul add .......................... .. code:: abnf - value-instruction =/ "cooperative_matrix_mul_add" - local-identifier "," local-identifier "," local-identifier - ":" coopmatrix-type "," coopmatrix-type "," coopmatrix-type - "->" coopmatrix-type + value-instruction =/ "cooperative_matrix_mul_add" local-identifier "," + local-identifier "," local-identifier ":" coopmatrix-type Overview ~~~~~~~~ @@ -899,15 +977,24 @@ Matrix mul add returns the value of where A, B, and C are matrices given by the three operands. -The operands must have cooperative matrix type, where the first operand has shape :math:`M\times K` -with use "matrix_a", the second operand has shape :math:`K\times N` with use "matrix_b", -and the third operand and the result have shape :math:`M\times N` with use "matrix_acc". +Operands +~~~~~~~~ -The component types of the operands and the result do not need to match. +======= =============== ========== =========== +Op.-No. Type Use Description +======= =============== ========== =========== +1 coopmatrix-type matrix_a A +2 coopmatrix-type matrix_b B +3 coopmatrix-type matrix_acc C +======= =============== ========== =========== Restrictions ~~~~~~~~~~~~ +* :math:`\text{columns}(A) = \text{rows}(B)` +* :math:`\text{rows}(C) = \text{rows}(A) \land \text{columns}(C) = \text{columns}(B)` +* :math:`\text{shape}(D) = \text{shape}(C)` +* :math:`\text{use}(D) = \text{matrix_acc}` * :math:`\text{compatible_type}(\text{component_type}(A), \text{component_type}(B)) \preceq \text{component_type}(C)` * Cast of :math:`\text{component_type}(C)` to :math:`\text{component_type}(D)` must be allowed @@ -916,24 +1003,39 @@ Cooperative matrix scale .. code:: abnf - value-instruction =/ "cooperative_matrix_scale" - local-identifier "," local-identifier - ":" scalar-type "," coopmatrix-type + value-instruction =/ "cooperative_matrix_scale" local-identifier "," local-identifier + ":" coopmatrix-type Overview ~~~~~~~~ -Scale a matrix by a scalar. -The scalar type of the scalar and the component type of the matrix must match. +Scale a coopmatrix by a scalar. +The scalar type of the scalar and the component type of the coopmatrix must match, +and the returned must have the same coopmatrix type as the matrix operand. + +Operands +~~~~~~~~ + +======= =============== =========== +Op.-No. Type Description +======= =============== =========== +1 scalar-type scalar +2 coopmatrix-type matrix +======= =============== =========== + +Restrictions +~~~~~~~~~~~~ + +* :math:`\text{type}(scalar) = \text{component_type}(matrix)` +* :math:`\text{type}(result) = \text{type}(matrix)` Cooperative matrix store ........................ .. code:: abnf - instruction =/ "cooperative_matrix_store" checked-flag [store-flag] - local-identifier "," local-identifier "[" local-identifier "," local-identifier "]" - ":" coopmatrix-type "," memref-type + instruction =/ "cooperative_matrix_store" checked-flag [store-flag] local-identifier "," + local-identifier "[" local-identifier "," local-identifier "]" Overview ~~~~~~~~ @@ -963,13 +1065,23 @@ When the atomic_add flag is set, the coopmatrix is added to the memref atomicall When storing a complex value the update may be pseudo-atomic, meaning that an atomic store is used for the the real and imaginary separately. -Arguments -~~~~~~~~~ +Operands +~~~~~~~~ -The first operand must have cooperative matrix type with the same component type as the memref type. -The indices must be of ``index`` type. +======= =============== =========== +Op.-No. Type Description +======= =============== =========== +1 coopmatrix-type A +2 memref-type M +3 index x +4 index y +======= =============== =========== -All arguments **must** be dynamically uniform. +Restrictions +~~~~~~~~~~~~ + +* :math:`\text{component_type}(A) = \text{element_type}(B)` +* All arguments **must** be dynamically uniform. Expand ...... @@ -985,65 +1097,69 @@ Overview The expand instruction returns a view on a tensor with a mode viewed as higher-order mode. -Arguments -~~~~~~~~~ +Operands +~~~~~~~~ The first argument must point to a value of memref type. The first integer constant before "->" gives the mode that shall be expanded. The expand shape coming after "->" gives the new shape of the mode. -Dynamic values in the expand shape must have index type. +Dynamic values in the expand shape must have `index` type. + +Restrictions +~~~~~~~~~~~~ -The output type is a memref type according to the following rules: +The memref type of the result must conform with the following rules: #. **Shape:** The mode size is replaced with the expand shape. The product of the expand shape must equal the size of the expanded mode. .. code:: - expand %0[1 -> 2x8] : memref ; -> memref - expand %0[1 -> 2x2x2x2] : memref ; -> memref + expand %0[1 -> 2x8] : memref ; %0: memref + expand %0[1 -> 2x2x2x2] : memref ; %0: memref #. **Identifiers:** Local identifiers in the expand shape are dynamic in the resulting memref type. The product of the dynamic expand shape must equal the size of the expanded mode. .. code:: - expand %0[1 -> %1 x 2] : memref ; -> memref - expand %0[1 -> 2 x %1] : memref ; -> memref - expand %0[1 -> %1 x 2] : memref ; -> memref - expand %0[1 -> %1 x 2] : memref ; -> memref - expand %0[1 -> %1 x %2 x 2] : memref ; -> memref - expand %0[1 -> %2 x 2 x %1] : memref ; -> memref - expand %0[1 -> %1 x %2] : memref ; -> memref - expand %0[1 -> %1 x %2] : memref ; -> memref + expand %0[1 -> %1 x 2] : memref ; %0: memref + expand %0[1 -> 2 x %1] : memref ; %0: memref + expand %0[1 -> %1 x 2] : memref ; %0: memref + expand %0[1 -> %1 x 2] : memref ; %0: memref + expand %0[1 -> %1 x %2 x 2] : memref ; %0: memref + expand %0[1 -> %2 x 2 x %1] : memref ; %0: memref + expand %0[1 -> %1 x %2] : memref ; %0: memref + expand %0[1 -> %1 x %2] : memref ; %0: memref *Note:* In the third example above, %1 must be equal to 8. The output mode corresponding to %1 is still dynamic. #. **Stride:** A new stride entry is entered that follows the canonical stride computation. + It is also permissible to put '?' for a stride instead of the constant value. .. code:: - expand %0[0->4 x 8] : memref> ; -> memref> - expand %0[0->%1 x 4] : memref> ; -> memref> - expand %0[0->4 x %1] : memref> ; -> memref> + expand %0[0->4 x 8] : memref> ; %0: memref> + expand %0[0->4 x 8] : memref> ; %0: memref> + expand %0[0->%1 x 4] : memref> ; %0: memref> + expand %0[0->4 x %1] : memref> ; %0: memref> + expand %0[0->4 x %1] : memref> ; %0: memref> -Restrictions -~~~~~~~~~~~~ +Further restrictions: -The product of the expand shape must be the same as the mode size. -If the product of the expand shape is only known at runtime, then it is undefined behaviour -if the dynamic product does not match the mode size. +* The product of the expand shape must be the same as the mode size. +* If the product of the expand shape is only known at runtime, then it is undefined behaviour + if the dynamic product does not match the mode size. For ... .. code:: abnf - multi-value-instruction = "for" local-identifier "=" + multi-value-instruction = "for" local-identifier [":" integer-type] "=" local-identifier "," local-identifier ["," local-identifier] - ["init" "(" init-value-list ")" "->" "(" return-type-list ")" ] - [":" integer-type] region + ["init" "(" init-value-list ")" "->" "(" return-type-list ")" ] region init-value-list = init-value *("," init-value) init-value = local-identifier "=" local-identifier return-type-list = return-type *("," return-type) @@ -1087,7 +1203,7 @@ Example: %to = constant 6 -> i32 %f0 = constant 0 -> i64 %f1 = constant 1 -> i64 - %fn_1, %fn = for %n=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) : i32 { + %fn_1, %fn = for %n:i32=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) { %fn = arith.add %fn_2, %fn_1 : i64 yield %fn_1, %fn : i64, i64 } @@ -1098,67 +1214,69 @@ Fuse .. code:: abnf - value-instruction =& "fuse" local-identifier "[" integer-constant "," integer-constant "]" ":" memref-type + value-instruction =/ "fuse" local-identifier "[" integer-constant "," integer-constant "]" + ":" memref-type Overview ~~~~~~~~ The fuse instruction returns a view on a tensor with two or more adjacent modes viewed as a single mode. -Arguments -~~~~~~~~~ +Fused modes are specified as the interval [from, to], where counting starts from 0. +From and to must refer to existing modes, that is, we require :math:`0 \leq \text{from} < \text{to} < \text{order}(\text{tensor})`. +Moreover, the stride vector S and the shape vector s must satisify the following compatibility condition: -The first argument must point to a value of memref type. -The fused modes are specified as the interval [from, to], where from is given -by the first integer and to is given by the second integer. -Counting starts from 0 so we have +:math:`\forall k \in [\text{from},\text{to}): S_{k}s_{k} = S_{k+1}` -.. math:: - - 0 \leq from < to < order(memref) - -The local identifier must have the memref type specified last. -The output type is a memref type according to the following rules: - -#. **Shape:** The mode size of the fused modes is the product of the mode sizes. If one mode is dynamic the fused mode size is dynamic. - - .. code:: +If S(i:j) and s(i:j) are known at compile time, the fuse instruction is illegal if the compatibility +condition is not satisfied. +If a single entry in S(i:j) or s(i:j) is dynamic, then fusing modes that violate the compatbility condition +is undefined beheaviour, e.g. - fuse %0[1,3] : memref ; -> memref - fuse %0[1,3] : memref> ; -> memref> +.. code:: -#. **Stride:** Strides remain unchanged. + ; Illegal, modes cannot be fused + fuse %0[0,1] : memref ; %0: memref> + ; Undefined behaviour if dynamic stride != 8 + fuse %0[0,1] : memref> ; %0: memref> - .. code:: +Operands +~~~~~~~~ - fuse %0[1,2] : memref> ; -> memref> - fuse %0[0,1] : memref> ; -> memref> +======= ================ =========== +Op.-No. Type Description +======= ================ =========== +1 memref-type tensor +2 integer-constant from +3 integer-constant to +======= ================ =========== Restrictions ~~~~~~~~~~~~ -Let i be the first mode and j the last mode. -The stride vector S and the shape vector s must satisify the following compatibility condition: +The memref type of the result must conform with the following rules: -:math:`\forall k \in [i,j): S_{k}s_{k} = S_{k+1}` +#. **Shape:** The mode size of the fused modes is the product of the mode sizes. If one mode is dynamic the fused mode size is dynamic. -If S(i:j) and s(i:j) are known at compile time, the fuse instruction is illegal if the compatibility -condition is not satisfied. -If a single entry in S(i:j) or s(i:j) is dynamic, then fusing modes that violate the compatbility condition -is undefined beheaviour. + .. code:: -.. code:: + fuse %0[1,3] : memref ; %0: memref + fuse %0[1,3] : memref> ; %0: memref> + +#. **Stride:** Strides remain unchanged or are replaced by '?'. - fuse %0[0,1] : memref> ; Illegal, modes cannot be fused - fuse %0[0,1] : memref> ; Undefined behaviour if dynamic stride != 8 + .. code:: + fuse %0[1,2] : memref> ; %0: memref> + fuse %0[1,2] : memref> ; %0: memref> + fuse %0[0,1] : memref> ; %0: memref> Group id ........ .. code:: abnf - value-instruction =/ "group_id" + value-instruction =/ "group_id" ":" "index" Overview ~~~~~~~~ @@ -1170,7 +1288,7 @@ Group size .. code:: abnf - value-instruction =/ "group_size" + value-instruction =/ "group_size" ":" "index" Overview ~~~~~~~~ @@ -1191,7 +1309,7 @@ Overview An if statement. Both regions are *mixed regions*. -The condition must have boolean type. +The condition (first operand) must have boolean type. Returns ~~~~~~~ @@ -1219,8 +1337,8 @@ Load .. code:: abnf value-instruction =/ "load" local-identifier "[" [local-identifier-list] "]" - ":" memref-or-group-type - memref-or-group-type = memref-type / group-type + ":" scalar-or-memref-type + scalar-or-memref-type = scalar-type / memref-type Overview ~~~~~~~~ @@ -1229,11 +1347,15 @@ Load the element given by the index list from a memref or group. The number of indices must match the order of the memref and a single index must be given for a group. -Arguments +Operands ~~~~~~~~~ -The first operand must have memref or group type. -The indices must be of ``index`` type. +======= ======================== =========== +Op.-No. Type Description +======= ======================== =========== +1 memref-type / group-type tensor +2... index index list +======= ======================== =========== Returns ~~~~~~~ @@ -1241,17 +1363,17 @@ Returns A value of the memref's element type or the group's memref type. Examples: -#. ``load %0[] : memref`` returns a ``f32`` value. -#. ``load %0[5, %1] : memref`` returns a ``f32`` value. -#. ``load %0[%1] : group>`` returns a ``memref`` value. -#. ``load %0[%1] : group, offset: ?>`` returns a ``memref`` value. +#. ``load %0[] : f32 ; %0: memref`` +#. ``load %0[5, %1] : f32 ; %0: memref`` +#. ``load %0[%1] : memref ; %0: group>`` +#. ``load %0[%1] : memref ; %0: group, offset: ?>`` Number of subgroups ................... .. code:: abnf - value-instruction =/ "num_subgroups" + value-instruction =/ "num_subgroups" ":" "i32" Overview ~~~~~~~~ @@ -1263,34 +1385,31 @@ Size .. code:: abnf - value-instruction =/ "size" local-identifier "[" integer-constant "]" ":" memref-type + value-instruction =/ "size" local-identifier "[" integer-constant "]" ":" "index" Overview ~~~~~~~~ The size instruction returns the i-th entry of the tensor's shape, where "i" is given by the integer constant in square brackets. +"i" must be in bounds, i.e. :math:`0 \leq i < \text{order}(tensor)`. -Arguments +Operands ~~~~~~~~~ -The first argument must point to a value of memref type. -The integer constant i gives the mode for which the size shall be returned. -It is required that - -.. math:: - - 0 \leq i < order(memref) - -The local identifier must have the memref type specified last. -The instruction returns an integer of index type. +======= ================ =========== +Op.-No. Type Description +======= ================ =========== +1 memref-type tensor +2 integer-constant mode index +======= ================ =========== Subgroup size ............. .. code:: abnf - value-instruction =/ "subgroup_size" + value-instruction =/ "subgroup_size" ":" "i32" Overview ~~~~~~~~ @@ -1312,11 +1431,8 @@ Overview The subview instruction returns a view on a tensor. -Arguments -~~~~~~~~~ - The first argument must point to a value of memref type. -The number of indices in square brackets must match the order of the memref. +The number of indices in square brackets must match the order of the memref type. The indices are either given as single index or as a slice, where slices are given in offset plus size notation ("%offset : %size"). E.g. the slice "%0 : %1" extracts a block of %1 elements beginning from %0, which is equivalent @@ -1336,39 +1452,40 @@ A single index is syntactic sugar for offset plus size 0, e.g. %0 is syntactic s to the memref would be out-of-bounds. However, a one-sized rank, e.g. memref, might be desirable.) A dynamic size of zero is undefined behaviour. - - There is no run-time check whether the indices are within bounds. Offset and size must be of index type. Offset must be non-negative and size must be positive. -The local identifier must have the memref type specified last. -The output type is a memref type according to the following rules: +Restrictions +~~~~~~~~~~~~ -#. **Invariant-stride:** The stride is not changed. +The memref type of the result must conform with the following rules: + +#. **Invariant-stride:** The stride is not changed or replaced with '?'. .. code:: - subview %0[4:8,8:4] : memref ; Returns memref> + subview %0[4:8,8:4] : memref> ; %0: memref + subview %0[4:8,8:4] : memref> ; %0: memref #. **Rank-reduction:** A mode accessed by offset only or a mode with size statically known to be 0 is removed from the output tensor. .. code:: - subview %0[2:4, %1] : memref ; Returns memref - subview %0[2:4, %1:0] : memref ; Returns memref - subview %0[2:4, %1:1] : memref ; Returns memref> + subview %0[2:4, %1] : memref ; %0: memref + subview %0[2:4, %1:0] : memref ; %0: memref + subview %0[2:4, %1:1] : memref> ; %0: memref #. **Output-mode size:** The size of the output mode is determined by the size field of a slice and may be dynamic. .. code:: - subview %0[%1:4] : memref ; Returns memref - subview %0[%2:%2] : memref ; Returns memref - subview %0[2:4, %2:%2, 6:7] : memref ; Returns memref - subview %0[2:4, %2:%2, 6:7] : memref> ; Returns memref + subview %0[%1:4] : memref ; %0: memref + subview %0[%2:%2] : memref ; %0: memref + subview %0[2:4, %2:%2, 6:7] : memref ; %0: memref + subview %0[2:4, %2:%2, 6:7] : memref ; %0: memref> Store ..... @@ -1377,13 +1494,12 @@ Store instruction =/ "store" [store-flag] local-identifier "," local-identifier "[" [local-identifier-list] "]" - ":" memref-type store-flag = ".atomic" / ".atomic_add" Overview ~~~~~~~~ -Store a scalar value in a memref at the position given by the index list. +Store a scalar value (first operand) in a memref (second operand) at the position given by the index list. The number of indices must match the order of the memref. The store is atomic when the atomic flag is set with relaxed memory ordering. @@ -1397,11 +1513,21 @@ for the the real and imaginary separately. *Note:* Store should only be used in SPMD regions as otherwise the same memory location is written from all work-items. -Arguments -~~~~~~~~~ +Operands +~~~~~~~~ -The first operand must have the same scalar type as the memref type. -The indices must be of ``index`` type. +======= ================ =========== +Op.-No. Type Description +======= ================ =========== +1 scalar-type value +2 memref-type tensor +3... index index list +======= ================ =========== + +Restrictions +~~~~~~~~~~~~ + +* :math:`\text{type}(value) = \text{element_type}(tensor)` Work group collectives ...................... @@ -1422,6 +1548,15 @@ Work group op Description .reduce_add Compute work group sum of value ============= ================================================================ +Operands +~~~~~~~~ + +======= ================ =========== +Op.-No. Type Description +======= ================ =========== +1 scalar-type value +======= ================ =========== + Restrictions ~~~~~~~~~~~~ @@ -1439,10 +1574,14 @@ Overview Yield returns values from an if or for instruction. -Arguments -~~~~~~~~~ +Operands +~~~~~~~~ -The length of the local identifier list must equal the length of the return type list. +======= ============================================ =========== +Op.-No. Type Description +======= ============================================ =========== +1... boolean-type / scalar-type / coopmatrix-type value +======= ============================================ =========== Additional instructions ....................... @@ -1459,7 +1598,7 @@ Subgroup id .. code:: abnf - value-instruction =/ "subgroup_id" + value-instruction =/ "subgroup_id" ":" "i32" Overview ~~~~~~~~ @@ -1471,7 +1610,7 @@ Subgroup local id .. code:: abnf - value-instruction =/ "subgroup_local_id" + value-instruction =/ "subgroup_local_id" ":" "i32" Overview ~~~~~~~~ @@ -1500,14 +1639,12 @@ where B and C are constant matrices and A and D are matrix batches. %B: memref, %C: memref, %D: memref) { - %0 = group_id - %1 = load %A[%0] : group> ; Returns memref - %2 = subview %D[:,:,%0] : memref ; Returns memref - %tmp0 = alloca -> memref + %0 = group_id : index + %1 = load %A[%0] : memref + %2 = subview %D[:,:,%0] : memref + %tmp0 = alloca : memref %zero = constant 0.0 : f32 %one = constant 1.0 : f32 gemm.n.t %one, %1, %B, %zero, %tmp0 - : f32, memref, memref, f32, memref gemm.n.n %alpha, %tmp0, %C, %one, %2 - : f32, memref, memref, f32, memref } From ee72ff6e55037401655c2053e00e473e2aba6cc2 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 21 Nov 2024 10:57:23 +0100 Subject: [PATCH 117/297] Implement language changes for BLAS ops and arith Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.rst | 7 ++ docs/api/builder_capi.yaml | 1 + docs/api/core_capi.rst | 35 ++++++++++ docs/api/core_cxxapi.rst | 14 ++++ examples/matrix_chain/test_ader.cpp | 4 +- include/tinytc/tinytc.h | 15 +++- include/tinytc/tinytc.hpp | 17 ++++- include/tinytc/types.h | 2 + include/tinytc/types.hpp | 1 + src/codegen_tools.cpp | 93 +++++++++++++------------ src/error.cpp | 2 + src/inst.cpp | 7 +- src/node/inst_node.cpp | 28 ++++---- src/node/inst_node.hpp | 3 +- src/parser/lexer.re | 2 +- src/parser/parser_impl.yy | 83 ++++++---------------- src/pass/clone.cpp | 3 +- src/pass/dump_ir.cpp | 22 +----- src/pass/lower_foreach.cpp | 33 ++++----- src/pass/lower_linalg.cpp | 19 ++--- src/recipe/tall_and_skinny.cpp | 6 +- src/value.cpp | 9 ++- test/codegen/axpby0.ir | 4 +- test/codegen/axpby1.ir | 8 +-- test/opt/check-ir/nesting0.ir | 4 +- test/opt/dead-code-elimination.ir | 8 +-- test/opt/insert-barrier.ir | 103 ++++++++++++++-------------- test/opt/insert-lifetime-stop.ir | 31 ++++----- test/opt/work-group-size.ir | 13 ++-- test/spv/alloca.ir | 6 +- 30 files changed, 314 insertions(+), 269 deletions(-) diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index 233af067..8194f2f3 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -780,6 +780,8 @@ Value * :ref:`tinytc_value_get_name` + * :ref:`tinytc_value_get_type` + * :ref:`tinytc_value_set_name` * :ref:`tinytc_value_set_name_n` @@ -792,6 +794,11 @@ tinytc_value_get_name .. doxygenfunction:: tinytc_value_get_name +tinytc_value_get_type +..................... + +.. doxygenfunction:: tinytc_value_get_type + tinytc_value_set_name ..................... diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index 50abd572..27a0017c 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -120,5 +120,6 @@ Builder C-API: Value: function: - tinytc_value_get_name + - tinytc_value_get_type - tinytc_value_set_name - tinytc_value_set_name_n diff --git a/docs/api/core_capi.rst b/docs/api/core_capi.rst index 056d5316..44a2668f 100644 --- a/docs/api/core_capi.rst +++ b/docs/api/core_capi.rst @@ -279,6 +279,10 @@ Compiler * :ref:`tinytc_prog_compile_to_spirv` + * :ref:`tinytc_prog_compile_to_spirv_and_assemble` + + * :ref:`tinytc_spirv_assemble` + Compiler Enumerations --------------------- @@ -315,6 +319,16 @@ tinytc_prog_compile_to_spirv .. doxygenfunction:: tinytc_prog_compile_to_spirv +tinytc_prog_compile_to_spirv_and_assemble +......................................... + +.. doxygenfunction:: tinytc_prog_compile_to_spirv_and_assemble + +tinytc_spirv_assemble +..................... + +.. doxygenfunction:: tinytc_spirv_assemble + Compiler Context ================ @@ -633,6 +647,12 @@ SPIR-V module * Functions + * :ref:`tinytc_spv_mod_dump` + + * :ref:`tinytc_spv_mod_print_to_file` + + * :ref:`tinytc_spv_mod_print_to_string` + * :ref:`tinytc_spv_mod_release` * :ref:`tinytc_spv_mod_retain` @@ -640,6 +660,21 @@ SPIR-V module SPIR-V module Functions ----------------------- +tinytc_spv_mod_dump +................... + +.. doxygenfunction:: tinytc_spv_mod_dump + +tinytc_spv_mod_print_to_file +............................ + +.. doxygenfunction:: tinytc_spv_mod_print_to_file + +tinytc_spv_mod_print_to_string +.............................. + +.. doxygenfunction:: tinytc_spv_mod_print_to_string + tinytc_spv_mod_release ...................... diff --git a/docs/api/core_cxxapi.rst b/docs/api/core_cxxapi.rst index dccbd1a0..15ce00d5 100644 --- a/docs/api/core_cxxapi.rst +++ b/docs/api/core_cxxapi.rst @@ -164,6 +164,10 @@ Compiler * :ref:`compile_to_spirv` + * :ref:`compile_to_spirv_and_assemble` + + * :ref:`spirv_assemble` + Compiler Functions ------------------ @@ -187,6 +191,16 @@ compile_to_spirv .. doxygenfunction:: tinytc::compile_to_spirv +compile_to_spirv_and_assemble +............................. + +.. doxygenfunction:: tinytc::compile_to_spirv_and_assemble + +spirv_assemble +.............. + +.. doxygenfunction:: tinytc::spirv_assemble + Compiler Context ================ diff --git a/examples/matrix_chain/test_ader.cpp b/examples/matrix_chain/test_ader.cpp index 19e14d42..60d24253 100644 --- a/examples/matrix_chain/test_ader.cpp +++ b/examples/matrix_chain/test_ader.cpp @@ -114,10 +114,10 @@ auto test_ader::make_optimized_kernel(bool dump) auto cnum = c1; auto const static_offsets2 = std::array{0, 0}; for (std::int64_t n = 1; n <= N_; ++n) { - cnum = bb.add(make_arith(arithmetic::mul, cnum, dt)); + cnum = bb.add(make_arith(arithmetic::mul, cnum, dt, dt.get_type())); denom *= n + 1; auto cdenom = bb.add(make_constant(static_cast(denom), element_ty)); - auto cfactor = bb.add(make_arith(arithmetic::div, cnum, cdenom)); + auto cfactor = bb.add(make_arith(arithmetic::div, cnum, cdenom, cnum.get_type())); auto bn = Bd_aligned(N_ - n); auto dq_next = bb.add(make_alloca(dQ_[n].local_type(element_ty))); auto dq_nextv = bb.add(make_subview(dq_next, static_offsets2, {bn, P_})); diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 1f1bc5ca..30924a2c 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -171,6 +171,17 @@ TINYTC_EXPORT tinytc_status_t tinytc_value_set_name_n(tinytc_value_t vl, uint32_ */ TINYTC_EXPORT tinytc_status_t tinytc_value_get_name(const_tinytc_value_t vl, char const **name); +/** + * @brief Get type of value + * + * @param vl [in] value object + * @param ty [out] pointer to data type + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_value_get_type(const_tinytc_value_t vl, + tinytc_data_type_t *ty); + //////////////////////////// /////// Instructions /////// //////////////////////////// @@ -197,18 +208,20 @@ TINYTC_EXPORT char const *tinytc_work_group_operation_to_string(tinytc_work_grou /** * @brief Create arithmetic instruction (binary) * - * @code %value = arith. %a, %b : type(%a) ; type(%a) == type(%b) @endcode + * @code %value = arith. %a, %b : ty ; ty == type(%a) and ty == type(%b) @endcode * * @param instr [out] pointer to the inst object created * @param op [in] arithmetic operation type * @param a [in] left-hand operand * @param b [in] right-hand operand + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_arith_inst_create(tinytc_inst_t *instr, tinytc_arithmetic_t op, tinytc_value_t a, tinytc_value_t b, + tinytc_data_type_t ty, const tinytc_location_t *loc); /** diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 98ccc2b8..d356defe 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -664,6 +664,17 @@ class value : public handle { inline void set_name(std::string_view name) { CHECK_STATUS(tinytc_value_set_name_n(obj_, name.size(), name.data())); } + + /** + * @brief Get type + * + * @return Data type + */ + inline auto get_type() -> data_type { + tinytc_data_type_t ty; + CHECK_STATUS(tinytc_value_get_type(obj_, &ty)); + return ty; + } }; static_assert(std::is_standard_layout_v && sizeof(value) == sizeof(tinytc_value_t)); @@ -863,14 +874,16 @@ static_assert(std::is_standard_layout_v && sizeof(region) == sizeof(tiny * @param op Arithmetic operation type * @param a First operand * @param b Second operand + * @param ty Result type * @param loc Source code location * * @return Instruction */ -inline inst make_arith(arithmetic op, value a, value b, location const &loc = {}) { +inline inst make_arith(arithmetic op, value a, value b, data_type ty, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC( - tinytc_arith_inst_create(&instr, static_cast(op), a, b, &loc), loc); + tinytc_arith_inst_create(&instr, static_cast(op), a, b, ty, &loc), + loc); return inst(instr); } diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 270e839e..eb5dd1b6 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -91,6 +91,8 @@ typedef enum { tinytc_status_ir_yield_in_else_branch_missing = 0x130, ///< Must have yield instruction in else branch tinytc_status_ir_from_to_mismatch = 0x131, ///< size(from) != size(to) in foreach + tinytc_status_ir_operand_type_must_match_return_type = + 0x132, /// Operand type must match return type // SPIR-V errors tinytc_status_spirv_forbidden_forward_declaration = 0x1000, ///< Forward declaration of id is forbidden diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 47ba1677..48bed60f 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -98,6 +98,7 @@ enum class status { ir_must_have_yield = tinytc_status_ir_must_have_yield, ir_yield_in_else_branch_missing = tinytc_status_ir_yield_in_else_branch_missing, ir_from_to_mismatch = tinytc_status_ir_from_to_mismatch, + ir_operand_type_must_match_return_type = tinytc_status_ir_operand_type_must_match_return_type, spirv_forbidden_forward_declaration = tinytc_status_spirv_forbidden_forward_declaration, spirv_undefined_value = tinytc_status_spirv_undefined_value, spirv_missing_dope_vector = tinytc_status_spirv_missing_dope_vector, diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 962ac435..c68c4e5b 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -489,32 +489,35 @@ void write_matrix_block(block_builder &bb, block_accessor const &block, void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, value sg_id, sgs_loop_body_builder_new const &body) { - auto index_ty = scalar_data_type::get(sg_id->context(), scalar_type::index); - auto c_sgs = bb.add(make_constant(sgs, index_ty)); - auto c_sgs_tiles = bb.add(make_constant(sgs * num_tiles, index_ty)); - auto c0 = bb.add(make_constant(0, index_ty)); - auto c_tiles_1 = bb.add(make_constant(num_tiles - 1, index_ty)); + auto ity = loop_trip_count->ty(); + auto c_sgs = bb.add(make_constant(sgs, ity)); + auto c_sgs_tiles = bb.add(make_constant(sgs * num_tiles, ity)); + auto c0 = bb.add(make_constant(0, ity)); + auto c_tiles_1 = bb.add(make_constant(num_tiles - 1, ity)); auto blocks = - instant_constant_fold_add(bb, make_arith(arithmetic::div, loop_trip_count, c_sgs)); - auto rem = instant_constant_fold_add(bb, make_arith(arithmetic::rem, loop_trip_count, c_sgs)); + instant_constant_fold_add(bb, make_arith(arithmetic::div, loop_trip_count, c_sgs, ity)); + auto rem = + instant_constant_fold_add(bb, make_arith(arithmetic::rem, loop_trip_count, c_sgs, ity)); - auto sg_id_index = instant_constant_fold_add(bb, make_cast(sg_id, index_ty)); + auto sg_id_cast = instant_constant_fold_add(bb, make_cast(sg_id, ity)); auto is_blocks_gt_0 = instant_constant_fold_add(bb, make_cmp(cmp_condition::gt, blocks, c0)); bb.if_condition(is_blocks_gt_0, [&](region_builder &bb) { auto block_start = - instant_constant_fold_add(bb, make_arith(arithmetic::mul, c_sgs, sg_id_index)); - auto block_end = instant_constant_fold_add(bb, make_arith(arithmetic::mul, c_sgs, blocks)); - bb.for_loop(std::move(block_start), std::move(block_end), c_sgs_tiles, index_ty, + instant_constant_fold_add(bb, make_arith(arithmetic::mul, c_sgs, sg_id_cast, ity)); + auto block_end = + instant_constant_fold_add(bb, make_arith(arithmetic::mul, c_sgs, blocks, ity)); + bb.for_loop(std::move(block_start), std::move(block_end), c_sgs_tiles, ity, [&](region_builder &bb, value block) { body(bb, block, false, c_sgs); }); }); auto condition0 = instant_constant_fold_add(bb, make_cmp(cmp_condition::gt, rem, c0)); bb.if_condition(condition0, [&](region_builder &bb) { auto condition1 = - instant_constant_fold_add(bb, make_cmp(cmp_condition::eq, sg_id_index, c_tiles_1)); + instant_constant_fold_add(bb, make_cmp(cmp_condition::eq, sg_id_cast, c_tiles_1)); bb.if_condition(condition1, [&](region_builder &bb) { - auto block = instant_constant_fold_add(bb, make_arith(arithmetic::mul, blocks, c_sgs)); + auto block = + instant_constant_fold_add(bb, make_arith(arithmetic::mul, blocks, c_sgs, ity)); body(bb, block, true, rem); }); }); @@ -523,49 +526,55 @@ void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, in void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int block_size, int num_tiles, value sg_id, uniform_loop_body_builder_new const &body) { - auto index_ty = scalar_data_type::get(loop_trip_count->context(), scalar_type::index); - auto c0 = bb.add(make_constant(0, index_ty)); - auto c1 = bb.add(make_constant(1, index_ty)); - auto c_tiles = bb.add(make_constant(num_tiles, index_ty)); + auto ity = loop_trip_count->ty(); + auto c0 = bb.add(make_constant(0, ity)); + auto c1 = bb.add(make_constant(1, ity)); + auto c_tiles = bb.add(make_constant(num_tiles, ity)); // Here we compute // blocks = ceil(loop_trip_count / block_size) = 1 + (loop_trip_count - 1) / block_size // blocks = ceil(blocks / num_tiles) * num_tiles = (1 + (blocks - 1) / num_tiles) * // num_tiles - auto c_block_size = bb.add(make_constant(block_size, index_ty)); - auto blocks0 = instant_constant_fold_add(bb, make_arith(arithmetic::sub, loop_trip_count, c1)); + auto c_block_size = bb.add(make_constant(block_size, ity)); + auto blocks0 = + instant_constant_fold_add(bb, make_arith(arithmetic::sub, loop_trip_count, c1, ity)); auto blocks1 = - instant_constant_fold_add(bb, make_arith(arithmetic::div, blocks0, c_block_size)); - auto blocks2 = instant_constant_fold_add(bb, make_arith(arithmetic::div, blocks1, c_tiles)); - auto blocks3 = instant_constant_fold_add(bb, make_arith(arithmetic::add, c1, blocks2)); - auto blocks = instant_constant_fold_add(bb, make_arith(arithmetic::mul, blocks3, c_tiles)); - - auto bs = instant_constant_fold_add(bb, make_arith(arithmetic::div, loop_trip_count, blocks)); - auto bs_1 = instant_constant_fold_add(bb, make_arith(arithmetic::add, bs, c1)); - auto rem = instant_constant_fold_add(bb, make_arith(arithmetic::rem, loop_trip_count, blocks)); - - auto sg_id_index = instant_constant_fold_add(bb, make_cast(sg_id, index_ty)); + instant_constant_fold_add(bb, make_arith(arithmetic::div, blocks0, c_block_size, ity)); + auto blocks2 = + instant_constant_fold_add(bb, make_arith(arithmetic::div, blocks1, c_tiles, ity)); + auto blocks3 = instant_constant_fold_add(bb, make_arith(arithmetic::add, c1, blocks2, ity)); + auto blocks = instant_constant_fold_add(bb, make_arith(arithmetic::mul, blocks3, c_tiles, ity)); + + auto bs = + instant_constant_fold_add(bb, make_arith(arithmetic::div, loop_trip_count, blocks, ity)); + auto bs_1 = instant_constant_fold_add(bb, make_arith(arithmetic::add, bs, c1, ity)); + auto rem = + instant_constant_fold_add(bb, make_arith(arithmetic::rem, loop_trip_count, blocks, ity)); + + auto sg_id_cast = instant_constant_fold_add(bb, make_cast(sg_id, ity)); // The following if makes it easy to eliminate the remainder handler in optimization if rem // == 0 is known at compile time. Without the if, we would need to prove that block_start_1 // is non-negative to eliminate the for-loop. auto is_rem_gt_0 = instant_constant_fold_add(bb, make_cmp(cmp_condition::gt, rem, c0)); bb.if_condition(is_rem_gt_0, [&](region_builder &bb) { auto block_start_1 = - instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, sg_id_index)); - auto block_end_1 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, rem)); - auto step_1 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, c_tiles)); - bb.for_loop(std::move(block_start_1), std::move(block_end_1), std::move(step_1), index_ty, + instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, sg_id_cast, ity)); + auto block_end_1 = + instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, rem, ity)); + auto step_1 = + instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, c_tiles, ity)); + bb.for_loop(std::move(block_start_1), std::move(block_end_1), std::move(step_1), ity, [&](region_builder &bb, value block) { body(bb, block, bs_1); }); }); - auto tmp0 = instant_constant_fold_add(bb, make_arith(arithmetic::rem, rem, c_tiles)); - auto tmp1 = instant_constant_fold_add(bb, make_arith(arithmetic::add, sg_id_index, tmp0)); - auto sg_id_1 = instant_constant_fold_add(bb, make_arith(arithmetic::rem, tmp1, c_tiles)); - auto tmp2 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs, sg_id_1)); - auto tmp3 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, rem)); - auto block_start = instant_constant_fold_add(bb, make_arith(arithmetic::add, tmp3, tmp2)); - auto step = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs, c_tiles)); - bb.for_loop(std::move(block_start), loop_trip_count, std::move(step), index_ty, + auto tmp0 = instant_constant_fold_add(bb, make_arith(arithmetic::rem, rem, c_tiles, ity)); + auto tmp1 = instant_constant_fold_add(bb, make_arith(arithmetic::add, sg_id_cast, tmp0, ity)); + auto sg_id_1 = instant_constant_fold_add(bb, make_arith(arithmetic::rem, tmp1, c_tiles, ity)); + auto tmp2 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs, sg_id_1, ity)); + auto tmp3 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, rem, ity)); + auto block_start = instant_constant_fold_add(bb, make_arith(arithmetic::add, tmp3, tmp2, ity)); + auto step = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs, c_tiles, ity)); + bb.for_loop(std::move(block_start), loop_trip_count, std::move(step), ity, [&](region_builder &bb, value block) { body(bb, block, bs); }); } @@ -587,7 +596,7 @@ auto mixed_precision_arithmetic(region_builder &bb, arithmetic operation, value b = bb.add(make_cast(b, compatible_ty, loc)); } } - return bb.add(make_arith(operation, a, b, loc)); + return bb.add(make_arith(operation, a, b, a->ty(), loc)); } auto mixed_precision_coopmatrix_scale(region_builder &bb, value a, value b, location const &loc) -> value { diff --git a/src/error.cpp b/src/error.cpp index 2e4dfe26..dc630ca2 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -211,6 +211,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Else-branch must have yield instruction if then-branch has yield instruction"; case tinytc_status_ir_from_to_mismatch: return "length(from) must equal length(to) and length must be greater than 0"; + case tinytc_status_ir_operand_type_must_match_return_type: + return "Type of operand must match return type"; // SPIR-V case tinytc_status_spirv_forbidden_forward_declaration: return "Forward declaration of id is forbidden"; diff --git a/src/inst.cpp b/src/inst.cpp index 570574b8..46c22b98 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -143,14 +143,15 @@ char const *tinytc_work_group_operation_to_string(tinytc_work_group_operation_t } tinytc_status_t tinytc_arith_inst_create(tinytc_inst_t *instr, tinytc_arithmetic_t op, - tinytc_value_t a, tinytc_value_t b, + tinytc_value_t a, tinytc_value_t b, tinytc_data_type_t ty, const tinytc_location_t *loc) { if (instr == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(op), a, b, get_optional(loc)) - .release(); + *instr = + std::make_unique(enum_cast(op), a, b, ty, get_optional(loc)) + .release(); }); } diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index ed2e4660..d81c8658 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -186,13 +186,20 @@ axpby_inst::axpby_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, t } arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b0, - location const &lc) + tinytc_data_type_t ty, location const &lc) : standard_inst{IK::arith}, operation_(operation) { op(op_a, a0); op(op_b, b0); loc(lc); - if (isa(*a().ty())) { + if (a().ty() != ty) { + throw compilation_error(a().loc(), status::ir_operand_type_must_match_return_type); + } + if (b().ty() != ty) { + throw compilation_error(b().loc(), status::ir_operand_type_must_match_return_type); + } + + if (isa(*ty)) { auto const inst_supports_bool = [&] { switch (operation) { case arithmetic::and_: @@ -206,10 +213,7 @@ arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b if (!inst_supports_bool) { throw compilation_error(loc(), status::ir_boolean_unsupported); } - } else if (isa(*a().ty())) { - if (!isa(*b().ty())) { - throw compilation_error(loc(), status::ir_expected_coopmatrix); - } + } else if (isa(*ty)) { bool inst_supports_coopmatrix = false; switch (operation) { case arithmetic::add: @@ -225,12 +229,8 @@ arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b throw compilation_error(loc(), status::ir_coopmatrix_unsupported); } } else { - auto a_ty = get_scalar_type(loc(), a())->ty(); - auto b_ty = get_scalar_type(loc(), b())->ty(); + auto sty = get_scalar_type(loc(), ty)->ty(); - if (a_ty != b_ty) { - throw compilation_error(loc(), status::ir_scalar_mismatch); - } bool inst_supports_fp = true; bool inst_supports_complex = true; switch (operation) { @@ -254,15 +254,15 @@ arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b inst_supports_complex = false; break; } - if (!inst_supports_fp && is_floating_type(a_ty)) { + if (!inst_supports_fp && is_floating_type(sty)) { throw compilation_error(loc(), status::ir_fp_unsupported); } - if (!inst_supports_complex && is_complex_type(a_ty)) { + if (!inst_supports_complex && is_complex_type(sty)) { throw compilation_error(loc(), status::ir_complex_unsupported); } } - result(0) = value_node{a().ty(), this, lc}; + result(0) = value_node{ty, this, lc}; } arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0, diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index e089ad84..4e1700d0 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -367,7 +367,8 @@ class arith_inst : public standard_inst<2, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::arith; } enum op_number { op_a = 0, op_b = 1 }; - arith_inst(arithmetic op, tinytc_value_t a, tinytc_value_t b, location const &lc = {}); + arith_inst(arithmetic op, tinytc_value_t a, tinytc_value_t b, tinytc_data_type_t ty, + location const &lc = {}); inline arithmetic operation() const { return operation_; } inline auto a() -> tinytc_value & { return op(op_a); } diff --git a/src/parser/lexer.re b/src/parser/lexer.re index c77e2991..84d82230 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -90,7 +90,7 @@ lex: "func" { adv_loc(); return parser::make_FUNC(loc_); } "work_group_size" { adv_loc(); return parser::make_WORK_GROUP_SIZE(loc_); } "subgroup_size" { adv_loc(); return parser::make_SUBGROUP_SIZE(loc_); } - "->" { adv_loc(); return parser::make_RETURNS(loc_); } + "->" { adv_loc(); return parser::make_ARROW(loc_); } "?" { adv_loc(); return parser::make_DYNAMIC(loc_); } ".n" { adv_loc(); return parser::make_NOTRANS(loc_); } ".t" { adv_loc(); return parser::make_TRANS(loc_); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 39fc1f95..d0c377ca 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -90,7 +90,7 @@ FUNC "func" WORK_GROUP_SIZE "work_group_size" SUBGROUP_SIZE "subgroup_size" - RETURNS "->" + ARROW "->" DYNAMIC "?" NOTRANS ".n" TRANS ".t" @@ -466,13 +466,7 @@ instruction: ; axpby_inst: - AXPBY transpose[ta] atomic - var[alpha] COMMA var[a] COMMA var[beta] COMMA var[b] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA scalar_type[fbeta] COMMA memref_type[mb] { - check_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_type($beta, $fbeta, @beta, @fbeta); - check_type($b, $mb, @b, @mb); + AXPBY transpose[ta] atomic var[alpha] COMMA var[a] COMMA var[beta] COMMA var[b] { try { $$ = inst { std::make_unique($ta, std::move($alpha), std::move($a), @@ -529,15 +523,7 @@ optional_local_attr: ; gemm_inst: - GEMM transpose[ta] transpose[tb] atomic - var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] - COMMA memref_type[mc] { - check_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_type($b, $mb, @b, @mb); - check_type($beta, $fbeta, @beta, @fbeta); - check_type($c, $mc, @c, @mc); + GEMM transpose[ta] transpose[tb] atomic var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] { try { $$ = inst { std::make_unique($ta, $tb, std::move($alpha), std::move($a), @@ -553,15 +539,7 @@ gemm_inst: ; gemv_inst: - GEMV transpose[ta] atomic - var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] - COMMA memref_type[mc] { - check_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_type($b, $mb, @b, @mb); - check_type($beta, $fbeta, @beta, @fbeta); - check_type($c, $mc, @c, @mc); + GEMV transpose[ta] atomic var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] { try { $$ = inst { std::make_unique($ta, std::move($alpha), std::move($a), std::move($b), @@ -581,15 +559,7 @@ transpose: ; ger_inst: - GER atomic - var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] - COMMA memref_type[mc] { - check_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_type($b, $mb, @b, @mb); - check_type($beta, $fbeta, @beta, @fbeta); - check_type($c, $mc, @c, @mc); + GER atomic var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] { try { $$ = inst { std::make_unique(std::move($alpha), std::move($a), std::move($b), @@ -648,7 +618,7 @@ optional_step: optional_loop_carried_values: %empty { $$ = {}; } - | INIT LPAREN init_value_list RPAREN RETURNS LPAREN return_type_list RPAREN { + | INIT LPAREN init_value_list RPAREN ARROW LPAREN return_type_list RPAREN { $$ = std::make_tuple(std::move($init_value_list.first), std::move($init_value_list.second), std::move($return_type_list)); } @@ -723,15 +693,7 @@ identifier_list: hadamard_inst: - HADAMARD atomic - var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] - COMMA memref_type[mc] { - check_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_type($b, $mb, @b, @mb); - check_type($beta, $fbeta, @beta, @fbeta); - check_type($c, $mc, @c, @mc); + HADAMARD atomic var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] { try { $$ = inst { std::make_unique(std::move($alpha), std::move($a), std::move($b), @@ -747,13 +709,7 @@ hadamard_inst: ; sum_inst: - SUM transpose[ta] atomic - var[alpha] COMMA var[a] COMMA var[beta] COMMA var[b] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA scalar_type[fbeta] COMMA memref_type[mb] { - check_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_type($beta, $fbeta, @beta, @fbeta); - check_type($b, $mb, @b, @mb); + SUM transpose[ta] atomic var[alpha] COMMA var[a] COMMA var[beta] COMMA var[b] { try { $$ = inst { std::make_unique($ta, std::move($alpha), std::move($a), std::move($beta), @@ -808,7 +764,7 @@ valued_inst: ; alloca_inst: - ALLOCA RETURNS memref_type { + ALLOCA COLON memref_type { try { $$ = inst { std::make_unique(std::move($memref_type), @alloca_inst).release() @@ -826,7 +782,8 @@ arith_inst: check_type($b, $ty, @b, @ty); try { $$ = inst { - std::make_unique($ARITHMETIC, std::move($a), std::move($b), @arith_inst) + std::make_unique($ARITHMETIC, std::move($a), std::move($b), std::move($ty), + @arith_inst) .release() }; } catch (compilation_error const &e) { @@ -854,7 +811,7 @@ arith_unary_inst: cast_inst: - CAST var[a] COLON data_type[from] RETURNS data_type[to] { + CAST var[a] COLON data_type[from] ARROW data_type[to] { check_type($a, $from, @a, @from); try { $$ = inst { std::make_unique(std::move($a), $to, @cast_inst).release() }; @@ -883,7 +840,7 @@ compare_inst: ; constant_inst: - CONSTANT LSQBR FLOATING_CONSTANT[re] COMMA FLOATING_CONSTANT[im] RSQBR RETURNS data_type { + CONSTANT LSQBR FLOATING_CONSTANT[re] COMMA FLOATING_CONSTANT[im] RSQBR ARROW data_type { try { $$ = inst { std::make_unique(std::complex{$re, $im}, $data_type, @constant_inst) @@ -894,7 +851,7 @@ constant_inst: YYERROR; } } - | CONSTANT FLOATING_CONSTANT RETURNS data_type { + | CONSTANT FLOATING_CONSTANT ARROW data_type { try { $$ = inst { std::make_unique($FLOATING_CONSTANT, $data_type, @constant_inst).release() @@ -904,7 +861,7 @@ constant_inst: YYERROR; } } - | CONSTANT INTEGER_CONSTANT RETURNS data_type { + | CONSTANT INTEGER_CONSTANT ARROW data_type { try { $$ = inst { std::make_unique($INTEGER_CONSTANT, $data_type, @constant_inst).release() @@ -914,7 +871,7 @@ constant_inst: YYERROR; } } - | CONSTANT BOOLEAN_CONSTANT RETURNS data_type { + | CONSTANT BOOLEAN_CONSTANT ARROW data_type { try { $$ = inst { std::make_unique($BOOLEAN_CONSTANT, $data_type, @constant_inst).release() @@ -927,7 +884,7 @@ constant_inst: ; cooperative_matrix_load_inst: - COOPERATIVE_MATRIX_LOAD transpose checked var[op] LSQBR var[p0] COMMA var[p1] RSQBR COLON data_type[op_ty] RETURNS data_type[result_ty] { + COOPERATIVE_MATRIX_LOAD transpose checked var[op] LSQBR var[p0] COMMA var[p1] RSQBR COLON data_type[op_ty] ARROW data_type[result_ty] { check_type($op, $op_ty, @op, @op_ty); try { $$ = inst { @@ -949,7 +906,7 @@ checked: ; cooperative_matrix_mul_add_inst: - COOPERATIVE_MATRIX_MUL_ADD var[a] COMMA var[b] COMMA var[c] COLON data_type[a_ty] COMMA data_type[b_ty] COMMA data_type[c_ty] RETURNS data_type[to_ty] { + COOPERATIVE_MATRIX_MUL_ADD var[a] COMMA var[b] COMMA var[c] COLON data_type[a_ty] COMMA data_type[b_ty] COMMA data_type[c_ty] ARROW data_type[to_ty] { check_type($a, $a_ty, @a, @a_ty); check_type($b, $b_ty, @b, @b_ty); check_type($c, $c_ty, @c, @c_ty); @@ -1003,7 +960,7 @@ cooperative_matrix_store_inst: ; expand_inst: - EXPAND var LSQBR INTEGER_CONSTANT[expanded_mode] RETURNS expand_shape RSQBR COLON memref_type { + EXPAND var LSQBR INTEGER_CONSTANT[expanded_mode] ARROW expand_shape RSQBR COLON memref_type { if ($var->ty() != $memref_type) { auto loc = @var; loc.end = @memref_type.end; @@ -1156,7 +1113,7 @@ else_region: optional_returned_values: %empty { $$ = {}; } - | RETURNS LPAREN optional_return_type_list[tys] RPAREN { $$ = std::move($tys); } + | ARROW LPAREN optional_return_type_list[tys] RPAREN { $$ = std::move($tys); } ; optional_return_type_list: diff --git a/src/pass/clone.cpp b/src/pass/clone.cpp index 3ec3dcbf..aefc2265 100644 --- a/src/pass/clone.cpp +++ b/src/pass/clone.cpp @@ -25,7 +25,8 @@ auto inst_cloner::operator()(axpby_inst &in) -> std::unique_ptr { subs(&in.B()), in.atomic(), in.loc()); } auto inst_cloner::operator()(arith_inst &in) -> std::unique_ptr { - return std::make_unique(in.operation(), subs(&in.a()), subs(&in.b()), in.loc()); + return std::make_unique(in.operation(), subs(&in.a()), subs(&in.b()), + in.result(0).ty(), in.loc()); } auto inst_cloner::operator()(arith_unary_inst &in) -> std::unique_ptr { return std::make_unique(in.operation(), subs(&in.a()), in.loc()); diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 7f11e04d..3a667d14 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -80,14 +80,6 @@ void dump_ir_pass::dump_blas_a2(blas_a2_inst const &g) { dump_val(g.beta()); *os_ << ", "; dump_val(g.B()); - *os_ << " : "; - visit(*this, *g.alpha().ty()); - *os_ << ", "; - visit(*this, *g.A().ty()); - *os_ << ", "; - visit(*this, *g.beta().ty()); - *os_ << ", "; - visit(*this, *g.B().ty()); } void dump_ir_pass::dump_blas_a3(blas_a3_inst const &g) { @@ -104,21 +96,11 @@ void dump_ir_pass::dump_blas_a3(blas_a3_inst const &g) { dump_val(g.beta()); *os_ << ", "; dump_val(g.C()); - *os_ << " : "; - visit(*this, *g.alpha().ty()); - *os_ << ", "; - visit(*this, *g.A().ty()); - *os_ << ", "; - visit(*this, *g.B().ty()); - *os_ << ", "; - visit(*this, *g.beta().ty()); - *os_ << ", "; - visit(*this, *g.C().ty()); } void dump_ir_pass::operator()(alloca_inst const &a) { dump_val(a.result(0)); - *os_ << " = alloca -> "; + *os_ << " = alloca : "; visit(*this, *a.result()->ty()); } @@ -135,7 +117,7 @@ void dump_ir_pass::operator()(arith_inst const &a) { *os_ << ", "; dump_val(a.b()); *os_ << " : "; - visit(*this, *a.a().ty()); + visit(*this, *a.result(0).ty()); } void dump_ir_pass::operator()(arith_unary_inst const &a) { diff --git a/src/pass/lower_foreach.cpp b/src/pass/lower_foreach.cpp index f61d5d57..979c3140 100644 --- a/src/pass/lower_foreach.cpp +++ b/src/pass/lower_foreach.cpp @@ -17,16 +17,16 @@ namespace tinytc { template void make_loop0(region_builder &bb, value from, value to, value sg_id, int sgs, int num_tiles, F &&make_body, location const &loc) { + auto ity = from->ty(); auto ctx = compiler_context{sg_id->context(), true}; - auto index_ty = get_scalar(ctx, scalar_type::index); auto sg_lid_i32 = bb.add(make_subgroup_local_id(ctx)); - auto sg_lid = bb.add(make_cast(sg_lid_i32, index_ty)); - auto size = bb.add(make_arith(arithmetic::sub, to, from, loc)); - auto work_item_offset = bb.add(make_arith(arithmetic::add, from, sg_lid)); + auto sg_lid = bb.add(make_cast(sg_lid_i32, ity)); + auto size = bb.add(make_arith(arithmetic::sub, to, from, ity, loc)); + auto work_item_offset = bb.add(make_arith(arithmetic::add, from, sg_lid, ity)); tile_loop_by_sgs_new( bb, size, sgs, num_tiles, sg_id, [&](region_builder &bb, value block, bool is_remainder, value trip_count) { - auto loop_var0 = bb.add(make_arith(arithmetic::add, block, work_item_offset)); + auto loop_var0 = bb.add(make_arith(arithmetic::add, block, work_item_offset, ity)); if (is_remainder) { auto cond = bb.add(make_cmp(cmp_condition::lt, sg_lid, trip_count)); bb.if_condition(cond, [&](region_builder &bb) { make_body(bb, loop_var0); }); @@ -53,30 +53,27 @@ auto foreach_generator::operator()(foreach_inst &in) -> inst { tinytc_region_t body = ¶llel->child_region(0); auto bb = region_builder{body}; - auto ctx = compiler_context{in.context(), true}; - auto i32_ty = get_scalar(ctx, scalar_type::i32); - auto index_ty = get_scalar(ctx, scalar_type::index); - - auto sg_id = bb.add(make_subgroup_id(ctx, in.loc())); + auto sg_id = bb.add(make_subgroup_id(compiler_context{in.context(), true}, in.loc())); auto cloner = inst_cloner{}; auto loop_vars = in.loop_vars().begin(); auto from = in.from().begin(); auto to = in.to().begin(); + auto ity = (*from).ty(); if (in.dim() > 1) { auto const make_inner_loop_nest = [&](region_builder &bb, value from1, value to1) { tinytc_region_t current_region = bb.get_region().get(); for (std::int64_t i = in.dim() - 1; i > 1; --i) { auto for_i = std::make_unique( - &from[i], &to[i], nullptr, array_view{}, index_ty, in.loc()); + &from[i], &to[i], nullptr, array_view{}, ity, in.loc()); cloner.set_subs(&loop_vars[i], &for_i->loop_var()); tinytc_region_t next_region = &for_i->body(); current_region->insts().push_back(for_i.release()); current_region = next_region; } region_builder{current_region}.for_loop( - from1, to1, index_ty, + from1, to1, ity, [&](region_builder &bb, value loop_var1) { cloner.set_subs(&loop_vars[1], loop_var1.get()); cloner.clone_region(in.body(), *bb.get_region()); @@ -84,16 +81,16 @@ auto foreach_generator::operator()(foreach_inst &in) -> inst { in.loc()); }; - auto c_m_tiles = bb.add(make_constant(tiling_.m_tiles(), i32_ty, in.loc())); - auto sg_id1 = bb.add(make_arith(arithmetic::div, sg_id, c_m_tiles, in.loc())); - auto sg_id0 = bb.add(make_arith(arithmetic::rem, sg_id, c_m_tiles, in.loc())); + auto c_m_tiles = bb.add(make_constant(tiling_.m_tiles(), sg_id->ty(), in.loc())); + auto sg_id1 = bb.add(make_arith(arithmetic::div, sg_id, c_m_tiles, sg_id->ty(), in.loc())); + auto sg_id0 = bb.add(make_arith(arithmetic::rem, sg_id, c_m_tiles, sg_id->ty(), in.loc())); - auto size1 = bb.add(make_arith(arithmetic::sub, &to[1], &from[1], in.loc())); + auto size1 = bb.add(make_arith(arithmetic::sub, &to[1], &from[1], ity, in.loc())); tile_loop_uniformly_new( bb, size1, core_cfg_.subgroup_size, tiling_.n_tiles(), sg_id1, [&](region_builder &bb, value block, value trip_count1) { - auto from1 = bb.add(make_arith(arithmetic::add, &from[1], block)); - auto to1 = bb.add(make_arith(arithmetic::add, from1, trip_count1)); + auto from1 = bb.add(make_arith(arithmetic::add, &from[1], block, ity, in.loc())); + auto to1 = bb.add(make_arith(arithmetic::add, from1, trip_count1, ity, in.loc())); make_loop0( bb, &from[0], &to[0], sg_id0, core_cfg_.subgroup_size, tiling_.m_tiles(), [&](region_builder &bb, value loop_var0) { diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index 93ed63b8..5e23c06b 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -93,8 +93,10 @@ void gemm_microkernel(region_builder &bb, transpose tA, transpose tB, bool atomi auto c_zero = bb.add(make_constant_zero(index_ty, loc)); auto c_k_block_size = bb.add(make_constant(k_block_size, index_ty, loc)); - auto tmp = instant_constant_fold_add(bb, make_arith(arithmetic::div, K, c_k_block_size, loc)); - auto K0 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, tmp, c_k_block_size, loc)); + auto tmp = instant_constant_fold_add( + bb, make_arith(arithmetic::div, K, c_k_block_size, index_ty, loc)); + auto K0 = instant_constant_fold_add( + bb, make_arith(arithmetic::mul, tmp, c_k_block_size, index_ty, loc)); c_acc = compute_c(bb, k_block_size, c_zero, K0, c_acc); auto needs_remainder = instant_constant_fold_add(bb, make_cmp(cmp_condition::lt, K0, K, loc)); auto r = get_bool_constant(needs_remainder); @@ -124,7 +126,8 @@ void gemm_microkernel(region_builder &bb, transpose tA, transpose tB, bool atomi auto c_load = bb.add(make_cooperative_matrix_load(transpose::N, check_c, C, m_block, n_block, coopmatrix_c_ty)); auto beta_c = mixed_precision_coopmatrix_scale(bb, beta, c_load, loc); - auto alpha_ab_plus_beta_c = bb.add(make_arith(arithmetic::add, alpha_ab, beta_c, loc)); + auto alpha_ab_plus_beta_c = + bb.add(make_arith(arithmetic::add, alpha_ab, beta_c, alpha_ab->ty(), loc)); bb.add(make_cooperative_matrix_store(check_c, store_flag::regular, alpha_ab_plus_beta_c, C, m_block, n_block, loc)); } @@ -195,7 +198,7 @@ void linalg_generator::operator()(axpby_inst &in) { auto c0 = bb.add(make_constant(0, i32_ty)); auto cond0 = bb.add(make_cmp(cmp_condition::eq, sg_id, c0)); auto cond1 = bb.add(make_cmp(cmp_condition::eq, sg_lid, c0)); - auto cond = bb.add(make_arith(arithmetic::and_, cond0, cond1)); + auto cond = bb.add(make_arith(arithmetic::and_, cond0, cond1, cond0->ty())); bb.if_condition(cond, [&](region_builder &bb) { auto a = bb.add(make_load(&in.A(), {}, in.loc())); blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), {}, in.loc()); @@ -263,8 +266,8 @@ void linalg_generator::operator()(gemm_inst &in) { auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); auto c_m_tiles = bb.add(make_constant(tiling_.m_tiles(), i32_ty, in.loc())); - auto sg_n = bb.add(make_arith(arithmetic::div, sgid, c_m_tiles, in.loc())); - auto sg_m = bb.add(make_arith(arithmetic::rem, sgid, c_m_tiles, in.loc())); + auto sg_n = bb.add(make_arith(arithmetic::div, sgid, c_m_tiles, i32_ty, in.loc())); + auto sg_m = bb.add(make_arith(arithmetic::rem, sgid, c_m_tiles, i32_ty, in.loc())); auto [max_rows, max_cols] = max_register_block_gemm( size(ct->element_ty()), core_cfg_.subgroup_size, core_cfg_.register_space, @@ -378,8 +381,8 @@ void linalg_generator::operator()(sum_inst &in) { auto c_sgs = bb.add(make_constant(core_cfg_.subgroup_size, i32_ty, in.loc())); auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); auto m = bb.add(make_subgroup_local_id(ctx, in.loc())); - auto from0 = bb.add(make_arith(arithmetic::mul, sgid, c_sgs, in.loc())); - auto from1 = bb.add(make_arith(arithmetic::add, from0, m, in.loc())); + auto from0 = bb.add(make_arith(arithmetic::mul, sgid, c_sgs, i32_ty, in.loc())); + auto from1 = bb.add(make_arith(arithmetic::add, from0, m, i32_ty, in.loc())); auto from_index = bb.add(make_cast(from1, index_ty, in.loc())); auto c_zero = bb.add(make_constant_zero(i32_ty, in.loc())); diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 77d307a8..51445277 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -108,7 +108,8 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( bool is_beta_nonzero, value beta_arg, value C) { auto c_M_block_size = bb.add(make_constant(M_block_size, index_ty, my_loc())); auto gid = bb.add(make_group_id(ctx_, my_loc())); - auto m = bb.add(make_arith(arithmetic::mul, gid, c_M_block_size, my_loc())); + auto m = bb.add( + make_arith(arithmetic::mul, gid, c_M_block_size, gid.get_type(), my_loc())); auto beta = is_beta_nonzero ? beta_arg : bb.add(make_constant_zero(ty_, my_loc())); auto const static_offsets = std::array{dynamic, 0}; @@ -141,7 +142,8 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( } else { auto M_val = bb.add(make_size(C, 0, my_loc())); - auto M_val_sub_m = bb.add(make_arith(arithmetic::sub, M_val, m, my_loc())); + auto M_val_sub_m = + bb.add(make_arith(arithmetic::sub, M_val, m, m.get_type(), my_loc())); auto cond = bb.add(make_cmp(cmp_condition::lt, M_val_sub_m, c_M_block_size, my_loc())); bb.ifelse( diff --git a/src/value.cpp b/src/value.cpp index 3f87165b..beca1e41 100644 --- a/src/value.cpp +++ b/src/value.cpp @@ -30,6 +30,13 @@ tinytc_status_t tinytc_value_get_name(const_tinytc_value_t vl, char const **name if (vl == nullptr || name == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { return vl->name(); }); + return exception_to_status_code([&] { *name = vl->name(); }); +} + +tinytc_status_t tinytc_value_get_type(const_tinytc_value_t vl, tinytc_data_type_t *ty) { + if (vl == nullptr || ty == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *ty = vl->ty(); }); } } diff --git a/test/codegen/axpby0.ir b/test/codegen/axpby0.ir index 23ccfefc..33bc2196 100644 --- a/test/codegen/axpby0.ir +++ b/test/codegen/axpby0.ir @@ -4,6 +4,6 @@ ; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s func @axpby(%alpha: f32, %A: memref, %B: memref) { %zero = constant 0.0 -> f32 - axpby.n %alpha, %A, %zero, %B : f32, memref, f32, memref -; CHECK: 7.5-79: Incompatible tensor shapes + axpby.n %alpha, %A, %zero, %B +; CHECK: 7.5-33: Incompatible tensor shapes } diff --git a/test/codegen/axpby1.ir b/test/codegen/axpby1.ir index 9ebd146c..41c3a9e4 100644 --- a/test/codegen/axpby1.ir +++ b/test/codegen/axpby1.ir @@ -4,17 +4,17 @@ ; RUN: %tinytc-oc < %s func @axpby0(%alpha: f32, %A: memref, %B: memref) { %z = constant 0.0 -> f32 - axpby.n %alpha, %A, %z, %B : f32, memref, f32, memref + axpby.n %alpha, %A, %z, %B } func @axpby1(%alpha: f32, %A: memref>, %B: memref) { %z = constant 0.0 -> f32 - axpby.n %alpha, %A, %z, %B : f32, memref>, f32, memref + axpby.n %alpha, %A, %z, %B } func @axpby2(%alpha: f32, %A: memref, %B: memref) { %z = constant 0.0 -> f32 - axpby.n %alpha, %A, %z, %B : f32, memref, f32, memref + axpby.n %alpha, %A, %z, %B } func @axpby3(%alpha: f32, %A: memref, %B: memref) { @@ -28,7 +28,7 @@ func @axpby3(%alpha: f32, %A: memref, %B: memref) for %j=%lb,%ub1 { %A1 = subview %A0[0:48,0:48,%j] : memref %B1 = subview %B0[0:48,0:48,%j] : memref - axpby.t %alpha, %A1, %z, %B1 : f32, memref, f32, memref + axpby.t %alpha, %A1, %z, %B1 } } } diff --git a/test/opt/check-ir/nesting0.ir b/test/opt/check-ir/nesting0.ir index c1bf6f07..4821a7be 100644 --- a/test/opt/check-ir/nesting0.ir +++ b/test/opt/check-ir/nesting0.ir @@ -6,7 +6,7 @@ func @illegal_nesting(%c: f32, %A: memref, %B: memref, %C: mem %lb = constant 1 -> index %ub = constant 16 -> index foreach (%i)=(%lb),(%ub) { - gemm.n.n %c, %A, %B, %c, %C : f32, memref, memref, f32, memref + gemm.n.n %c, %A, %B, %c, %C } -; CHECK: 9.9-97: Collective instruction must not be called from SPMD region +; CHECK: 9.9-35: Collective instruction must not be called from SPMD region } diff --git a/test/opt/dead-code-elimination.ir b/test/opt/dead-code-elimination.ir index cca54f1b..6cdb62f3 100644 --- a/test/opt/dead-code-elimination.ir +++ b/test/opt/dead-code-elimination.ir @@ -65,12 +65,12 @@ func @dead_loop(%a: memref) { } func @unused_alloca(%a: memref) { - %0 = alloca -> memref - %1 = alloca -> memref + %0 = alloca : memref + %1 = alloca : memref %one = constant 1.0 -> f64 - axpby.n %one, %1, %one, %a : f64, memref, f64, memref + axpby.n %one, %1, %one, %a ; CHECK-LABEL: func @unused_alloca({{.*}} -; CHECK-NEXT: %0 = alloca -> memref +; CHECK-NEXT: %0 = alloca : memref ; CHECK-NEXT: %one{{.*}} ; CHECK-NEXT: axpby.n %one, %0{{.*}} } diff --git a/test/opt/insert-barrier.ir b/test/opt/insert-barrier.ir index 06ebd842..3e398bfc 100644 --- a/test/opt/insert-barrier.ir +++ b/test/opt/insert-barrier.ir @@ -3,76 +3,76 @@ ; RUN: %tinytc-opt -pinsert-barrier < %s | filecheck %s func @rar(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { - gemm.n.n %a, %A, %B, %b, %C : f32, memref, memref, f32, memref - gemm.n.n %a, %A, %B, %b, %D : f32, memref, memref, f32, memref + gemm.n.n %a, %A, %B, %b, %C + gemm.n.n %a, %A, %B, %b, %D ; CHECK-LABEL: func @rar({{.*}} -; CHECK: gemm.n.n %a, %A, %B, %b, %C{{.*}} -; CHECK-NEXT: gemm.n.n %a, %A, %B, %b, %D{{.*}} +; CHECK: gemm.n.n %a, %A, %B, %b, %C +; CHECK-NEXT: gemm.n.n %a, %A, %B, %b, %D } func @raw(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { - gemm.n.n %a, %A, %B, %b, %C : f32, memref, memref, f32, memref - gemm.n.n %a, %C, %B, %b, %D : f32, memref, memref, f32, memref + gemm.n.n %a, %A, %B, %b, %C + gemm.n.n %a, %C, %B, %b, %D ; CHECK-LABEL: func @raw({{.*}} -; CHECK: gemm.n.n %a, %A, %B, %b, %C{{.*}} +; CHECK: gemm.n.n %a, %A, %B, %b, %C ; CHECK-NEXT: barrier.global -; CHECK-NEXT: gemm.n.n %a, %C, %B, %b, %D{{.*}} +; CHECK-NEXT: gemm.n.n %a, %C, %B, %b, %D } func @war(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { - gemm.n.n %a, %C, %B, %b, %A : f32, memref, memref, f32, memref - gemm.n.n %a, %D, %B, %b, %C : f32, memref, memref, f32, memref + gemm.n.n %a, %C, %B, %b, %A + gemm.n.n %a, %D, %B, %b, %C ; CHECK-LABEL: func @war({{.*}} -; CHECK: gemm.n.n %a, %C, %B, %b, %A{{.*}} +; CHECK: gemm.n.n %a, %C, %B, %b, %A ; CHECK-NEXT: barrier.global -; CHECK-NEXT: gemm.n.n %a, %D, %B, %b, %C{{.*}} +; CHECK-NEXT: gemm.n.n %a, %D, %B, %b, %C } func @waw(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { - gemm.n.n %a, %A, %B, %b, %C : f32, memref, memref, f32, memref - gemm.n.n %a, %A, %B, %b, %C : f32, memref, memref, f32, memref + gemm.n.n %a, %A, %B, %b, %C + gemm.n.n %a, %A, %B, %b, %C ; CHECK-LABEL: func @waw({{.*}} -; CHECK: gemm.n.n %a, %A, %B, %b, %C{{.*}} +; CHECK: gemm.n.n %a, %A, %B, %b, %C ; CHECK-NEXT: barrier.global -; CHECK-NEXT: gemm.n.n %a, %A, %B, %b, %C{{.*}} +; CHECK-NEXT: gemm.n.n %a, %A, %B, %b, %C } func @raw_local(%a: f32, %b: f32, %A: memref, %B: memref, %D: memref) { - %C = alloca -> memref - gemm.n.n %a, %A, %B, %b, %C : f32, memref, memref, f32, memref - gemm.n.n %a, %C, %B, %b, %D : f32, memref, memref, f32, memref + %C = alloca : memref + gemm.n.n %a, %A, %B, %b, %C + gemm.n.n %a, %C, %B, %b, %D ; CHECK-LABEL: func @raw_local({{.*}} -; CHECK: gemm.n.n %a, %A, %B, %b, %C{{.*}} +; CHECK: gemm.n.n %a, %A, %B, %b, %C ; CHECK-NEXT: barrier.local -; CHECK-NEXT: gemm.n.n %a, %C, %B, %b, %D{{.*}} +; CHECK-NEXT: gemm.n.n %a, %C, %B, %b, %D } func @raw_local_war_global(%a: f32, %b: f32, %A: memref, %B: memref, %D: memref) { - %C = alloca -> memref - gemm.n.n %a, %A, %B, %b, %C : f32, memref, memref, f32, memref - gemm.n.n %a, %C, %B, %b, %A : f32, memref, memref, f32, memref + %C = alloca : memref + gemm.n.n %a, %A, %B, %b, %C + gemm.n.n %a, %C, %B, %b, %A ; CHECK-LABEL: func @raw_local_war_global({{.*}} -; CHECK: gemm.n.n %a, %A, %B, %b, %C{{.*}} +; CHECK: gemm.n.n %a, %A, %B, %b, %C ; CHECK-NEXT: barrier.global.local -; CHECK-NEXT: gemm.n.n %a, %C, %B, %b, %A{{.*}} +; CHECK-NEXT: gemm.n.n %a, %C, %B, %b, %A } func @respect_manual_barrier(%a: f32, %b: f32, %A: memref, %B: memref, %D: memref) { - %C = alloca -> memref - gemm.n.n %a, %A, %B, %b, %C : f32, memref, memref, f32, memref + %C = alloca : memref + gemm.n.n %a, %A, %B, %b, %C barrier.global.local - gemm.n.n %a, %C, %B, %b, %A : f32, memref, memref, f32, memref + gemm.n.n %a, %C, %B, %b, %A ; CHECK-LABEL: func @respect_manual_barrier({{.*}} -; CHECK: gemm.n.n %a, %A, %B, %b, %C{{.*}} +; CHECK: gemm.n.n %a, %A, %B, %b, %C ; CHECK-NEXT: barrier.global.local -; CHECK-NEXT: gemm.n.n %a, %C, %B, %b, %A{{.*}} +; CHECK-NEXT: gemm.n.n %a, %C, %B, %b, %A } func @war_alias(%a: f32, %b: f32, %A: memref, %C: memref) { - %B = alloca -> memref + %B = alloca : memref %0 = subview %B[0:8,0:8] : memref - axpby.n %a, %B, %b, %C : f32, memref, f32, memref - axpby.n %a, %A, %b, %0 : f32, memref, f32, memref + axpby.n %a, %B, %b, %C + axpby.n %a, %A, %b, %0 ; CHECK-LABEL: func @war_alias({{.*}} ; CHECK: axpby.n %a, %B, %b, %C{{.*}} ; CHECK-NEXT: barrier.local @@ -83,12 +83,12 @@ func @if(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref< %c42 = constant 42.0 -> f32 %0 = cmp.gt %a, %c42 : f32 if %0 { - axpby.n %a, %A, %b, %B : f32, memref, f32, memref - axpby.n %a, %B, %b, %C : f32, memref, f32, memref + axpby.n %a, %A, %b, %B + axpby.n %a, %B, %b, %C } else { - axpby.n %a, %C, %b, %D : f32, memref, f32, memref + axpby.n %a, %C, %b, %D } - axpby.n %a, %A, %b, %B : f32, memref, f32, memref + axpby.n %a, %A, %b, %B ; CHECK-LABEL: func @if({{.*}} ; CHECK: if %0 { ; CHECK-NEXT: axpby.n %a, %A, %b, %B{{.*}} @@ -104,14 +104,14 @@ func @if(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref< func @if2(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { %c42 = constant 42.0 -> f32 %0 = cmp.gt %a, %c42 : f32 - axpby.n %a, %B, %b, %A : f32, memref, f32, memref + axpby.n %a, %B, %b, %A if %0 { - axpby.n %a, %A, %b, %B : f32, memref, f32, memref - axpby.n %a, %B, %b, %C : f32, memref, f32, memref + axpby.n %a, %A, %b, %B + axpby.n %a, %B, %b, %C } else { - axpby.n %a, %C, %b, %D : f32, memref, f32, memref + axpby.n %a, %C, %b, %D } - axpby.n %a, %A, %b, %B : f32, memref, f32, memref + axpby.n %a, %A, %b, %B ; CHECK-LABEL: func @if2({{.*}} ; CHECK: if %0 { ; CHECK-NEXT: barrier.global @@ -128,26 +128,25 @@ func @if2(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref func @region1() { %one = constant 1.0 -> f32 %zero = constant 0.0 -> f32 - %0 = alloca -> memref + %0 = alloca : memref %lb = constant 0 -> index %ub = constant 4 -> index for %i=%lb,%ub : index { - %1 = alloca -> memref + %1 = alloca : memref for %k=%lb,%ub : index { - %2 = alloca -> memref + %2 = alloca : memref gemm.n.n %one, %0, %1, %zero, %2 - : f32, memref, memref, f32, memref - axpby.n %one, %1, %zero, %0 : f32, memref, f32, memref + axpby.n %one, %1, %zero, %0 } - axpby.n %one, %0, %zero, %1 : f32, memref, f32, memref + axpby.n %one, %0, %zero, %1 } ; CHECK-LABEL: func @region1({{.*}} ; CHECK: for %i=%lb,%ub : index { -; CHECK-NEXT: %1 = alloca -> memref +; CHECK-NEXT: %1 = alloca : memref ; CHECK-NEXT: for %k=%lb,%ub : index { -; CHECK-NEXT: %2 = alloca -> memref +; CHECK-NEXT: %2 = alloca : memref ; CHECK-NEXT: barrier.local -; CHECK-NEXT: gemm.n.n %one, %0, %1, %zero, %2{{.*}} +; CHECK-NEXT: gemm.n.n %one, %0, %1, %zero, %2 ; CHECK-NEXT: barrier.local ; CHECK-NEXT: axpby.n %one, %1, %zero, %0{{.*}} ; CHECK-NEXT: } diff --git a/test/opt/insert-lifetime-stop.ir b/test/opt/insert-lifetime-stop.ir index d06605f0..3bc8c7bf 100644 --- a/test/opt/insert-lifetime-stop.ir +++ b/test/opt/insert-lifetime-stop.ir @@ -3,16 +3,16 @@ ; RUN: %tinytc-opt -pinsert-lifetime-stop < %s | filecheck %s func @basic() { - %0 = alloca -> memref -; CHECK: %0 = alloca -> memref + %0 = alloca : memref +; CHECK: %0 = alloca : memref ; CHECK-NEXT: lifetime_stop %0 } func @use1(%A: memref, %C: memref) { ; CHECK-LABEL: func @use1{{.*}} - %B = alloca -> memref + %B = alloca : memref %one = constant 1.0 -> f32 - gemm.n.n %one, %A, %B, %one, %C : f32, memref, memref, f32, memref + gemm.n.n %one, %A, %B, %one, %C ; CHECK: gemm.n.n{{.*}} ; CHECK-NEXT: lifetime_stop %B } @@ -20,11 +20,11 @@ func @use1(%A: memref, %C: memref) { func @use2(%A: memref, %C: memref) { ; CHECK-LABEL: func @use2{{.*}} %one = constant 1.0 -> f32 - %B = alloca -> memref - gemm.n.n %one, %A, %B, %one, %C : f32, memref, memref, f32, memref - %B2 = alloca -> memref - gemm.n.n %one, %A, %B, %one, %C : f32, memref, memref, f32, memref - gemm.n.n %one, %A, %B2, %one, %C : f32, memref, memref, f32, memref + %B = alloca : memref + gemm.n.n %one, %A, %B, %one, %C + %B2 = alloca : memref + gemm.n.n %one, %A, %B, %one, %C + gemm.n.n %one, %A, %B2, %one, %C ; CHECK: %B2 = {{.*}} ; CHECK-NEXT: gemm.n.n{{.*}} ; CHECK-NEXT: lifetime_stop %B @@ -34,10 +34,10 @@ func @use2(%A: memref, %C: memref) { func @use_alias(%a: f32, %A: memref, %C: memref) { ; CHECK-LABEL: func @use_alias{{.*}} - %B = alloca -> memref + %B = alloca : memref %0 = fuse %B[1,3] : memref %1 = subview %0[0:8,0:8] : memref - gemm.n.n %a, %A, %1, %a, %C : f32, memref, memref,local>, f32, memref + gemm.n.n %a, %A, %1, %a, %C ; CHECK: gemm.n.n{{.*}} ; CHECK-NEXT: lifetime_stop %B } @@ -45,16 +45,15 @@ func @use_alias(%a: f32, %A: memref, %C: memref) { func @region1() { ; CHECK-LABEL: func @region1{{.*}} %one = constant 1.0 -> f32 - %0 = alloca -> memref + %0 = alloca : memref %lb = constant 0 -> index %ub = constant 4 -> index for %i=%lb,%ub : index { - %1 = alloca -> memref + %1 = alloca : memref for %k=%lb,%ub : index { - %2 = alloca -> memref + %2 = alloca : memref gemm.n.n %one, %0, %1, %one, %2 - : f32, memref, memref, f32, memref - axpby.n %one, %0, %one, %1 : f32, memref, f32, memref + axpby.n %one, %0, %one, %1 } } ; CHECK: gemm.n.n{{.*}} diff --git a/test/opt/work-group-size.ir b/test/opt/work-group-size.ir index 98086c96..0a048528 100644 --- a/test/opt/work-group-size.ir +++ b/test/opt/work-group-size.ir @@ -8,28 +8,27 @@ func @default_pvc() { func @f32_blas() { ; CHECK: func @f32_blas() subgroup_size(32) work_group_size(128,2) { - %0 = alloca -> memref - %1 = alloca -> memref + %0 = alloca : memref + %1 = alloca : memref %one = constant 1.0 -> f32 %zero = constant 0.0 -> f32 %lb = constant 0 -> index %ub = constant 4 -> index for %i=%lb,%ub { - axpby.n %one, %0, %zero, %1 : f32, memref, f32, memref + axpby.n %one, %0, %zero, %1 } } func @f64_blas() { ; CHECK: func @f64_blas() subgroup_size(16) work_group_size(128,8) { - %0 = alloca -> memref - %1 = alloca -> memref - %2 = alloca -> memref + %0 = alloca : memref + %1 = alloca : memref + %2 = alloca : memref %one = constant 1.0 -> f64 %zero = constant 0.0 -> f64 %lb = constant 0 -> index %ub = constant 4 -> index for %i=%lb,%ub { gemm.n.n %one, %0, %1, %zero, %2 - : f64, memref, memref, f64, memref } } diff --git a/test/spv/alloca.ir b/test/spv/alloca.ir index d48f54a5..9d0aaac0 100644 --- a/test/spv/alloca.ir +++ b/test/spv/alloca.ir @@ -26,9 +26,9 @@ func @alloca() { %c0 = constant 0 -> index - %0 = alloca -> memref - %1 = alloca -> memref - %2 = alloca -> memref + %0 = alloca : memref + %1 = alloca : memref + %2 = alloca : memref %3 = load %0[%c0] : memref %4 = load %1[%c0,%c0] : memref %5 = load %2[] : memref From dcd25895202fd6ddda54be39e2961b6df47d0054 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 21 Nov 2024 15:17:30 +0100 Subject: [PATCH 118/297] Implement language change for arith_unary and improve error reporting Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 2 +- include/tinytc/tinytc.h | 2 + include/tinytc/tinytc.hpp | 9 +- include/tinytc/types.h | 64 +++---- include/tinytc/types.hpp | 8 +- src/compiler_context.cpp | 20 ++- src/compiler_context.hpp | 3 + src/error.cpp | 43 +++-- src/error.hpp | 12 +- src/inst.cpp | 5 +- src/node/inst_node.cpp | 213 ++++++++++++++---------- src/node/inst_node.hpp | 3 +- src/parser/parser_impl.yy | 108 ++++++------ src/pass/clone.cpp | 3 +- src/pass/convert_to_opencl.cpp | 4 +- src/pass/dump_ir.cpp | 2 +- test/codegen/scalar_arithmetic.ir | 6 +- test/codegen/scalar_arithmetic_error.ir | 6 +- test/opt/constant-propagation.ir | 12 +- test/spv/arith_unary.ir | 6 +- tools/offline_compiler/main.cpp | 3 + tools/opt/main.cpp | 3 + 22 files changed, 324 insertions(+), 213 deletions(-) diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index d8fa95ea..2d2fb113 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -426,7 +426,7 @@ Restrictions ~~~~~~~~~~~~ * :math:`\text{shape}(B) = \text{shape}(\text{op}(A))` -* :math:`\text{order}(B) = 1 \lor \text{order}(B) = 2` +* :math:`\text{order}(B) = 0 \lor \text{order}(B) = 1 \lor \text{order}(B) = 2` * :math:`\text{type}(\alpha) \preceq \text{element_type}(A)` * :math:`\text{type}(\beta) \preceq \text{element_type}(B)` * If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 30924a2c..cf2f82a9 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -232,6 +232,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_arith_inst_create(tinytc_inst_t *instr, tin * @param instr [out] pointer to the inst object created * @param op [in] unary arithmetic operation type * @param a [in] operand + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise @@ -239,6 +240,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_arith_inst_create(tinytc_inst_t *instr, tin TINYTC_EXPORT tinytc_status_t tinytc_arith_unary_inst_create(tinytc_inst_t *instr, tinytc_arithmetic_unary_t op, tinytc_value_t a, + tinytc_data_type_t ty, const tinytc_location_t *loc); /** diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index d356defe..c796ae03 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -892,15 +892,16 @@ inline inst make_arith(arithmetic op, value a, value b, data_type ty, location c * * @param op Arithmetic operation type * @param a Operand + * @param ty Result type * @param loc Source code location * * @return Instruction */ -inline inst make_arith(arithmetic_unary op, value a, location const &loc = {}) { +inline inst make_arith(arithmetic_unary op, value a, data_type ty, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC( - tinytc_arith_unary_inst_create(&instr, static_cast(op), a, &loc), - loc); + CHECK_STATUS_LOC(tinytc_arith_unary_inst_create( + &instr, static_cast(op), a, ty, &loc), + loc); return inst(instr); } diff --git a/include/tinytc/types.h b/include/tinytc/types.h index eb5dd1b6..686671ff 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -60,39 +60,43 @@ typedef enum { tinytc_status_ir_expected_memref = 0x10c, ///< Expected a value of memref type tinytc_status_ir_expected_memref_or_scalar = 0x10d, ///< Expected memref or scalar type tinytc_status_ir_expected_memref_or_group = 0x10e, ///< Expected a value of memref or group type - tinytc_status_ir_expected_matrix = 0x10f, ///< Expected a marix - tinytc_status_ir_expected_vector_or_matrix = 0x110, ///< Expected a vector or marix - tinytc_status_ir_unexpected_yield = 0x111, ///< Unexpected yield instruction - tinytc_status_ir_yield_mismatch = 0x112, ///< Wrong number of yielded values - tinytc_status_ir_subview_mismatch = 0x113, ///< Mismatch in subview - tinytc_status_ir_invalid_slice = 0x114, ///< Invalid slice - tinytc_status_ir_expand_shape_order_too_small = 0x115, ///< Expand shape too small - tinytc_status_ir_expand_shape_mismatch = 0x116, ///< Invalid expand shape - tinytc_status_ir_collective_called_from_spmd = 0x117, ///< Collective instruction from SPMD - tinytc_status_ir_fp_unsupported = 0x118, ///< Instruction does not support floating type - tinytc_status_ir_spmd_called_from_collective = 0x119, ///< SPMD instruction from collective - tinytc_status_ir_expected_local_address_space = 0x11a, ///< Expected local address space - tinytc_status_ir_expected_global_address_space = 0x11b, ///< Expected global address space - tinytc_status_ir_invalid_offset = 0x11c, ///< Invalid offset - tinytc_status_ir_int_unsupported = 0x11d, ///< Instruction does not support int type - tinytc_status_ir_boolean_unsupported = 0x11e, ///< Instruction does not support boolean type - tinytc_status_ir_complex_unsupported = 0x11f, ///< Instruction does not support complex type + tinytc_status_ir_expected_memref_order_0 = 0x10f, ///< Expected memref of order 0 + tinytc_status_ir_expected_memref_order_1 = 0x110, ///< Expected memref of order 1 + tinytc_status_ir_expected_memref_order_2 = 0x111, ///< Expected memref of order 2 + tinytc_status_ir_expected_memref_order_0_or_1 = 0x112, ///< Expected memref of order 0 or 1 + tinytc_status_ir_expected_memref_order_1_or_2 = 0x113, ///< Expected memref of order 1 or 2 + tinytc_status_ir_expected_memref_order_0_1_or_2 = 0x114, ///< Expected memref of order 0, 1 or 2 + tinytc_status_ir_unexpected_yield = 0x115, ///< Unexpected yield instruction + tinytc_status_ir_yield_mismatch = 0x116, ///< Wrong number of yielded values + tinytc_status_ir_subview_mismatch = 0x117, ///< Mismatch in subview + tinytc_status_ir_invalid_slice = 0x118, ///< Invalid slice + tinytc_status_ir_expand_shape_order_too_small = 0x119, ///< Expand shape too small + tinytc_status_ir_expand_shape_mismatch = 0x11a, ///< Invalid expand shape + tinytc_status_ir_collective_called_from_spmd = 0x11b, ///< Collective instruction from SPMD + tinytc_status_ir_fp_unsupported = 0x11c, ///< Instruction does not support floating type + tinytc_status_ir_spmd_called_from_collective = 0x11d, ///< SPMD instruction from collective + tinytc_status_ir_expected_local_address_space = 0x11e, ///< Expected local address space + tinytc_status_ir_expected_global_address_space = 0x11f, ///< Expected global address space + tinytc_status_ir_invalid_offset = 0x120, ///< Invalid offset + tinytc_status_ir_int_unsupported = 0x121, ///< Instruction does not support int type + tinytc_status_ir_boolean_unsupported = 0x122, ///< Instruction does not support boolean type + tinytc_status_ir_complex_unsupported = 0x123, ///< Instruction does not support complex type tinytc_status_ir_coopmatrix_unsupported = - 0x120, ///< Instruction does not support coopmatrix type - tinytc_status_ir_forbidden_cast = 0x121, ///< Forbidden cast - tinytc_status_ir_invalid_beta = 0x122, ///< Invalid beta value - tinytc_status_ir_init_return_mismatch = 0x123, ///< Mismatch of init values and returned values - tinytc_status_ir_invalid_matrix_use = 0x124, ///< Invalid matrix use - tinytc_status_ir_unsupported_coopmatrix_shape = 0x125, ///< Unsupported coopmatrix shape - tinytc_status_ir_incompatible_scalar_types = 0x126, ///< Incompatible scalar types - tinytc_status_ir_constant_mismatch = 0x127, ///< Constant mismatch - tinytc_status_ir_insufficient_alignment = 0x128, ///< Insufficient alignment - tinytc_status_ir_must_have_yield = 0x129, ///< Must have yield instruction + 0x124, ///< Instruction does not support coopmatrix type + tinytc_status_ir_forbidden_cast = 0x125, ///< Forbidden cast + tinytc_status_ir_invalid_beta = 0x126, ///< Invalid beta value + tinytc_status_ir_init_return_mismatch = 0x127, ///< Mismatch of init values and returned values + tinytc_status_ir_invalid_matrix_use = 0x128, ///< Invalid matrix use + tinytc_status_ir_unsupported_coopmatrix_shape = 0x129, ///< Unsupported coopmatrix shape + tinytc_status_ir_incompatible_scalar_types = 0x12a, ///< Incompatible scalar types + tinytc_status_ir_constant_mismatch = 0x12b, ///< Constant mismatch + tinytc_status_ir_insufficient_alignment = 0x12c, ///< Insufficient alignment + tinytc_status_ir_must_have_yield = 0x12d, ///< Must have yield instruction tinytc_status_ir_yield_in_else_branch_missing = - 0x130, ///< Must have yield instruction in else branch - tinytc_status_ir_from_to_mismatch = 0x131, ///< size(from) != size(to) in foreach + 0x12e, ///< Must have yield instruction in else branch + tinytc_status_ir_from_to_mismatch = 0x12f, ///< size(from) != size(to) in foreach tinytc_status_ir_operand_type_must_match_return_type = - 0x132, /// Operand type must match return type + 0x130, /// Operand type must match return type // SPIR-V errors tinytc_status_spirv_forbidden_forward_declaration = 0x1000, ///< Forward declaration of id is forbidden diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 48bed60f..f7c53aa0 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -69,8 +69,12 @@ enum class status { ir_expected_memref = tinytc_status_ir_expected_memref, ir_expected_memref_or_scalar = tinytc_status_ir_expected_memref_or_scalar, ir_expected_memref_or_group = tinytc_status_ir_expected_memref_or_group, - ir_expected_matrix = tinytc_status_ir_expected_matrix, - ir_expected_vector_or_matrix = tinytc_status_ir_expected_vector_or_matrix, + ir_expected_memref_order_0 = tinytc_status_ir_expected_memref_order_0, + ir_expected_memref_order_1 = tinytc_status_ir_expected_memref_order_1, + ir_expected_memref_order_2 = tinytc_status_ir_expected_memref_order_2, + ir_expected_memref_order_0_or_1 = tinytc_status_ir_expected_memref_order_0_or_1, + ir_expected_memref_order_1_or_2 = tinytc_status_ir_expected_memref_order_1_or_2, + ir_expected_memref_order_0_1_or_2 = tinytc_status_ir_expected_memref_order_0_1_or_2, ir_unexpected_yield = tinytc_status_ir_unexpected_yield, ir_yield_mismatch = tinytc_status_ir_yield_mismatch, ir_subview_mismatch = tinytc_status_ir_subview_mismatch, diff --git a/src/compiler_context.cpp b/src/compiler_context.cpp index 604ff9e4..ca49fdd8 100644 --- a/src/compiler_context.cpp +++ b/src/compiler_context.cpp @@ -4,17 +4,16 @@ #include "compiler_context.hpp" #include "compiler_context_cache.hpp" #include "error.hpp" +#include "node/value_node.hpp" #include "tinytc/tinytc.h" #include "tinytc/types.h" #include "tinytc/types.hpp" #include -#include +#include namespace tinytc { -void default_error_reporter(char const *what, const tinytc_location_t *, void *) { - std::cerr << what << std::endl; -} +void default_error_reporter(char const *, const tinytc_location_t *, void *) {} } // namespace tinytc using namespace tinytc; @@ -43,10 +42,23 @@ auto tinytc_compiler_context::source_text(std::int32_t source_id) return {"", 0}; } void tinytc_compiler_context::report_error(location const &l, char const *what) { + report_error(l, {}, what); +} + +void tinytc_compiler_context::report_error(tinytc_location const &l, + array_view const &ref_values, + char const *what) { auto [name, name_size] = source_name(l.begin.source_id); auto [text, text_size] = source_text(l.begin.source_id); auto err = report_error_with_context(text, text_size, name, l, what); reporter_(err.c_str(), &l, user_data_); + for (auto &ref_value : ref_values) { + if (ref_value) { + auto err = report_error_with_context(text, text_size, name, ref_value->loc(), + "value defined here"); + reporter_(err.c_str(), &ref_value->loc(), user_data_); + } + } } auto tinytc_compiler_context::opt_flag(tinytc_optflag_t flag) const -> bool { diff --git a/src/compiler_context.hpp b/src/compiler_context.hpp index 1bb2a568..ba1b5f74 100644 --- a/src/compiler_context.hpp +++ b/src/compiler_context.hpp @@ -5,6 +5,7 @@ #define COMPILER_CONTEXT_20240924_HPP #include "reference_counted.hpp" +#include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" @@ -50,6 +51,8 @@ struct tinytc_compiler_context : tinytc::reference_counted { auto source_name(std::int32_t source_id) -> std::pair; auto source_text(std::int32_t source_id) -> std::pair; void report_error(tinytc_location const &l, char const *what); + void report_error(tinytc_location const &l, + tinytc::array_view const &ref_values, char const *what); auto opt_flag(tinytc_optflag_t flag) const -> bool; inline void opt_flag(tinytc_optflag_t flag, std::int32_t state) { opt_flags_[flag] = state; } diff --git a/src/error.cpp b/src/error.cpp index dc630ca2..13c63e4b 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -5,6 +5,7 @@ #include "location.hpp" #include "tinytc/tinytc.h" +#include #include #include #include @@ -12,12 +13,27 @@ namespace tinytc { compilation_error::compilation_error(location const &loc, status code, std::string extra_info) - : loc_(loc), code_(code), extra_info_(std::move(extra_info)) {} + : loc_(loc), ref_values_{}, num_ref_values_{0}, code_(code), + extra_info_(std::move(extra_info)) {} + +compilation_error::compilation_error(location const &loc, array_view ref_values, + status code, std::string extra_info) + : loc_(loc), code_(code), extra_info_(std::move(extra_info)) { + num_ref_values_ = std::min(error_max_ref, ref_values.size()); + for (std::size_t i = 0; i < num_ref_values_; ++i) { + ref_values_[i] = ref_values[i]; + } +} auto report_error_with_context(char const *code, std::size_t code_len, std::string const &file_name, location const &l, std::string const &what) -> std::string { constexpr int additional_context_lines = 2; + auto oerr = std::ostringstream{}; + oerr << file_name << ":"; + print_range(oerr, l.begin, l.end); + oerr << ": " << what << std::endl; + int cur_line = 1; const char *begin = code; const char *limit = begin + code_len; @@ -27,7 +43,7 @@ auto report_error_with_context(char const *code, std::size_t code_len, std::stri } ++begin; } - auto oerr = std::ostringstream{}; + char const *end = begin; int start_col = -1; while (cur_line <= l.end.line && *end != '\0' && end <= limit) { @@ -63,7 +79,7 @@ auto report_error_with_context(char const *code, std::size_t code_len, std::stri col_begin = l.begin.column > 1 ? l.begin.column - 1 : 0; num_col = l.end.column > l.begin.column ? l.end.column - l.begin.column : 1; } - oerr << std::string(col_begin, ' ') << std::string(num_col, '~') << std::endl; + oerr << std::string(col_begin, ' ') << std::string(num_col, '~'); } ++cur_line; start_col = -1; @@ -71,9 +87,6 @@ auto report_error_with_context(char const *code, std::size_t code_len, std::stri } ++end; } - oerr << file_name << ":"; - print_range(oerr, l.begin, l.end); - oerr << ": " << what; return std::move(oerr).str(); } @@ -150,10 +163,18 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Expected memref type or scalar type"; case tinytc_status_ir_expected_memref_or_group: return "Expected memref or group operand"; - case tinytc_status_ir_expected_matrix: - return "Expected matrix input"; - case tinytc_status_ir_expected_vector_or_matrix: - return "Expected vector or matrix input"; + case tinytc_status_ir_expected_memref_order_0: + return "Expected memref of order 0 (scalar)"; + case tinytc_status_ir_expected_memref_order_1: + return "Expected memref of order 1 (vector)"; + case tinytc_status_ir_expected_memref_order_2: + return "Expected memref of order 2 (matrix)"; + case tinytc_status_ir_expected_memref_order_0_or_1: + return "Expected memref of order 0 or 1 (scalar or vector)"; + case tinytc_status_ir_expected_memref_order_1_or_2: + return "Expected memref of order 1 or 2 (vector or matrix)"; + case tinytc_status_ir_expected_memref_order_0_1_or_2: + return "Expected memref of order 0, 1, or 2 (scalar, vector, or matrix)"; case tinytc_status_ir_unexpected_yield: return "Yield encountered in non-yielding region"; case tinytc_status_ir_yield_mismatch: @@ -200,7 +221,7 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Unsupported coopmatrix shape for the combination of scalar type, matrix use, and " "target architecture"; case tinytc_status_ir_incompatible_scalar_types: - return "Scalar types violate compatibility rules"; + return "Scalar type violates compatibility rules"; case tinytc_status_ir_constant_mismatch: return "Type of constant does not match type of returned value"; case tinytc_status_ir_insufficient_alignment: diff --git a/src/error.hpp b/src/error.hpp index 6784d180..4f122a13 100644 --- a/src/error.hpp +++ b/src/error.hpp @@ -9,6 +9,7 @@ #include "tinytc/types.h" #include "tinytc/types.hpp" +#include #include #include #include @@ -24,12 +25,19 @@ auto report_error_with_context(char const *code, std::size_t code_len, std::stri //! Compilation error class compilation_error : public std::exception { public: + constexpr static std::size_t error_max_ref = 4; + //! ctor; taking location, status code, and expanatory string compilation_error(location const &loc, status code, std::string extra_info = {}); + compilation_error(location const &loc, array_view ref_values, status code, + std::string extra_info = {}); //! Get status code inline auto code() const noexcept { return code_; } //! Get location inline auto loc() const noexcept -> location const & { return loc_; } + inline auto ref_values() const noexcept -> array_view { + return array_view(ref_values_.data(), num_ref_values_); + } //! Get explanatory string inline char const *what() const noexcept override { return error_string(code_); } //! Get additional information @@ -37,6 +45,8 @@ class compilation_error : public std::exception { private: location loc_; + std::array ref_values_; + std::size_t num_ref_values_; status code_; std::string extra_info_; }; @@ -66,7 +76,7 @@ auto exception_to_status_code(F &&f, auto what = (std::ostringstream{} << e.what() << " (" << e.extra_info() << ')').str(); } else { - context->report_error(e.loc(), e.what()); + context->report_error(e.loc(), e.ref_values(), e.what()); } } return static_cast(e.code()); diff --git a/src/inst.cpp b/src/inst.cpp index 46c22b98..477107d9 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -156,12 +156,13 @@ tinytc_status_t tinytc_arith_inst_create(tinytc_inst_t *instr, tinytc_arithmetic } tinytc_status_t tinytc_arith_unary_inst_create(tinytc_inst_t *instr, tinytc_arithmetic_unary_t op, - tinytc_value_t a, const tinytc_location_t *loc) { + tinytc_value_t a, tinytc_data_type_t ty, + const tinytc_location_t *loc) { if (instr == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(op), a, + *instr = std::make_unique(enum_cast(op), a, ty, get_optional(loc)) .release(); }); diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index d81c8658..533523cd 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -115,10 +115,12 @@ blas_a2_inst::blas_a2_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinyt auto betat = get_scalar_type(loc(), op(op_beta)); if (compatible_type(alphat->ty(), At->element_ty()) != At->element_ty()) { - throw compilation_error(loc(), status::ir_incompatible_scalar_types); + throw compilation_error(loc(), {&op(op_alpha), &op(op_A)}, + status::ir_incompatible_scalar_types); } if (compatible_type(betat->ty(), Bt->element_ty()) != Bt->element_ty()) { - throw compilation_error(loc(), status::ir_incompatible_scalar_types); + throw compilation_error(loc(), {&op(op_beta), &op(op_B)}, + status::ir_incompatible_scalar_types); } } @@ -140,10 +142,12 @@ blas_a3_inst::blas_a3_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinyt const auto AB_ty = compatible_type(At->element_ty(), Bt->element_ty()); if (compatible_type(alphat->ty(), AB_ty) != AB_ty) { - throw compilation_error(loc(), status::ir_incompatible_scalar_types); + throw compilation_error(loc(), {&op(op_alpha), &op(op_A), &op(op_B)}, + status::ir_incompatible_scalar_types); } if (compatible_type(betat->ty(), Ct->element_ty()) != Ct->element_ty()) { - throw compilation_error(loc(), status::ir_incompatible_scalar_types); + throw compilation_error(loc(), {&op(op_beta), &op(op_C)}, + status::ir_incompatible_scalar_types); } } @@ -169,6 +173,10 @@ axpby_inst::axpby_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, t auto a = get_memref_type(loc(), A()); auto b = get_memref_type(loc(), B()); + if (b->dim() < 0 || b->dim() > 2) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_0_1_or_2); + } + bool shape_equal = false; if (tA_ == transpose::T && a->dim() == 2 && b->dim() == 2) { shape_equal = a->shape()[1] == b->shape()[0] && a->shape()[0] == b->shape()[1]; @@ -177,11 +185,7 @@ axpby_inst::axpby_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, t } if (!shape_equal) { - throw compilation_error(loc(), status::ir_incompatible_shapes); - } - - if (b->dim() > 2) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix); + throw compilation_error(loc(), {&A(), &B()}, status::ir_incompatible_shapes); } } @@ -193,10 +197,10 @@ arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b loc(lc); if (a().ty() != ty) { - throw compilation_error(a().loc(), status::ir_operand_type_must_match_return_type); + throw compilation_error(loc(), {&a()}, status::ir_operand_type_must_match_return_type); } if (b().ty() != ty) { - throw compilation_error(b().loc(), status::ir_operand_type_must_match_return_type); + throw compilation_error(loc(), {&b()}, status::ir_operand_type_must_match_return_type); } if (isa(*ty)) { @@ -266,67 +270,71 @@ arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b } arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0, - location const &lc) + tinytc_data_type_t ty, location const &lc) : standard_inst{IK::arith_unary}, operation_(operation) { op(op_a, a0); loc(lc); - tinytc_data_type_t to_ty = [&]() -> tinytc_data_type_t { - if (isa(*a().ty())) { - if (operation_ != arithmetic_unary::not_) { - throw compilation_error(loc(), status::ir_boolean_unsupported); - } - return a().ty(); - } else if (isa(*a().ty())) { - if (operation_ != arithmetic_unary::neg) { - throw compilation_error(loc(), status::ir_coopmatrix_unsupported); - } - return a().ty(); - } else { - auto a_ty = get_scalar_type(loc(), a()); - tinytc_data_type_t to_ty = a_ty; - - bool inst_supports_int = true; - bool inst_supports_fp = true; - bool inst_supports_complex = true; - switch (operation_) { - case arithmetic_unary::abs: - case arithmetic_unary::neg: - break; - case arithmetic_unary::not_: - inst_supports_fp = false; - inst_supports_complex = false; - break; - case arithmetic_unary::conj: - case arithmetic_unary::im: - case arithmetic_unary::re: - inst_supports_int = false; - inst_supports_fp = false; - break; - } - if (!inst_supports_int && is_integer_type(a_ty->ty())) { - throw compilation_error(loc(), status::ir_int_unsupported); - } - if (!inst_supports_fp && is_floating_type(a_ty->ty())) { - throw compilation_error(loc(), status::ir_fp_unsupported); - } - if (!inst_supports_complex && is_complex_type(a_ty->ty())) { - throw compilation_error(loc(), status::ir_complex_unsupported); - } - switch (operation_) { - case arithmetic_unary::abs: - case arithmetic_unary::im: - case arithmetic_unary::re: - to_ty = scalar_data_type::get(a_ty->context(), element_type(a_ty->ty())); - break; - default: - break; - } - return to_ty; + result(0) = value_node{ty, this, lc}; + + // Check if inst is supported for combination of a type and result type + switch (operation_) { + case arithmetic_unary::abs: + case arithmetic_unary::im: + case arithmetic_unary::re: { + auto a_ty = get_scalar_type(a().loc(), a()); + auto r_ty = get_scalar_type(loc(), result(0)); + if (r_ty->ty() != element_type(a_ty->ty())) { + throw compilation_error(loc(), {&a()}, status::ir_incompatible_scalar_types); + } + break; + } + default: + if (a().ty() != ty) { + throw compilation_error(loc(), {&a()}, status::ir_operand_type_must_match_return_type); } - }(); + break; + } - result(0) = value_node{to_ty, this, lc}; + if (isa(*ty)) { + if (operation_ != arithmetic_unary::not_) { + throw compilation_error(loc(), status::ir_boolean_unsupported); + } + } else if (isa(*ty)) { + if (operation_ != arithmetic_unary::neg) { + throw compilation_error(loc(), status::ir_coopmatrix_unsupported); + } + } else { + auto a_ty = get_scalar_type(loc(), a()); + + bool inst_supports_int = true; + bool inst_supports_fp = true; + bool inst_supports_complex = true; + switch (operation_) { + case arithmetic_unary::abs: + case arithmetic_unary::neg: + break; + case arithmetic_unary::not_: + inst_supports_fp = false; + inst_supports_complex = false; + break; + case arithmetic_unary::conj: + case arithmetic_unary::im: + case arithmetic_unary::re: + inst_supports_int = false; + inst_supports_fp = false; + break; + } + if (!inst_supports_int && is_integer_type(a_ty->ty())) { + throw compilation_error(loc(), {&a()}, status::ir_int_unsupported); + } + if (!inst_supports_fp && is_floating_type(a_ty->ty())) { + throw compilation_error(loc(), {&a()}, status::ir_fp_unsupported); + } + if (!inst_supports_complex && is_complex_type(a_ty->ty())) { + throw compilation_error(loc(), {&a()}, status::ir_complex_unsupported); + } + } } cast_inst::cast_inst(tinytc_value_t a0, tinytc_data_type_t to_ty, location const &lc) @@ -456,7 +464,7 @@ cooperative_matrix_load_inst::cooperative_matrix_load_inst(transpose t, checked_ throw compilation_error(loc(), status::ir_scalar_mismatch); } if (ot->dim() != 2) { - throw compilation_error(loc(), status::ir_expected_matrix); + throw compilation_error(loc(), status::ir_expected_memref_order_2); } check_index_ty(lc, pos0().ty()); @@ -556,7 +564,7 @@ cooperative_matrix_store_inst::cooperative_matrix_store_inst(checked_flag cflag, throw compilation_error(loc(), status::ir_scalar_mismatch); } if (ot->dim() != 2) { - throw compilation_error(loc(), status::ir_expected_matrix); + throw compilation_error(loc(), status::ir_expected_memref_order_2); } check_index_ty(lc, pos0().ty()); @@ -765,9 +773,14 @@ gemm_inst::gemm_inst(transpose tA, transpose tB, tinytc_value_t alpha0, tinytc_v auto b = get_memref_type(loc(), B()); auto c = get_memref_type(loc(), C()); - if (a->dim() != 2 || b->dim() != 2 || c->dim() != 2) { - throw compilation_error(loc(), status::ir_expected_matrix, - "gemm only supported for memref of order 2 (matrices)"); + if (a->dim() != 2) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_2); + } + if (b->dim() != 2) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_2); + } + if (c->dim() != 2) { + throw compilation_error(loc(), {&C()}, status::ir_expected_memref_order_2); } auto ak = tA_ == transpose::T ? 0 : 1; @@ -781,7 +794,8 @@ gemm_inst::gemm_inst(transpose tA, transpose tB, tinytc_value_t alpha0, tinytc_v oss << "A=" << a->shape(0) << "x" << a->shape(1) << ", "; oss << "B=" << b->shape(0) << "x" << b->shape(1) << ", "; oss << "C=" << c->shape(0) << "x" << c->shape(1); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_incompatible_shapes, + oss.str()); } } @@ -794,9 +808,14 @@ gemv_inst::gemv_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, tin auto b = get_memref_type(loc(), B()); auto c = get_memref_type(loc(), C()); - if (a->dim() != 2 || b->dim() != 1 || c->dim() != 1) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix, - "gemv only supports matrix-vector products"); + if (a->dim() != 2) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_2); + } + if (b->dim() != 1) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_1); + } + if (c->dim() != 1) { + throw compilation_error(loc(), {&C()}, status::ir_expected_memref_order_1); } auto ak = tA_ == transpose::T ? 0 : 1; @@ -808,7 +827,8 @@ gemv_inst::gemv_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, tin oss << "A=" << a->shape(0) << "x" << a->shape(1) << ", "; oss << "b=" << b->shape(0) << ", "; oss << "c=" << c->shape(0); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_incompatible_shapes, + oss.str()); } } @@ -820,9 +840,14 @@ ger_inst::ger_inst(tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t B0, auto b = get_memref_type(loc(), B()); auto c = get_memref_type(loc(), C()); - if (a->dim() != 1 || b->dim() != 1 || c->dim() != 2) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix, - "ger requires two vectors as input and one matrix as output"); + if (a->dim() != 1) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_1); + } + if (b->dim() != 1) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_1); + } + if (c->dim() != 2) { + throw compilation_error(loc(), {&C()}, status::ir_expected_memref_order_2); } auto M = c->shape(0); @@ -833,7 +858,8 @@ ger_inst::ger_inst(tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t B0, oss << "a=" << a->shape(0) << ", "; oss << "b=" << b->shape(0) << ", "; oss << "C=" << c->shape(0) << "x" << c->shape(1); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_incompatible_shapes, + oss.str()); } } @@ -846,9 +872,14 @@ hadamard_inst::hadamard_inst(tinytc_value_t alpha0, tinytc_value_t A0, tinytc_va auto b = get_memref_type(loc(), B()); auto c = get_memref_type(loc(), C()); - if (a->dim() != 1 || b->dim() != 1 || c->dim() != 1) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix, - "hadamard requires two vectors as input and one vector as output"); + if (a->dim() != 1) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_1); + } + if (b->dim() != 1) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_1); + } + if (c->dim() != 1) { + throw compilation_error(loc(), {&C()}, status::ir_expected_memref_order_1); } auto M = c->shape(0); @@ -858,7 +889,8 @@ hadamard_inst::hadamard_inst(tinytc_value_t alpha0, tinytc_value_t A0, tinytc_va oss << "a=" << a->shape(0) << ", "; oss << "b=" << b->shape(0) << ", "; oss << "c=" << c->shape(0); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_incompatible_shapes, + oss.str()); } } @@ -985,14 +1017,19 @@ sum_inst::sum_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, tinyt auto a = get_memref_type(loc(), A()); auto b = get_memref_type(loc(), B()); - bool const size_ok = (a->dim() == 2 && b->dim() == 1) || (a->dim() == 1 && b->dim() == 0); - if (!size_ok) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix); + if (b->dim() == 1 && a->dim() != 2) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_2); + } + if (b->dim() == 0 && a->dim() != 1) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_1); + } + if (b->dim() != 0 && b->dim() != 1) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_0_or_1); } if (a->dim() == 2) { if (a->shape(tA_ == transpose::T ? 1 : 0) != b->shape(0)) { - throw compilation_error(loc(), status::ir_incompatible_shapes); + throw compilation_error(loc(), {&A(), &B()}, status::ir_incompatible_shapes); } } } diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 4e1700d0..746aabdc 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -384,7 +384,8 @@ class arith_unary_inst : public standard_inst<1, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::arith_unary; } enum op_number { op_a = 0 }; - arith_unary_inst(arithmetic_unary op, tinytc_value_t a, location const &lc = {}); + arith_unary_inst(arithmetic_unary op, tinytc_value_t a, tinytc_data_type_t ty, + location const &lc = {}); inline arithmetic_unary operation() const { return operation_; } inline auto a() -> tinytc_value & { return op(op_a); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index d0c377ca..e075f877 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -26,6 +26,9 @@ using int_or_val = std::variant; using unique_ptr_to_if_inst = std::unique_ptr; + + using identifier = std::variant; + using identifier_and_location = std::pair; } } @@ -44,6 +47,7 @@ #include #include #include + #include #include #include @@ -55,7 +59,16 @@ loc.end = loc2.end; throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); } - }; + } + + void report_error(compiler_context const& cctx, compilation_error const& e) { + if (e.extra_info().size() > 0) { + auto what = (std::ostringstream{} << e.what() << " (" << e.extra_info() << ')').str(); + cctx.get()->report_error(e.loc(), e.ref_values(), what.c_str()); + } else { + cctx.get()->report_error(e.loc(), e.ref_values(), e.what()); + } + } } // namespace tinytc } @@ -142,7 +155,7 @@ WORK_GROUP "work_group" YIELD "yield" ; -%token > LOCAL_IDENTIFIER +%token LOCAL_IDENTIFIER %token GLOBAL_IDENTIFIER %token BOOLEAN_CONSTANT %token INTEGER_CONSTANT @@ -159,8 +172,8 @@ %nterm prog %nterm > func_list %nterm func -%nterm >,std::vector>> parameters -%nterm ,tinytc_data_type_t>> parameter +%nterm ,std::vector>> parameters +%nterm > parameter %nterm >> attributes %nterm > attribute %nterm data_type @@ -190,9 +203,9 @@ %nterm ger_inst %nterm transpose %nterm for_inst -%nterm >, std::vector, std::vector>> optional_loop_carried_values -%nterm >, std::vector>> init_value_list -%nterm , tinytc_value_t>> init_value +%nterm , std::vector, std::vector>> optional_loop_carried_values +%nterm , std::vector>> init_value_list +%nterm > init_value %nterm optional_step %nterm foreach_inst %nterm hadamard_inst @@ -204,7 +217,7 @@ %nterm yield_inst %nterm for_loop_var_type %nterm var_definition -%nterm >> identifier_list +%nterm > identifier_list %nterm valued_inst %nterm alloca_inst %nterm arith_inst @@ -269,13 +282,13 @@ func: ctx.push_scope(); auto name_it = $parameters.first.begin(); for (auto &p : func_node->params()) { - ctx.val(*name_it, p, @parameters); + ctx.val(name_it->first, p, name_it->second); ++name_it; } ctx.push_region(&func_node->body()); $$ = func{func_node.release()}; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } }[prototype] region { @@ -301,7 +314,7 @@ parameters: parameter: LOCAL_IDENTIFIER COLON data_type { - $$ = std::make_pair($LOCAL_IDENTIFIER, $data_type); + $$ = std::make_pair(std::make_pair($LOCAL_IDENTIFIER, @LOCAL_IDENTIFIER), $data_type); } ; @@ -354,7 +367,7 @@ coopmatrix_type: try { $$ = get_coopmatrix($scalar_type, $rows, $cols, $MATRIX_USE, @coopmatrix_type); } catch (compilation_error const& e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -365,7 +378,7 @@ memref_type: try { $$ = get_memref($scalar_type, $mode_list, {}, $optional_address_space, @memref_type); } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -379,7 +392,7 @@ memref_type: $$ = get_memref($scalar_type, $mode_list, $optional_stride_list, $optional_address_space, @memref_type); } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -474,7 +487,7 @@ axpby_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -506,7 +519,7 @@ barrier_inst: try { $$ = inst { std::make_unique(fence_flags, @barrier_inst).release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -532,7 +545,7 @@ gemm_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -547,7 +560,7 @@ gemv_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -567,7 +580,7 @@ ger_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -601,7 +614,7 @@ for_inst: ctx.push_region(&inode->body()); $$ = inst{inode.release()}; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } }[loop_header] region { @@ -656,7 +669,7 @@ foreach_inst: ctx.push_region(&inode->body()); $$ = inst{inode.release()}; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } }[loop_header] region { @@ -702,7 +715,7 @@ hadamard_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -717,7 +730,7 @@ sum_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -770,7 +783,7 @@ alloca_inst: std::make_unique(std::move($memref_type), @alloca_inst).release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -778,8 +791,6 @@ alloca_inst: arith_inst: ARITH ARITHMETIC var[a] COMMA var[b] COLON data_type[ty] { - check_type($a, $ty, @a, @ty); - check_type($b, $ty, @b, @ty); try { $$ = inst { std::make_unique($ARITHMETIC, std::move($a), std::move($b), std::move($ty), @@ -787,7 +798,7 @@ arith_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -795,15 +806,14 @@ arith_inst: arith_unary_inst: ARITH ARITHMETIC_UNARY var[a] COLON data_type[ty] { - check_type($a, $ty, @a, @ty); try { $$ = inst { - std::make_unique($ARITHMETIC_UNARY, std::move($a), + std::make_unique($ARITHMETIC_UNARY, std::move($a), std::move($ty), @arith_unary_inst) .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -816,7 +826,7 @@ cast_inst: try { $$ = inst { std::make_unique(std::move($a), $to, @cast_inst).release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -833,7 +843,7 @@ compare_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -847,7 +857,7 @@ constant_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -857,7 +867,7 @@ constant_inst: std::make_unique($FLOATING_CONSTANT, $data_type, @constant_inst).release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -867,7 +877,7 @@ constant_inst: std::make_unique($INTEGER_CONSTANT, $data_type, @constant_inst).release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -877,7 +887,7 @@ constant_inst: std::make_unique($BOOLEAN_CONSTANT, $data_type, @constant_inst).release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -894,7 +904,7 @@ cooperative_matrix_load_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -918,7 +928,7 @@ cooperative_matrix_mul_add_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -935,7 +945,7 @@ cooperative_matrix_scale_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -953,7 +963,7 @@ cooperative_matrix_store_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -986,7 +996,7 @@ expand_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } catch (std::exception const& e) { error(@expand_inst, e.what()); @@ -1023,7 +1033,7 @@ fuse_inst: std::make_unique(std::move($var), $from, $to, @fuse_inst).release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -1043,7 +1053,7 @@ load_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -1063,7 +1073,7 @@ store_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -1094,7 +1104,7 @@ if_inst: ctx.push_region(&inode->then()); $$ = std::move(inode); } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } }[header] region { @@ -1139,7 +1149,7 @@ parallel_inst: ctx.push_region(&inode->body()); $$ = inst{inode.release()}; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } }[header] region { @@ -1158,7 +1168,7 @@ size_inst: try { $$ = inst { std::make_unique(std::move($var), $mode, @size_inst).release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -1216,7 +1226,7 @@ subview_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } catch (std::exception const& e) { error(@subview_inst, e.what()); @@ -1258,7 +1268,7 @@ work_group_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } diff --git a/src/pass/clone.cpp b/src/pass/clone.cpp index aefc2265..f1b76e7a 100644 --- a/src/pass/clone.cpp +++ b/src/pass/clone.cpp @@ -29,7 +29,8 @@ auto inst_cloner::operator()(arith_inst &in) -> std::unique_ptr { in.result(0).ty(), in.loc()); } auto inst_cloner::operator()(arith_unary_inst &in) -> std::unique_ptr { - return std::make_unique(in.operation(), subs(&in.a()), in.loc()); + return std::make_unique(in.operation(), subs(&in.a()), in.result(0).ty(), + in.loc()); } auto inst_cloner::operator()(barrier_inst &in) -> std::unique_ptr { return std::make_unique(in.fence_flags(), in.loc()); diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index e7d76071..953458a2 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -317,7 +317,7 @@ std::vector convert_to_opencl_pass::operator()(axpby_inst const &ins }); return {bb.get_product()}; } - throw compilation_error(inst.loc(), status::ir_expected_vector_or_matrix); + return {}; } std::vector convert_to_opencl_pass::operator()(barrier_inst const &b) { @@ -1516,8 +1516,6 @@ std::vector convert_to_opencl_pass::operator()(sum_inst const &inst) inner_loop(bb); } }); - } else { - throw compilation_error(inst.loc(), status::ir_expected_vector_or_matrix); } return {bb.get_product()}; } diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 3a667d14..3dd7ff50 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -125,7 +125,7 @@ void dump_ir_pass::operator()(arith_unary_inst const &a) { *os_ << " = arith." << to_string(a.operation()) << " "; dump_val(a.a()); *os_ << " : "; - visit(*this, *a.a().ty()); + visit(*this, *a.result(0).ty()); } void dump_ir_pass::operator()(barrier_inst const &b) { diff --git a/test/codegen/scalar_arithmetic.ir b/test/codegen/scalar_arithmetic.ir index bc9ee44a..a63d65f6 100644 --- a/test/codegen/scalar_arithmetic.ir +++ b/test/codegen/scalar_arithmetic.ir @@ -90,10 +90,10 @@ func @t6(%a: c32, %b: c32) { %2 = arith.mul %a, %b : c32 %3 = arith.div %a, %b : c32 %4 = arith.neg %a : c32 - %5 = arith.abs %a : c32 + %5 = arith.abs %a : f32 %6 = arith.conj %a : c32 - %7 = arith.im %a : c32 - %8 = arith.re %a : c32 + %7 = arith.im %a : f32 + %8 = arith.re %a : f32 ; CHECK: float2 x = a + b; ; CHECK-NEXT: float2 x1 = a - b; ; CHECK-NEXT: float2 x2 = a * b.x + (float2) (-a.y, a.x) * b.y; diff --git a/test/codegen/scalar_arithmetic_error.ir b/test/codegen/scalar_arithmetic_error.ir index e51c6a0c..6cef12e9 100644 --- a/test/codegen/scalar_arithmetic_error.ir +++ b/test/codegen/scalar_arithmetic_error.ir @@ -4,7 +4,7 @@ ; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s func @t1(%a: f32, %b: f32) { %1 = arith.and %a, %b : f32 -; CHECK: %1 = arith.and %a, %b : f32 -; CHECK-NEXT: ~~~~~~~~~~~~~~~~~~~~~~ -; CHECK-NEXT::6.8-29: Floating point type unsupported by instruction +; CHECK: :6.8-29: Floating point type unsupported by instruction +; CHECK: %1 = arith.and %a, %b : f32 +; CHECK-NEXT: ~~~~~~~~~~~~~~~~~~~~~~ } diff --git a/test/opt/constant-propagation.ir b/test/opt/constant-propagation.ir index a03c6d7a..4cecfdff 100644 --- a/test/opt/constant-propagation.ir +++ b/test/opt/constant-propagation.ir @@ -132,9 +132,9 @@ func @known_arith_complex() { %3 = arith.div %a, %b : c32 %4 = arith.neg %a : c32 %5 = arith.conj %a : c32 - %6 = arith.abs %a : c32 - %7 = arith.im %a : c32 - %8 = arith.re %a : c32 + %6 = arith.abs %a : f32 + %7 = arith.im %a : f32 + %8 = arith.re %a : f32 ; CHECK-LABEL: func @known_arith_complex({{.*}} ; CHECK: %0 = constant [0x1p+1,0x1.cp+2] -> c32 ; CHECK-NEXT: %1 = arith.add %a, %b : c32 @@ -149,9 +149,9 @@ func @known_arith_complex() { ; CHECK-NEXT: %10 = constant [0x1.8p+1,-0x1p+1] -> c32 ; CHECK-NEXT: %11 = arith.conj %a : c32 ; CHECK-NEXT: %12 = constant 0x1.cd82b4p+1 -> f32 -; CHECK-NEXT: %13 = arith.abs %a : c32 +; CHECK-NEXT: %13 = arith.abs %a : f32 ; CHECK-NEXT: %14 = constant 0x1p+1 -> f32 -; CHECK-NEXT: %15 = arith.im %a : c32 +; CHECK-NEXT: %15 = arith.im %a : f32 ; CHECK-NEXT: %16 = constant 0x1.8p+1 -> f32 -; CHECK-NEXT: %17 = arith.re %a : c32 +; CHECK-NEXT: %17 = arith.re %a : f32 } diff --git a/test/spv/arith_unary.ir b/test/spv/arith_unary.ir index f4bafe53..30f79255 100644 --- a/test/spv/arith_unary.ir +++ b/test/spv/arith_unary.ir @@ -32,11 +32,11 @@ func @tfloat(%a: f32) { } func @tcomplex(%a: c32) { - %0 = arith.abs %a : c32 + %0 = arith.abs %a : f32 %1 = arith.neg %a : c32 %2 = arith.conj %a : c32 - %3 = arith.im %a : c32 - %4 = arith.re %a : c32 + %3 = arith.im %a : f32 + %4 = arith.re %a : f32 ; CHECK: %[[#A2:]] = OpFMul %[[#C32]] %[[#]] %[[#]] ; CHECK-NEXT: %[[#A2_0:]] = OpCompositeExtract %[[#F32]] %[[#A2]] 0 ; CHECK-NEXT: %[[#A2_1:]] = OpCompositeExtract %[[#F32]] %[[#A2]] 1 diff --git a/tools/offline_compiler/main.cpp b/tools/offline_compiler/main.cpp index 6375df75..8ae9aa6c 100644 --- a/tools/offline_compiler/main.cpp +++ b/tools/offline_compiler/main.cpp @@ -92,6 +92,9 @@ int main(int argc, char **argv) { auto ctx = compiler_context{}; try { ctx = make_compiler_context(); + ctx.set_error_reporter([](char const *what, const tinytc_location_t *, + void *) { std::cerr << what << std::endl; }, + nullptr); ctx.set_optimization_level(opt_level); cmd::set_optflags(ctx, flags); info.set_core_features(core_features); diff --git a/tools/opt/main.cpp b/tools/opt/main.cpp index 543dfcc1..2c99ef3f 100644 --- a/tools/opt/main.cpp +++ b/tools/opt/main.cpp @@ -87,6 +87,9 @@ int main(int argc, char **argv) { auto ctx = compiler_context{}; try { ctx = make_compiler_context(); + ctx.set_error_reporter([](char const *what, const tinytc_location_t *, + void *) { std::cerr << what << std::endl; }, + nullptr); ctx.set_optimization_level(opt_level); cmd::set_optflags(ctx, flags); info.set_core_features(core_features); From aa95e598f76add23a4cc266141c72ce942439b09 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 21 Nov 2024 15:45:00 +0100 Subject: [PATCH 119/297] Update cast and compare Signed-off-by: Carsten Uphoff --- include/tinytc/tinytc.h | 3 ++- include/tinytc/tinytc.hpp | 6 +++-- src/codegen_tools.cpp | 13 ++++++---- src/error.cpp | 5 ++-- src/error.hpp | 2 +- src/inst.cpp | 8 +++--- src/node/inst_node.cpp | 40 +++++++++++++++-------------- src/node/inst_node.hpp | 3 ++- src/parser/parser_impl.yy | 9 +++---- src/pass/clone.cpp | 3 ++- src/pass/dump_ir.cpp | 6 ++--- src/pass/lower_foreach.cpp | 13 +++++----- src/pass/lower_linalg.cpp | 12 ++++++--- src/recipe/tall_and_skinny.cpp | 5 ++-- test/codegen/cast.ir | 14 +++++----- test/codegen/coopmatrix_basic.ir | 2 +- test/codegen/if.ir | 12 ++++----- test/codegen/scalar_arithmetic.ir | 26 +++++++++---------- test/opt/check-ir/cast_forbidden.ir | 4 +-- test/opt/constant-propagation.ir | 28 ++++++++++---------- test/opt/insert-barrier.ir | 8 +++--- test/spv/cast.ir | 18 ++++++------- test/spv/compare.ir | 28 ++++++++++---------- test/spv/if.ir | 2 +- 24 files changed, 141 insertions(+), 129 deletions(-) diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index cf2f82a9..3539bdb9 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -268,13 +268,14 @@ TINYTC_EXPORT tinytc_status_t tinytc_cast_inst_create(tinytc_inst_t *instr, tiny * @param cond [in] compare type * @param a [in] left-hand operand * @param b [in] right-hand operand + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_cmp_inst_create(tinytc_inst_t *instr, tinytc_cmp_condition_t cond, tinytc_value_t a, - tinytc_value_t b, + tinytc_value_t b, tinytc_data_type_t ty, const tinytc_location_t *loc); /** diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index c796ae03..d59c6e26 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -926,14 +926,16 @@ inline inst make_cast(value a, data_type to_ty, location const &loc = {}) { * @param cond Condition type * @param a First operand * @param b Second operand + * @param ty Result type * @param loc Source code location * * @return Instruction */ -inline inst make_cmp(cmp_condition cond, value a, value b, location const &loc = {}) { +inline inst make_cmp(cmp_condition cond, value a, value b, data_type ty, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC( - tinytc_cmp_inst_create(&instr, static_cast(cond), a, b, &loc), loc); + tinytc_cmp_inst_create(&instr, static_cast(cond), a, b, ty, &loc), + loc); return inst(instr); } diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index c68c4e5b..310d3a8b 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -490,6 +490,7 @@ void write_matrix_block(block_builder &bb, block_accessor const &block, void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, value sg_id, sgs_loop_body_builder_new const &body) { auto ity = loop_trip_count->ty(); + auto bool_ty = boolean_data_type::get(ity->context()); auto c_sgs = bb.add(make_constant(sgs, ity)); auto c_sgs_tiles = bb.add(make_constant(sgs * num_tiles, ity)); auto c0 = bb.add(make_constant(0, ity)); @@ -501,7 +502,8 @@ void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, in instant_constant_fold_add(bb, make_arith(arithmetic::rem, loop_trip_count, c_sgs, ity)); auto sg_id_cast = instant_constant_fold_add(bb, make_cast(sg_id, ity)); - auto is_blocks_gt_0 = instant_constant_fold_add(bb, make_cmp(cmp_condition::gt, blocks, c0)); + auto is_blocks_gt_0 = + instant_constant_fold_add(bb, make_cmp(cmp_condition::gt, blocks, c0, bool_ty)); bb.if_condition(is_blocks_gt_0, [&](region_builder &bb) { auto block_start = instant_constant_fold_add(bb, make_arith(arithmetic::mul, c_sgs, sg_id_cast, ity)); @@ -511,10 +513,10 @@ void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, in [&](region_builder &bb, value block) { body(bb, block, false, c_sgs); }); }); - auto condition0 = instant_constant_fold_add(bb, make_cmp(cmp_condition::gt, rem, c0)); + auto condition0 = instant_constant_fold_add(bb, make_cmp(cmp_condition::gt, rem, c0, bool_ty)); bb.if_condition(condition0, [&](region_builder &bb) { - auto condition1 = - instant_constant_fold_add(bb, make_cmp(cmp_condition::eq, sg_id_cast, c_tiles_1)); + auto condition1 = instant_constant_fold_add( + bb, make_cmp(cmp_condition::eq, sg_id_cast, c_tiles_1, bool_ty)); bb.if_condition(condition1, [&](region_builder &bb) { auto block = instant_constant_fold_add(bb, make_arith(arithmetic::mul, blocks, c_sgs, ity)); @@ -527,6 +529,7 @@ void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int bloc int num_tiles, value sg_id, uniform_loop_body_builder_new const &body) { auto ity = loop_trip_count->ty(); + auto bool_ty = boolean_data_type::get(ity->context()); auto c0 = bb.add(make_constant(0, ity)); auto c1 = bb.add(make_constant(1, ity)); auto c_tiles = bb.add(make_constant(num_tiles, ity)); @@ -555,7 +558,7 @@ void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int bloc // The following if makes it easy to eliminate the remainder handler in optimization if rem // == 0 is known at compile time. Without the if, we would need to prove that block_start_1 // is non-negative to eliminate the for-loop. - auto is_rem_gt_0 = instant_constant_fold_add(bb, make_cmp(cmp_condition::gt, rem, c0)); + auto is_rem_gt_0 = instant_constant_fold_add(bb, make_cmp(cmp_condition::gt, rem, c0, bool_ty)); bb.if_condition(is_rem_gt_0, [&](region_builder &bb) { auto block_start_1 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, sg_id_cast, ity)); diff --git a/src/error.cpp b/src/error.cpp index 13c63e4b..c11c6412 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -16,8 +16,9 @@ compilation_error::compilation_error(location const &loc, status code, std::stri : loc_(loc), ref_values_{}, num_ref_values_{0}, code_(code), extra_info_(std::move(extra_info)) {} -compilation_error::compilation_error(location const &loc, array_view ref_values, - status code, std::string extra_info) +compilation_error::compilation_error(location const &loc, + array_view ref_values, status code, + std::string extra_info) : loc_(loc), code_(code), extra_info_(std::move(extra_info)) { num_ref_values_ = std::min(error_max_ref, ref_values.size()); for (std::size_t i = 0; i < num_ref_values_; ++i) { diff --git a/src/error.hpp b/src/error.hpp index 4f122a13..defdd385 100644 --- a/src/error.hpp +++ b/src/error.hpp @@ -29,7 +29,7 @@ class compilation_error : public std::exception { //! ctor; taking location, status code, and expanatory string compilation_error(location const &loc, status code, std::string extra_info = {}); - compilation_error(location const &loc, array_view ref_values, status code, + compilation_error(location const &loc, array_view ref_values, status code, std::string extra_info = {}); //! Get status code inline auto code() const noexcept { return code_; } diff --git a/src/inst.cpp b/src/inst.cpp index 477107d9..86ba5a2f 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -178,15 +178,15 @@ tinytc_status_t tinytc_cast_inst_create(tinytc_inst_t *instr, tinytc_value_t a, } tinytc_status_t tinytc_cmp_inst_create(tinytc_inst_t *instr, tinytc_cmp_condition_t cond, - tinytc_value_t a, tinytc_value_t b, + tinytc_value_t a, tinytc_value_t b, tinytc_data_type_t ty, const tinytc_location_t *loc) { if (instr == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = - std::make_unique(enum_cast(cond), a, b, get_optional(loc)) - .release(); + *instr = std::make_unique(enum_cast(cond), a, b, ty, + get_optional(loc)) + .release(); }); } diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 533523cd..0b565c38 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -81,7 +81,7 @@ namespace tinytc { scalar_data_type *get_scalar_type(location const &loc, tinytc_value const &v) { auto m = dyn_cast(v.ty()); if (m == nullptr) { - throw compilation_error(loc, status::ir_expected_scalar); + throw compilation_error(loc, {&v}, status::ir_expected_scalar); } return m; } @@ -89,7 +89,7 @@ scalar_data_type *get_scalar_type(location const &loc, tinytc_value const &v) { memref_data_type *get_memref_type(location const &loc, tinytc_value const &v) { auto m = dyn_cast(v.ty()); if (m == nullptr) { - throw compilation_error(loc, status::ir_expected_memref); + throw compilation_error(loc, {&v}, status::ir_expected_memref); } return m; } @@ -342,47 +342,50 @@ cast_inst::cast_inst(tinytc_value_t a0, tinytc_data_type_t to_ty, location const op(op_a, a0); loc(lc); - auto const check_scalar_casting_rules = [](scalar_type a_ty, scalar_type r_ty, - location const &lc) { + auto const check_scalar_casting_rules = [&](scalar_type a_ty, scalar_type r_ty) { if (is_complex_type(a_ty) && !is_complex_type(r_ty)) { - throw compilation_error(lc, status::ir_forbidden_cast); + throw compilation_error(loc(), {&a()}, status::ir_forbidden_cast); } }; - if (auto ct = dyn_cast(a().ty()); ct) { - auto rt = dyn_cast(to_ty); - if (!rt) { - throw compilation_error(loc(), status::ir_expected_coopmatrix); + if (auto rt = dyn_cast(to_ty); rt) { + auto ct = dyn_cast(a().ty()); + if (!ct) { + throw compilation_error(loc(), {&a()}, status::ir_expected_coopmatrix); } if (ct->rows() != rt->rows() || ct->cols() != rt->cols() || ct->use() != rt->use()) { - throw compilation_error(lc, status::ir_forbidden_cast); + throw compilation_error(lc, {&a()}, status::ir_forbidden_cast); } - check_scalar_casting_rules(ct->component_ty(), rt->component_ty(), loc()); + check_scalar_casting_rules(ct->component_ty(), rt->component_ty()); } else { - auto rt = dyn_cast(to_ty); - if (rt == nullptr) { + auto to_ty_scalar = dyn_cast(to_ty); + if (to_ty_scalar == nullptr) { throw compilation_error(lc, status::ir_expected_scalar); } auto at = get_scalar_type(loc(), a()); - check_scalar_casting_rules(at->ty(), rt->ty(), loc()); + check_scalar_casting_rules(at->ty(), to_ty_scalar->ty()); } result(0) = value_node{to_ty, this, loc()}; } compare_inst::compare_inst(cmp_condition cond, tinytc_value_t a0, tinytc_value_t b0, - location const &lc) + tinytc_data_type_t ty, location const &lc) : standard_inst{IK::compare}, cond_(cond) { op(op_a, a0); op(op_b, b0); loc(lc); + if (!isa(*ty)) { + throw compilation_error(loc(), status::ir_expected_boolean); + } + auto at = get_scalar_type(loc(), a()); auto bt = get_scalar_type(loc(), b()); if (at->ty() != bt->ty()) { - throw compilation_error(loc(), status::ir_scalar_mismatch); + throw compilation_error(loc(), {&a(), &b()}, status::ir_scalar_mismatch); } bool inst_supports_complex = true; @@ -398,11 +401,10 @@ compare_inst::compare_inst(cmp_condition cond, tinytc_value_t a0, tinytc_value_t break; } if (!inst_supports_complex && is_complex_type(at->ty())) { - throw compilation_error(loc(), status::ir_complex_unsupported); + throw compilation_error(loc(), {&a(), &b()}, status::ir_complex_unsupported); } - auto result_ty = boolean_data_type::get(at->context()); - result(0) = value_node{result_ty, this, lc}; + result(0) = value_node{ty, this, lc}; } constant_inst::constant_inst(value_type const &value, tinytc_data_type_t ty, location const &lc) diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 746aabdc..2c4417ec 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -427,7 +427,8 @@ class compare_inst : public standard_inst<2, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::compare; } enum op_number { op_a = 0, op_b = 1 }; - compare_inst(cmp_condition cond, tinytc_value_t a, tinytc_value_t b, location const &lc = {}); + compare_inst(cmp_condition cond, tinytc_value_t a, tinytc_value_t b, tinytc_data_type_t ty, + location const &lc = {}); inline cmp_condition cond() const { return cond_; } inline auto a() -> tinytc_value & { return op(op_a); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index e075f877..279141e8 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -821,8 +821,7 @@ arith_unary_inst: cast_inst: - CAST var[a] COLON data_type[from] ARROW data_type[to] { - check_type($a, $from, @a, @from); + CAST var[a] COLON data_type[to] { try { $$ = inst { std::make_unique(std::move($a), $to, @cast_inst).release() }; } catch (compilation_error const &e) { @@ -833,13 +832,11 @@ cast_inst: ; compare_inst: - CMP CMP_CONDITION var[a] COMMA var[b] COLON scalar_type[ty] { - check_type($a, $ty, @a, @ty); - check_type($b, $ty, @b, @ty); + CMP CMP_CONDITION var[a] COMMA var[b] COLON boolean_type { try { $$ = inst { std::make_unique($CMP_CONDITION, std::move($a), std::move($b), - @compare_inst) + std::move($boolean_type), @compare_inst) .release() }; } catch (compilation_error const &e) { diff --git a/src/pass/clone.cpp b/src/pass/clone.cpp index f1b76e7a..8a542118 100644 --- a/src/pass/clone.cpp +++ b/src/pass/clone.cpp @@ -39,7 +39,8 @@ auto inst_cloner::operator()(cast_inst &in) -> std::unique_ptr { return std::make_unique(subs(&in.a()), in.result(0).ty(), in.loc()); } auto inst_cloner::operator()(compare_inst &in) -> std::unique_ptr { - return std::make_unique(in.cond(), subs(&in.a()), subs(&in.b()), in.loc()); + return std::make_unique(in.cond(), subs(&in.a()), subs(&in.b()), + in.result(0).ty(), in.loc()); } auto inst_cloner::operator()(constant_inst &in) -> std::unique_ptr { return std::make_unique(in.value(), in.result(0).ty(), in.loc()); diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 3dd7ff50..51858ccb 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -143,9 +143,7 @@ void dump_ir_pass::operator()(cast_inst const &c) { *os_ << " = cast "; dump_val(c.a()); *os_ << " : "; - visit(*this, *c.a().ty()); - *os_ << " -> "; - visit(*this, *c.result()->ty()); + visit(*this, *c.result(0).ty()); } void dump_ir_pass::operator()(compare_inst const &a) { @@ -155,7 +153,7 @@ void dump_ir_pass::operator()(compare_inst const &a) { *os_ << ", "; dump_val(a.b()); *os_ << " : "; - visit(*this, *a.a().ty()); + visit(*this, *a.result(0).ty()); } void dump_ir_pass::operator()(constant_inst const &c) { diff --git a/src/pass/lower_foreach.cpp b/src/pass/lower_foreach.cpp index 979c3140..7d14b0fd 100644 --- a/src/pass/lower_foreach.cpp +++ b/src/pass/lower_foreach.cpp @@ -19,17 +19,18 @@ void make_loop0(region_builder &bb, value from, value to, value sg_id, int sgs, F &&make_body, location const &loc) { auto ity = from->ty(); auto ctx = compiler_context{sg_id->context(), true}; - auto sg_lid_i32 = bb.add(make_subgroup_local_id(ctx)); - auto sg_lid = bb.add(make_cast(sg_lid_i32, ity)); + auto bool_ty = get_boolean(ctx); + auto sg_lid_i32 = bb.add(make_subgroup_local_id(ctx, loc)); + auto sg_lid = bb.add(make_cast(sg_lid_i32, ity, loc)); auto size = bb.add(make_arith(arithmetic::sub, to, from, ity, loc)); - auto work_item_offset = bb.add(make_arith(arithmetic::add, from, sg_lid, ity)); + auto work_item_offset = bb.add(make_arith(arithmetic::add, from, sg_lid, ity, loc)); tile_loop_by_sgs_new( bb, size, sgs, num_tiles, sg_id, [&](region_builder &bb, value block, bool is_remainder, value trip_count) { - auto loop_var0 = bb.add(make_arith(arithmetic::add, block, work_item_offset, ity)); + auto loop_var0 = bb.add(make_arith(arithmetic::add, block, work_item_offset, ity, loc)); if (is_remainder) { - auto cond = bb.add(make_cmp(cmp_condition::lt, sg_lid, trip_count)); - bb.if_condition(cond, [&](region_builder &bb) { make_body(bb, loop_var0); }); + auto cond = bb.add(make_cmp(cmp_condition::lt, sg_lid, trip_count, bool_ty, loc)); + bb.if_condition(cond, [&](region_builder &bb) { make_body(bb, loop_var0); }, loc); } else { make_body(bb, loop_var0); } diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index 5e23c06b..5adf7df5 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -36,6 +36,7 @@ void gemm_microkernel(region_builder &bb, transpose tA, transpose tB, bool atomi std::int64_t n_block_size, bool n_check, data_type a_ty, data_type b_ty, data_type c_ty, location const &loc) { auto ctx = m_block->context(); + auto bool_ty = boolean_data_type::get(ctx); auto index_ty = scalar_data_type::get(ctx, scalar_type::index); const auto check_a = m_check ? checked_flag::rows : checked_flag::none; @@ -98,7 +99,8 @@ void gemm_microkernel(region_builder &bb, transpose tA, transpose tB, bool atomi auto K0 = instant_constant_fold_add( bb, make_arith(arithmetic::mul, tmp, c_k_block_size, index_ty, loc)); c_acc = compute_c(bb, k_block_size, c_zero, K0, c_acc); - auto needs_remainder = instant_constant_fold_add(bb, make_cmp(cmp_condition::lt, K0, K, loc)); + auto needs_remainder = + instant_constant_fold_add(bb, make_cmp(cmp_condition::lt, K0, K, bool_ty, loc)); auto r = get_bool_constant(needs_remainder); if (r) { if (*r != 0) { @@ -184,6 +186,7 @@ auto linalg_generator::get_memref_type(value_node const &v) const -> const memre void linalg_generator::operator()(axpby_inst &in) { auto ctx = compiler_context{in.alpha().context(), true}; + auto bool_ty = get_boolean(ctx); auto index_ty = get_scalar(ctx, scalar_type::index); auto bt = get_memref_type(in.B()); @@ -196,8 +199,8 @@ void linalg_generator::operator()(axpby_inst &in) { auto sg_lid = bb.add(make_subgroup_local_id(ctx, in.loc())); auto i32_ty = get_scalar(ctx, scalar_type::i32); auto c0 = bb.add(make_constant(0, i32_ty)); - auto cond0 = bb.add(make_cmp(cmp_condition::eq, sg_id, c0)); - auto cond1 = bb.add(make_cmp(cmp_condition::eq, sg_lid, c0)); + auto cond0 = bb.add(make_cmp(cmp_condition::eq, sg_id, c0, bool_ty, in.loc())); + auto cond1 = bb.add(make_cmp(cmp_condition::eq, sg_lid, c0, bool_ty, in.loc())); auto cond = bb.add(make_arith(arithmetic::and_, cond0, cond1, cond0->ty())); bb.if_condition(cond, [&](region_builder &bb) { auto a = bb.add(make_load(&in.A(), {}, in.loc())); @@ -373,6 +376,7 @@ void linalg_generator::operator()(sum_inst &in) { auto bb = region_builder{body}; auto ctx = compiler_context{in.alpha().context(), true}; + auto bool_ty = get_boolean(ctx); auto i32_ty = get_scalar(ctx, scalar_type::i32); auto index_ty = get_scalar(ctx, scalar_type::index); @@ -386,7 +390,7 @@ void linalg_generator::operator()(sum_inst &in) { auto from_index = bb.add(make_cast(from1, index_ty, in.loc())); auto c_zero = bb.add(make_constant_zero(i32_ty, in.loc())); - auto is_from_0 = bb.add(make_cmp(cmp_condition::eq, from1, c_zero, in.loc())); + auto is_from_0 = bb.add(make_cmp(cmp_condition::eq, from1, c_zero, bool_ty, in.loc())); auto c_trip_count = instant_constant_fold_add(bb, make_size(&in.A(), 0, in.loc())); auto c_step = bb.add(make_constant( diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 51445277..c4879507 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -92,6 +92,7 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( return exception_to_status_code( [&] { auto const ty_ = get_scalar(ctx_, enum_cast(ty)); + auto const bool_ty = get_boolean(ctx_); auto const index_ty = get_scalar(ctx_, scalar_type::index); auto const shapes = @@ -144,8 +145,8 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( auto M_val = bb.add(make_size(C, 0, my_loc())); auto M_val_sub_m = bb.add(make_arith(arithmetic::sub, M_val, m, m.get_type(), my_loc())); - auto cond = - bb.add(make_cmp(cmp_condition::lt, M_val_sub_m, c_M_block_size, my_loc())); + auto cond = bb.add(make_cmp(cmp_condition::lt, M_val_sub_m, c_M_block_size, + bool_ty, my_loc())); bb.ifelse( cond, [&](region_builder &bb) { dynamic_gemm(bb, M_val_sub_m); }, [&](region_builder &bb) { static_gemm(bb); }, {}, my_loc()); diff --git a/test/codegen/cast.ir b/test/codegen/cast.ir index d417eb8f..5750e72d 100644 --- a/test/codegen/cast.ir +++ b/test/codegen/cast.ir @@ -4,43 +4,43 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @cast_ii() { %0 = constant 2 -> index - %1 = cast %0 : index -> i32 + %1 = cast %0 : i32 ; CHECK-LABEL: void cast_ii() { ; CHECK: int x1 = (int) x; } func @cast_ff() { %0 = constant 2.0 -> f32 - %1 = cast %0 : f32 -> f64 + %1 = cast %0 : f64 ; CHECK-LABEL: void cast_ff() { ; CHECK: double x1 = (double) x; } func @cast_cc() { %0 = constant [2.0, 0.0] -> c32 - %1 = cast %0 : c32 -> c64 + %1 = cast %0 : c64 ; CHECK-LABEL: void cast_cc() { ; CHECK: double2 x1 = convert_double2(x); } func @cast_if() { %0 = constant 2 -> i32 - %1 = cast %0 : i32 -> f32 + %1 = cast %0 : f32 ; CHECK-LABEL: void cast_if() { ; CHECK: float x1 = (float) x; } func @cast_fi() { %0 = constant 2.0 -> f32 - %1 = cast %0 : f32 -> i16 + %1 = cast %0 : i16 ; CHECK-LABEL: void cast_fi() { ; CHECK: short x1 = (short) x; } func @cast_ic() { %0 = constant 2 -> i8 - %1 = cast %0 : i8 -> c32 + %1 = cast %0 : c32 ; CHECK-LABEL: void cast_ic() { ; CHECK: float2 x1 = (float2) (x, 0); } func @cast_fc() { %0 = constant 2.0 -> f64 - %1 = cast %0 : f64 -> c32 + %1 = cast %0 : c32 ; CHECK-LABEL: void cast_fc() { ; CHECK: float2 x1 = (float2) (x, 0); } diff --git a/test/codegen/coopmatrix_basic.ir b/test/codegen/coopmatrix_basic.ir index f3d690cc..fdd18ad0 100644 --- a/test/codegen/coopmatrix_basic.ir +++ b/test/codegen/coopmatrix_basic.ir @@ -42,7 +42,7 @@ func @coopmatrix_neg() subgroup_size(16) { func @coopmatrix_cast() subgroup_size(16) { %0 = constant 1 -> coopmatrix - %1 = cast %0 : coopmatrix -> coopmatrix + %1 = cast %0 : coopmatrix ; CHECK-LABEL: void coopmatrix_cast({{.*}} ; CHECK: float2 x1[8]; ; CHECK-NEXT: x1[0] = (float2) (x[0], 0); diff --git a/test/codegen/if.ir b/test/codegen/if.ir index dcd0b1ad..177a1212 100644 --- a/test/codegen/if.ir +++ b/test/codegen/if.ir @@ -5,8 +5,8 @@ func @if0(%0: i32) { %c16 = constant 16 -> i32 %c0 = constant 0 -> i32 - %1 = cmp.lt %0, %c16 : i32 - %2 = cmp.ge %0, %c0 : i32 + %1 = cmp.lt %0, %c16 : bool + %2 = cmp.ge %0, %c0 : bool %3 = arith.and %1, %2 : bool if %3 { } else { @@ -20,7 +20,7 @@ func @if0(%0: i32) { func @if1(%0: i32) { %c16 = constant 16 -> i32 - %1 = cmp.lt %0, %c16 : i32 + %1 = cmp.lt %0, %c16 : bool if %1 { } else { } @@ -30,7 +30,7 @@ func @if1(%0: i32) { func @if2(%0: i32) { %c16 = constant 16 -> i32 - %1 = cmp.lt %0, %c16 : i32 + %1 = cmp.lt %0, %c16 : bool if %1 -> () { yield : } else { @@ -43,7 +43,7 @@ func @if2(%0: i32) { func @if3(%0: i32) { %c16 = constant 16 -> i32 - %1 = cmp.lt %0, %c16 : i32 + %1 = cmp.lt %0, %c16 : bool %x = if %1 -> (i32) { yield %0 : i32 } else { @@ -59,7 +59,7 @@ func @if3(%0: i32) { func @if4(%0: i32) { %c16 = constant 16 -> i32 - %1 = cmp.lt %0, %c16 : i32 + %1 = cmp.lt %0, %c16 : bool %x, %y = if %1 -> (i32, f32) { if %1 { } diff --git a/test/codegen/scalar_arithmetic.ir b/test/codegen/scalar_arithmetic.ir index a63d65f6..da8faa6c 100644 --- a/test/codegen/scalar_arithmetic.ir +++ b/test/codegen/scalar_arithmetic.ir @@ -37,12 +37,12 @@ func @t1(%a: i32, %b: i32, %a1: bool, %b1: bool) { ; CHECK-NEXT: int x15 = abs(a); } func @t2(%a: i32, %b: i32) { - %1 = cmp.eq %a, %b : i32 - %2 = cmp.ne %a, %b : i32 - %3 = cmp.gt %a, %b : i32 - %4 = cmp.ge %a, %b : i32 - %5 = cmp.lt %a, %b : i32 - %6 = cmp.le %a, %b : i32 + %1 = cmp.eq %a, %b : bool + %2 = cmp.ne %a, %b : bool + %3 = cmp.gt %a, %b : bool + %4 = cmp.ge %a, %b : bool + %5 = cmp.lt %a, %b : bool + %6 = cmp.le %a, %b : bool ; CHECK: bool x = a == b; ; CHECK-NEXT: bool x1 = a != b; ; CHECK-NEXT: bool x2 = a > b; @@ -67,12 +67,12 @@ func @t3(%a: f32, %b: f32) { ; CHECK-NEXT: float x6 = fabs(a); } func @t4(%a: f32, %b: f32) { - %1 = cmp.eq %a, %b : f32 - %2 = cmp.ne %a, %b : f32 - %3 = cmp.gt %a, %b : f32 - %4 = cmp.ge %a, %b : f32 - %5 = cmp.lt %a, %b : f32 - %6 = cmp.le %a, %b : f32 + %1 = cmp.eq %a, %b : bool + %2 = cmp.ne %a, %b : bool + %3 = cmp.gt %a, %b : bool + %4 = cmp.ge %a, %b : bool + %5 = cmp.lt %a, %b : bool + %6 = cmp.le %a, %b : bool ; CHECK: bool x = a == b; ; CHECK-NEXT: bool x1 = a != b; ; CHECK-NEXT: bool x2 = a > b; @@ -81,7 +81,7 @@ func @t4(%a: f32, %b: f32) { ; CHECK-NEXT: bool x5 = a <= b; } func @t5(%a: i32) { - %b = cast %a : i32 -> index + %b = cast %a : index ; CHECK: long b = (long) a; } func @t6(%a: c32, %b: c32) { diff --git a/test/opt/check-ir/cast_forbidden.ir b/test/opt/check-ir/cast_forbidden.ir index d09c0fe4..31e785f9 100644 --- a/test/opt/check-ir/cast_forbidden.ir +++ b/test/opt/check-ir/cast_forbidden.ir @@ -4,6 +4,6 @@ ; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @cast_cf() { %0 = constant [2.0, 1.0] -> c32 - %1 = cast %0 : c32 -> i32 -; CHECK: :7.8-27: Forbidden cast + %1 = cast %0 : i32 +; CHECK: :7.8-20: Forbidden cast } diff --git a/test/opt/constant-propagation.ir b/test/opt/constant-propagation.ir index 4cecfdff..39a40243 100644 --- a/test/opt/constant-propagation.ir +++ b/test/opt/constant-propagation.ir @@ -93,34 +93,34 @@ func @known_arith() { func @known_cast() { %c0 = constant 32768 -> i32 %c1 = constant [3.0, -2.0] -> c32 - %0 = cast %c0 : i32 -> i16 - %1 = cast %c0 : i32 -> f32 - %2 = cast %c0 : i32 -> c32 - %3 = cast %c0 : i32 -> c32 - %4 = cast %c1 : c32 -> c64 + %0 = cast %c0 : i16 + %1 = cast %c0 : f32 + %2 = cast %c0 : c32 + %3 = cast %c0 : c32 + %4 = cast %c1 : c64 ; CHECK-LABEL: func @known_cast({{.*}} ; CHECK: %0 = constant -32768 -> i16 -; CHECK-NEXT: %1 = cast %c0 : i32 -> i16 +; CHECK-NEXT: %1 = cast %c0 : i16 ; CHECK-NEXT: %2 = constant 0x1p+15 -> f32 -; CHECK-NEXT: %3 = cast %c0 : i32 -> f32 +; CHECK-NEXT: %3 = cast %c0 : f32 ; CHECK-NEXT: %4 = constant [0x1p+15,0x0p+0] -> c32 -; CHECK-NEXT: %5 = cast %c0 : i32 -> c32 +; CHECK-NEXT: %5 = cast %c0 : c32 ; CHECK-NEXT: %6 = constant [0x1p+15,0x0p+0] -> c32 -; CHECK-NEXT: %7 = cast %c0 : i32 -> c32 +; CHECK-NEXT: %7 = cast %c0 : c32 ; CHECK-NEXT: %8 = constant [0x1.8p+1,-0x1p+1] -> c64 -; CHECK-NEXT: %9 = cast %c1 : c32 -> c64 +; CHECK-NEXT: %9 = cast %c1 : c64 } func @known_compare() { %0 = constant 1.0 -> f32 %1 = constant 2.0 -> f32 - %2 = cmp.eq %0, %0 : f32 - %3 = cmp.eq %0, %1 : f32 + %2 = cmp.eq %0, %0 : bool + %3 = cmp.eq %0, %1 : bool ; CHECK-LABEL: func @known_compare({{.*}} ; CHECK: %2 = constant true -> bool -; CHECK-NEXT: %3 = cmp.eq %0, %0 : f32 +; CHECK-NEXT: %3 = cmp.eq %0, %0 : bool ; CHECK-NEXT: %4 = constant false -> bool -; CHECK-NEXT: %5 = cmp.eq %0, %1 : f32 +; CHECK-NEXT: %5 = cmp.eq %0, %1 : bool } func @known_arith_complex() { diff --git a/test/opt/insert-barrier.ir b/test/opt/insert-barrier.ir index 3e398bfc..c0cd1f58 100644 --- a/test/opt/insert-barrier.ir +++ b/test/opt/insert-barrier.ir @@ -81,7 +81,7 @@ func @war_alias(%a: f32, %b: f32, %A: memref, %C: memref) { func @if(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { %c42 = constant 42.0 -> f32 - %0 = cmp.gt %a, %c42 : f32 + %0 = cmp.gt %a, %c42 : bool if %0 { axpby.n %a, %A, %b, %B axpby.n %a, %B, %b, %C @@ -103,7 +103,7 @@ func @if(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref< func @if2(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { %c42 = constant 42.0 -> f32 - %0 = cmp.gt %a, %c42 : f32 + %0 = cmp.gt %a, %c42 : bool axpby.n %a, %B, %b, %A if %0 { axpby.n %a, %A, %b, %B @@ -161,7 +161,7 @@ func @no_barrier_spmd(%a: f32, %b: f32, %A: memref, %B: memref %c4 = constant 4 -> index parallel { %0 = subgroup_id - %1 = cmp.eq %0, %c0 : i32 + %1 = cmp.eq %0, %c0 : bool if %1 { %2 = load %A[%c3,%c4] : memref store %2, %A[%c3,%c4] : memref @@ -171,7 +171,7 @@ func @no_barrier_spmd(%a: f32, %b: f32, %A: memref, %B: memref ; CHECK-LABEL: func @no_barrier_spmd({{.*}} ; CHECK: parallel { ; CHECK-NEXT: %0 = subgroup_id -; CHECK-NEXT: %1 = cmp.eq %0, %c0 : i32 +; CHECK-NEXT: %1 = cmp.eq %0, %c0 : bool ; CHECK-NEXT: if %1 { ; CHECK-NEXT: %2 = load %A[%c3,%c4] : memref ; CHECK-NEXT: store %2, %A[%c3,%c4] : memref diff --git a/test/spv/cast.ir b/test/spv/cast.ir index 64fa56a7..0c0f5a42 100644 --- a/test/spv/cast.ir +++ b/test/spv/cast.ir @@ -13,10 +13,10 @@ ; CHECK: %[[#C64_NULL:]] = OpConstantNull %[[#C64]] func @tint(%a: i64) { - %0 = cast %a : i64 -> i64 - %1 = cast %a : i64 -> i8 - %2 = cast %a : i64 -> f32 - %3 = cast %a : i64 -> c64 + %0 = cast %a : i64 + %1 = cast %a : i8 + %2 = cast %a : f32 + %3 = cast %a : c64 ; CHECK: %[[#]] = OpCopyObject %[[#I64]] %[[#]] ; CHECK-NEXT: %[[#]] = OpSConvert %[[#I8]] %[[#]] ; CHECK-NEXT: %[[#]] = OpConvertSToF %[[#F32]] %[[#]] @@ -25,9 +25,9 @@ func @tint(%a: i64) { } func @tfloat(%a: f32) { - %1 = cast %a : f32 -> i8 - %2 = cast %a : f32 -> f64 - %3 = cast %a : f32 -> c64 + %1 = cast %a : i8 + %2 = cast %a : f64 + %3 = cast %a : c64 ; CHECK: %[[#]] = OpConvertFToS %[[#I8]] %[[#]] ; CHECK-NEXT: %[[#]] = OpFConvert %[[#F64]] %[[#]] ; CHECK-NEXT: %[[#F32_TO_F64:]] = OpFConvert %[[#F64]] %[[#]] @@ -35,13 +35,13 @@ func @tfloat(%a: f32) { } func @tcomplex(%a: c32) { - %1 = cast %a : c32 -> c64 + %1 = cast %a : c64 ; CHECK: %[[#]] = OpFConvert %[[#C64]] %[[#]] } func @tfloatcoopmatrix() subgroup_size(16) { %0 = constant 1.0 -> coopmatrix - %2 = cast %0 : coopmatrix -> coopmatrix + %2 = cast %0 : coopmatrix ; CHECK: %[[#]] = OpConvertFToS %[[#I8]] %[[#]] ; CHECK-NEXT: %[[#]] = OpConvertFToS %[[#I8]] %[[#]] ; CHECK-NEXT: %[[#]] = OpConvertFToS %[[#I8]] %[[#]] diff --git a/test/spv/compare.ir b/test/spv/compare.ir index 7c28f699..1957c0cf 100644 --- a/test/spv/compare.ir +++ b/test/spv/compare.ir @@ -7,12 +7,12 @@ ; CHECK: %[[#BOOL2:]] = OpTypeVector %[[#BOOL]] 2 func @tint(%a: i64, %b: i64) { - %0 = cmp.eq %a, %b : i64 - %1 = cmp.ne %a, %b : i64 - %2 = cmp.gt %a, %b : i64 - %3 = cmp.ge %a, %b : i64 - %4 = cmp.lt %a, %b : i64 - %5 = cmp.le %a, %b : i64 + %0 = cmp.eq %a, %b : bool + %1 = cmp.ne %a, %b : bool + %2 = cmp.gt %a, %b : bool + %3 = cmp.ge %a, %b : bool + %4 = cmp.lt %a, %b : bool + %5 = cmp.le %a, %b : bool ; CHECK: %[[#]] = OpIEqual %[[#BOOL]] %[[#]] %[[#]] ; CHECK-NEXT: %[[#]] = OpINotEqual %[[#BOOL]] %[[#]] %[[#]] ; CHECK-NEXT: %[[#]] = OpSGreaterThan %[[#BOOL]] %[[#]] %[[#]] @@ -22,12 +22,12 @@ func @tint(%a: i64, %b: i64) { } func @tfloat(%a: f32, %b: f32) { - %0 = cmp.eq %a, %b : f32 - %1 = cmp.ne %a, %b : f32 - %2 = cmp.gt %a, %b : f32 - %3 = cmp.ge %a, %b : f32 - %4 = cmp.lt %a, %b : f32 - %5 = cmp.le %a, %b : f32 + %0 = cmp.eq %a, %b : bool + %1 = cmp.ne %a, %b : bool + %2 = cmp.gt %a, %b : bool + %3 = cmp.ge %a, %b : bool + %4 = cmp.lt %a, %b : bool + %5 = cmp.le %a, %b : bool ; CHECK: %[[#]] = OpFOrdEqual %[[#BOOL]] %[[#]] %[[#]] ; CHECK-NEXT: %[[#]] = OpFUnordNotEqual %[[#BOOL]] %[[#]] %[[#]] ; CHECK-NEXT: %[[#]] = OpFOrdGreaterThan %[[#BOOL]] %[[#]] %[[#]] @@ -37,8 +37,8 @@ func @tfloat(%a: f32, %b: f32) { } func @tcomplex(%a: c32, %b: c32) { - %0 = cmp.eq %a, %b : c32 - %1 = cmp.ne %a, %b : c32 + %0 = cmp.eq %a, %b : bool + %1 = cmp.ne %a, %b : bool ; CHECK: %[[#COMPONENTS_EQUAL:]] = OpFOrdEqual %[[#BOOL2]] %[[#]] %[[#]] ; CHECK-NEXT: %[[#]] = OpAll %[[#BOOL]] %[[#COMPONENTS_EQUAL]] ; CHECK-NEXT: %[[#COMPONENTS_NOT_EQUAL:]] = OpFUnordNotEqual %[[#BOOL2]] %[[#]] %[[#]] diff --git a/test/spv/if.ir b/test/spv/if.ir index 43a3be15..420bc70b 100644 --- a/test/spv/if.ir +++ b/test/spv/if.ir @@ -13,7 +13,7 @@ func @if0(%0: i32) { %c42 = constant 42 -> i32 - %1 = cmp.lt %0, %c42 : i32 + %1 = cmp.lt %0, %c42 : bool if %1 { %2 = arith.neg %0 : i32 } else { From e75997db4670f00f44ba8a338f9f9fa9456ce91e Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 21 Nov 2024 16:17:45 +0100 Subject: [PATCH 120/297] Update constant and coopmatrix insts Signed-off-by: Carsten Uphoff --- include/tinytc/tinytc.h | 4 +- include/tinytc/tinytc.hpp | 6 +- src/codegen_tools.cpp | 2 +- src/inst.cpp | 5 +- src/node/inst_node.cpp | 97 ++++++++++++---------- src/node/inst_node.hpp | 3 +- src/parser/parser_impl.yy | 28 +++---- src/pass/clone.cpp | 3 +- src/pass/dump_ir.cpp | 19 +---- test/codegen/atomic.ir | 8 +- test/codegen/axpby0.ir | 2 +- test/codegen/axpby1.ir | 14 ++-- test/codegen/cast.ir | 14 ++-- test/codegen/coopmatrix_basic.ir | 10 +-- test/codegen/coopmatrix_load.ir | 20 ++--- test/codegen/coopmatrix_mul_add.ir | 40 ++++----- test/codegen/coopmatrix_store.ir | 24 +++--- test/codegen/dope_vector_group0.ir | 4 +- test/codegen/expand.ir | 26 +++--- test/codegen/for.ir | 16 ++-- test/codegen/fuse.ir | 8 +- test/codegen/if.ir | 18 ++-- test/codegen/load.ir | 2 +- test/codegen/store.ir | 2 +- test/codegen/type_mismatch1.ir | 2 +- test/codegen/work_group.ir | 4 +- test/opt/check-ir/cast_forbidden.ir | 2 +- test/opt/check-ir/expand.ir | 26 +++--- test/opt/check-ir/nesting0.ir | 4 +- test/opt/check-ir/nesting1.ir | 4 +- test/opt/check-ir/nesting3.ir | 4 +- test/opt/check-ir/subview.ir | 16 ++-- test/opt/constant-propagation-safe.ir | 46 +++++----- test/opt/constant-propagation-unsafe.ir | 16 ++-- test/opt/constant-propagation.ir | 106 ++++++++++++------------ test/opt/dead-code-elimination.ir | 36 ++++---- test/opt/dump-def-use.ir | 12 +-- test/opt/insert-barrier.ir | 18 ++-- test/opt/insert-lifetime-stop.ir | 10 +-- test/opt/work-group-size.ir | 16 ++-- test/spv/alloca.ir | 2 +- test/spv/arith.ir | 4 +- test/spv/arith_unary.ir | 2 +- test/spv/cast.ir | 2 +- test/spv/cooperative_matrix_load.ir | 8 +- test/spv/cooperative_matrix_mul_add.ir | 40 ++++----- test/spv/cooperative_matrix_scale.ir | 6 +- test/spv/cooperative_matrix_store.ir | 4 +- test/spv/expand.ir | 2 +- test/spv/for.ir | 20 ++--- test/spv/fuse.ir | 2 +- test/spv/if.ir | 14 ++-- test/spv/load.ir | 4 +- test/spv/store.ir | 8 +- test/spv/work_group.ir | 6 +- 55 files changed, 400 insertions(+), 421 deletions(-) diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 3539bdb9..0fb05610 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -411,12 +411,14 @@ TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_mul_add_inst_create( * @param instr [out] pointer to the inst object created * @param a [in] %a * @param b [in] %b + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_scale_inst_create( - tinytc_inst_t *instr, tinytc_value_t a, tinytc_value_t b, const tinytc_location_t *loc); + tinytc_inst_t *instr, tinytc_value_t a, tinytc_value_t b, tinytc_data_type_t ty, + const tinytc_location_t *loc); /** * @brief Create cooperative matrix store instruction diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index d59c6e26..9dafba64 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1090,13 +1090,15 @@ inline inst make_cooperative_matrix_mul_add(value a, value b, value c, data_type * * @param a %a * @param b %b + * @param ty Result type * @param loc Source code location * * @return Instruction */ -inline inst make_cooperative_matrix_scale(value a, value b, location const &loc = {}) { +inline inst make_cooperative_matrix_scale(value a, value b, data_type ty, + location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_cooperative_matrix_scale_inst_create(&instr, a, b, &loc), loc); + CHECK_STATUS_LOC(tinytc_cooperative_matrix_scale_inst_create(&instr, a, b, ty, &loc), loc); return inst(instr); } diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 310d3a8b..4382887f 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -621,7 +621,7 @@ auto mixed_precision_coopmatrix_scale(region_builder &bb, value a, value b, a = bb.add(make_cast(a, compatible_ty, loc)); } } - return bb.add(make_cooperative_matrix_scale(a, b, loc)); + return bb.add(make_cooperative_matrix_scale(a, b, bt, loc)); } auto get_atomic_store_flag(value beta) -> std::optional { diff --git a/src/inst.cpp b/src/inst.cpp index 86ba5a2f..4052f608 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -352,13 +352,14 @@ tinytc_status_t tinytc_cooperative_matrix_mul_add_inst_create(tinytc_inst_t *ins } tinytc_status_t tinytc_cooperative_matrix_scale_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - tinytc_value_t b, + tinytc_value_t b, tinytc_data_type_t ty, const tinytc_location_t *loc) { if (instr == nullptr || a == nullptr || b == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(a, b, get_optional(loc)).release(); + *instr = + std::make_unique(a, b, ty, get_optional(loc)).release(); }); } diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 0b565c38..91502b42 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -78,6 +78,14 @@ auto tinytc_inst::kind() const -> tinytc::inst_execution_kind { namespace tinytc { +coopmatrix_data_type *get_coopmatrix_type(location const &loc, tinytc_value const &v) { + auto m = dyn_cast(v.ty()); + if (m == nullptr) { + throw compilation_error(loc, {&v}, status::ir_expected_coopmatrix); + } + return m; +} + scalar_data_type *get_scalar_type(location const &loc, tinytc_value const &v) { auto m = dyn_cast(v.ty()); if (m == nullptr) { @@ -94,9 +102,9 @@ memref_data_type *get_memref_type(location const &loc, tinytc_value const &v) { return m; } -void check_index_ty(location const &loc, tinytc_data_type_t ty) { - if (auto sty = dyn_cast(ty); !sty || sty->ty() != scalar_type::index) { - throw compilation_error(loc, status::ir_expected_index); +void check_index_ty(location const &loc, tinytc_value const &v) { + if (auto sty = dyn_cast(v.ty()); !sty || sty->ty() != scalar_type::index) { + throw compilation_error(loc, {&v}, status::ir_expected_index); } } @@ -454,23 +462,21 @@ cooperative_matrix_load_inst::cooperative_matrix_load_inst(transpose t, checked_ op(op_pos1, p1); loc(lc); - auto ot = dyn_cast(operand().ty()); - if (!ot) { - throw compilation_error(loc(), status::ir_expected_memref); - } auto rt = dyn_cast(to_ty); if (!rt) { throw compilation_error(loc(), status::ir_expected_coopmatrix); } + + auto ot = get_memref_type(loc(), operand()); if (ot->element_ty() != rt->component_ty()) { - throw compilation_error(loc(), status::ir_scalar_mismatch); + throw compilation_error(loc(), {&operand()}, status::ir_scalar_mismatch); } if (ot->dim() != 2) { - throw compilation_error(loc(), status::ir_expected_memref_order_2); + throw compilation_error(loc(), {&operand()}, status::ir_expected_memref_order_2); } - check_index_ty(lc, pos0().ty()); - check_index_ty(lc, pos1().ty()); + check_index_ty(lc, pos0()); + check_index_ty(lc, pos1()); result(0) = value_node{to_ty, this, lc}; } @@ -486,13 +492,26 @@ cooperative_matrix_mul_add_inst::cooperative_matrix_mul_add_inst(tinytc_value_t op(op_c, c0); loc(lc); - auto at = dyn_cast(a().ty()); - auto bt = dyn_cast(b().ty()); - auto ct = dyn_cast(c().ty()); auto rt = dyn_cast(to_ty); - if (!at || !bt || !ct || !rt) { + if (!rt) { throw compilation_error(loc(), status::ir_expected_memref); } + if (rt->use() != matrix_use::acc) { + throw compilation_error(loc(), status::ir_invalid_matrix_use); + } + + auto at = get_coopmatrix_type(loc(), a()); + auto bt = get_coopmatrix_type(loc(), b()); + auto ct = get_coopmatrix_type(loc(), c()); + if (at->use() != matrix_use::a) { + throw compilation_error(loc(), {&a()}, status::ir_invalid_matrix_use); + } + if (bt->use() != matrix_use::b) { + throw compilation_error(loc(), {&b()}, status::ir_invalid_matrix_use); + } + if (ct->use() != matrix_use::acc) { + throw compilation_error(loc(), {&c()}, status::ir_invalid_matrix_use); + } auto M = rt->rows(); auto N = rt->cols(); @@ -505,42 +524,38 @@ cooperative_matrix_mul_add_inst::cooperative_matrix_mul_add_inst(tinytc_value_t oss << "B=" << bt->rows() << "x" << bt->cols() << ", "; oss << "C=" << ct->rows() << "x" << ct->cols() << ", "; oss << "result=" << rt->rows() << "x" << rt->cols(); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); - } - if (at->use() != matrix_use::a && bt->use() != matrix_use::b && ct->use() != matrix_use::acc && - rt->use() != matrix_use::acc) { - throw compilation_error(loc(), status::ir_invalid_matrix_use); + throw compilation_error(loc(), {&a(), &b(), &c()}, status::ir_incompatible_shapes, + oss.str()); } const auto AB_ty = compatible_type(at->component_ty(), bt->component_ty()); if (compatible_type(AB_ty, ct->component_ty()) != ct->component_ty()) { - throw compilation_error(loc(), status::ir_incompatible_scalar_types); + throw compilation_error(loc(), {&a(), &b(), &c()}, status::ir_incompatible_scalar_types); } result(0) = value_node{to_ty, this, lc}; } cooperative_matrix_scale_inst::cooperative_matrix_scale_inst(tinytc_value_t a0, tinytc_value_t b0, + tinytc_data_type_t ty, location const &lc) : standard_inst{IK::cooperative_matrix_scale} { op(op_a, a0); op(op_b, b0); loc(lc); - auto at = dyn_cast(a().ty()); - if (!at) { - throw compilation_error(loc(), status::ir_expected_scalar); - } - auto bt = dyn_cast(b().ty()); - if (!bt) { - throw compilation_error(loc(), status::ir_expected_memref); + if (b().ty() != ty) { + throw compilation_error(loc(), {&b()}, status::ir_operand_type_must_match_return_type); } + auto at = get_scalar_type(loc(), a()); + auto bt = get_coopmatrix_type(loc(), b()); + if (at->ty() != bt->component_ty()) { - throw compilation_error(loc(), status::ir_scalar_mismatch); + throw compilation_error(loc(), {&a(), &b()}, status::ir_scalar_mismatch); } - result(0) = value_node{b().ty(), this, lc}; + result(0) = value_node{ty, this, lc}; } cooperative_matrix_store_inst::cooperative_matrix_store_inst(checked_flag cflag, store_flag sflag, @@ -554,23 +569,17 @@ cooperative_matrix_store_inst::cooperative_matrix_store_inst(checked_flag cflag, op(op_pos1, p1); loc(lc); - auto vt = dyn_cast(val().ty()); - if (!vt) { - throw compilation_error(loc(), status::ir_expected_coopmatrix); - } - auto ot = dyn_cast(operand().ty()); - if (!ot) { - throw compilation_error(loc(), status::ir_expected_memref); - } + auto vt = get_coopmatrix_type(loc(), val()); + auto ot = get_memref_type(loc(), operand()); if (vt->component_ty() != ot->element_ty()) { - throw compilation_error(loc(), status::ir_scalar_mismatch); + throw compilation_error(loc(), {&val(), &operand()}, status::ir_scalar_mismatch); } if (ot->dim() != 2) { - throw compilation_error(loc(), status::ir_expected_memref_order_2); + throw compilation_error(loc(), {&operand()}, status::ir_expected_memref_order_2); } - check_index_ty(lc, pos0().ty()); - check_index_ty(lc, pos1().ty()); + check_index_ty(lc, pos0()); + check_index_ty(lc, pos1()); } expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, @@ -742,7 +751,7 @@ load_inst::load_inst(tinytc_value_t op0, array_view index_list0, : standard_inst{IK::load, static_cast(1 + index_list0.size())} { op(0, op0); for (std::size_t i = 0; i < index_list0.size(); ++i) { - check_index_ty(lc, index_list0[i]->ty()); + check_index_ty(lc, *index_list0[i]); op(1 + i, index_list0[i]); } loc(lc); @@ -993,7 +1002,7 @@ store_inst::store_inst(store_flag flag, tinytc_value_t val0, tinytc_value_t op0, { std::size_t i = op_operand; for (auto const &val : index_list0) { - check_index_ty(lc, val->ty()); + check_index_ty(lc, *val); op(++i, val); } } diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 2c4417ec..11d029f1 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -502,7 +502,8 @@ class cooperative_matrix_scale_inst : public standard_inst<2, 1, 0> { return i.type_id() == IK::cooperative_matrix_scale; } enum op_number { op_a = 0, op_b = 1 }; - cooperative_matrix_scale_inst(tinytc_value_t a0, tinytc_value_t b0, location const &lc = {}); + cooperative_matrix_scale_inst(tinytc_value_t a0, tinytc_value_t b0, tinytc_data_type_t ty, + location const &lc = {}); inline auto a() -> tinytc_value & { return op(op_a); } inline auto a() const -> tinytc_value const & { return op(op_a); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 279141e8..ed2de8f6 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -847,7 +847,7 @@ compare_inst: ; constant_inst: - CONSTANT LSQBR FLOATING_CONSTANT[re] COMMA FLOATING_CONSTANT[im] RSQBR ARROW data_type { + CONSTANT LSQBR FLOATING_CONSTANT[re] COMMA FLOATING_CONSTANT[im] RSQBR COLON data_type { try { $$ = inst { std::make_unique(std::complex{$re, $im}, $data_type, @constant_inst) @@ -858,7 +858,7 @@ constant_inst: YYERROR; } } - | CONSTANT FLOATING_CONSTANT ARROW data_type { + | CONSTANT FLOATING_CONSTANT COLON data_type { try { $$ = inst { std::make_unique($FLOATING_CONSTANT, $data_type, @constant_inst).release() @@ -868,7 +868,7 @@ constant_inst: YYERROR; } } - | CONSTANT INTEGER_CONSTANT ARROW data_type { + | CONSTANT INTEGER_CONSTANT COLON data_type { try { $$ = inst { std::make_unique($INTEGER_CONSTANT, $data_type, @constant_inst).release() @@ -878,7 +878,7 @@ constant_inst: YYERROR; } } - | CONSTANT BOOLEAN_CONSTANT ARROW data_type { + | CONSTANT BOOLEAN_CONSTANT COLON data_type { try { $$ = inst { std::make_unique($BOOLEAN_CONSTANT, $data_type, @constant_inst).release() @@ -891,8 +891,7 @@ constant_inst: ; cooperative_matrix_load_inst: - COOPERATIVE_MATRIX_LOAD transpose checked var[op] LSQBR var[p0] COMMA var[p1] RSQBR COLON data_type[op_ty] ARROW data_type[result_ty] { - check_type($op, $op_ty, @op, @op_ty); + COOPERATIVE_MATRIX_LOAD transpose checked var[op] LSQBR var[p0] COMMA var[p1] RSQBR COLON data_type[result_ty] { try { $$ = inst { std::make_unique( @@ -913,10 +912,7 @@ checked: ; cooperative_matrix_mul_add_inst: - COOPERATIVE_MATRIX_MUL_ADD var[a] COMMA var[b] COMMA var[c] COLON data_type[a_ty] COMMA data_type[b_ty] COMMA data_type[c_ty] ARROW data_type[to_ty] { - check_type($a, $a_ty, @a, @a_ty); - check_type($b, $b_ty, @b, @b_ty); - check_type($c, $c_ty, @c, @c_ty); + COOPERATIVE_MATRIX_MUL_ADD var[a] COMMA var[b] COMMA var[c] COLON data_type[to_ty] { try { $$ = inst { std::make_unique(std::move($a), std::move($b), @@ -932,13 +928,11 @@ cooperative_matrix_mul_add_inst: ; cooperative_matrix_scale_inst: - COOPERATIVE_MATRIX_SCALE var[a] COMMA var[b] COLON data_type[a_ty] COMMA data_type[b_ty] { - check_type($a, $a_ty, @a, @a_ty); - check_type($b, $b_ty, @b, @b_ty); + COOPERATIVE_MATRIX_SCALE var[a] COMMA var[b] COLON data_type[ty] { try { $$ = inst { - std::make_unique(std::move($a), std::move($b), - @cooperative_matrix_scale_inst) + std::make_unique( + std::move($a), std::move($b), std::move($ty), @cooperative_matrix_scale_inst) .release() }; } catch (compilation_error const &e) { @@ -949,9 +943,7 @@ cooperative_matrix_scale_inst: ; cooperative_matrix_store_inst: - COOPERATIVE_MATRIX_STORE checked store_flag var[val] COMMA var[op] LSQBR var[p0] COMMA var[p1] RSQBR COLON data_type[val_ty] COMMA data_type[op_ty] { - check_type($val, $val_ty, @val, @val_ty); - check_type($op, $op_ty, @op, @op_ty); + COOPERATIVE_MATRIX_STORE checked store_flag var[val] COMMA var[op] LSQBR var[p0] COMMA var[p1] RSQBR { try { $$ = inst { std::make_unique( diff --git a/src/pass/clone.cpp b/src/pass/clone.cpp index 8a542118..1cfdafef 100644 --- a/src/pass/clone.cpp +++ b/src/pass/clone.cpp @@ -55,7 +55,8 @@ auto inst_cloner::operator()(cooperative_matrix_mul_add_inst &in) -> std::unique subs(&in.a()), subs(&in.b()), subs(&in.c()), in.result(0).ty(), in.loc()); } auto inst_cloner::operator()(cooperative_matrix_scale_inst &in) -> std::unique_ptr { - return std::make_unique(subs(&in.a()), subs(&in.b()), in.loc()); + return std::make_unique(subs(&in.a()), subs(&in.b()), + in.result(0).ty(), in.loc()); } auto inst_cloner::operator()(cooperative_matrix_store_inst &in) -> std::unique_ptr { return std::make_unique(in.checked(), in.flag(), subs(&in.val()), diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 51858ccb..8c9cdd6f 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -180,7 +180,7 @@ void dump_ir_pass::operator()(constant_inst const &c) { }, }, c.value()); - *os_ << " -> "; + *os_ << " : "; visit(*this, *c.result()->ty()); } @@ -198,8 +198,6 @@ void dump_ir_pass::operator()(cooperative_matrix_load_inst const &c) { *os_ << ","; dump_val(c.pos1()); *os_ << "] : "; - visit(*this, *c.operand().ty()); - *os_ << " -> "; visit(*this, *c.result(0).ty()); } @@ -212,12 +210,6 @@ void dump_ir_pass::operator()(cooperative_matrix_mul_add_inst const &c) { *os_ << ", "; dump_val(c.c()); *os_ << " : "; - visit(*this, *c.a().ty()); - *os_ << ", "; - visit(*this, *c.b().ty()); - *os_ << ", "; - visit(*this, *c.c().ty()); - *os_ << " -> "; visit(*this, *c.result(0).ty()); } @@ -228,9 +220,7 @@ void dump_ir_pass::operator()(cooperative_matrix_scale_inst const &c) { *os_ << ", "; dump_val(c.b()); *os_ << " : "; - visit(*this, *c.a().ty()); - *os_ << ", "; - visit(*this, *c.b().ty()); + visit(*this, *c.result(0).ty()); } void dump_ir_pass::operator()(cooperative_matrix_store_inst const &c) { @@ -249,10 +239,7 @@ void dump_ir_pass::operator()(cooperative_matrix_store_inst const &c) { dump_val(c.pos0()); *os_ << ","; dump_val(c.pos1()); - *os_ << "] : "; - visit(*this, *c.val().ty()); - *os_ << ", "; - visit(*this, *c.operand().ty()); + *os_ << "]"; } void dump_ir_pass::operator()(expand_inst const &e) { diff --git a/test/codegen/atomic.ir b/test/codegen/atomic.ir index e5ea5c93..ab5d86e6 100644 --- a/test/codegen/atomic.ir +++ b/test/codegen/atomic.ir @@ -3,8 +3,8 @@ ; RUN: %tinytc-oc < %s | filecheck %s func @atomic_store(%A: memref) { - %f0 = constant 0.0 -> f64 - %i0 = constant 0 -> index + %f0 = constant 0.0 : f64 + %i0 = constant 0 : index store.atomic %f0, %A[%i0] : memref store.atomic_add %f0, %A[%i0] : memref ; CHECK-LABEL: void atomic_store({{.*}} @@ -13,8 +13,8 @@ func @atomic_store(%A: memref) { } func @atomic_store_c64(%A: memref) { - %f0 = constant [0.0, 0.0] -> c64 - %i0 = constant 0 -> index + %f0 = constant [0.0, 0.0] : c64 + %i0 = constant 0 : index store.atomic %f0, %A[%i0] : memref store.atomic_add %f0, %A[%i0] : memref ; CHECK-LABEL: void atomic_store_c64({{.*}} diff --git a/test/codegen/axpby0.ir b/test/codegen/axpby0.ir index 33bc2196..b8591f97 100644 --- a/test/codegen/axpby0.ir +++ b/test/codegen/axpby0.ir @@ -3,7 +3,7 @@ ; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s func @axpby(%alpha: f32, %A: memref, %B: memref) { - %zero = constant 0.0 -> f32 + %zero = constant 0.0 : f32 axpby.n %alpha, %A, %zero, %B ; CHECK: 7.5-33: Incompatible tensor shapes } diff --git a/test/codegen/axpby1.ir b/test/codegen/axpby1.ir index 41c3a9e4..5a19eff2 100644 --- a/test/codegen/axpby1.ir +++ b/test/codegen/axpby1.ir @@ -3,28 +3,28 @@ ; RUN: %tinytc-oc < %s func @axpby0(%alpha: f32, %A: memref, %B: memref) { - %z = constant 0.0 -> f32 + %z = constant 0.0 : f32 axpby.n %alpha, %A, %z, %B } func @axpby1(%alpha: f32, %A: memref>, %B: memref) { - %z = constant 0.0 -> f32 + %z = constant 0.0 : f32 axpby.n %alpha, %A, %z, %B } func @axpby2(%alpha: f32, %A: memref, %B: memref) { - %z = constant 0.0 -> f32 + %z = constant 0.0 : f32 axpby.n %alpha, %A, %z, %B } func @axpby3(%alpha: f32, %A: memref, %B: memref) { - %z = constant 0.0 -> f32 - %lb = constant 0 -> index - %ub = constant 5 -> index + %z = constant 0.0 : f32 + %lb = constant 0 : index + %ub = constant 5 : index for %i=%lb,%ub { %A0 = subview %A[0:48,0:48,0:4,%i] : memref %B0 = subview %B[0:48,0:48,0:4,%i] : memref - %ub1 = constant 4 -> index + %ub1 = constant 4 : index for %j=%lb,%ub1 { %A1 = subview %A0[0:48,0:48,%j] : memref %B1 = subview %B0[0:48,0:48,%j] : memref diff --git a/test/codegen/cast.ir b/test/codegen/cast.ir index 5750e72d..caa336c8 100644 --- a/test/codegen/cast.ir +++ b/test/codegen/cast.ir @@ -3,43 +3,43 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @cast_ii() { - %0 = constant 2 -> index + %0 = constant 2 : index %1 = cast %0 : i32 ; CHECK-LABEL: void cast_ii() { ; CHECK: int x1 = (int) x; } func @cast_ff() { - %0 = constant 2.0 -> f32 + %0 = constant 2.0 : f32 %1 = cast %0 : f64 ; CHECK-LABEL: void cast_ff() { ; CHECK: double x1 = (double) x; } func @cast_cc() { - %0 = constant [2.0, 0.0] -> c32 + %0 = constant [2.0, 0.0] : c32 %1 = cast %0 : c64 ; CHECK-LABEL: void cast_cc() { ; CHECK: double2 x1 = convert_double2(x); } func @cast_if() { - %0 = constant 2 -> i32 + %0 = constant 2 : i32 %1 = cast %0 : f32 ; CHECK-LABEL: void cast_if() { ; CHECK: float x1 = (float) x; } func @cast_fi() { - %0 = constant 2.0 -> f32 + %0 = constant 2.0 : f32 %1 = cast %0 : i16 ; CHECK-LABEL: void cast_fi() { ; CHECK: short x1 = (short) x; } func @cast_ic() { - %0 = constant 2 -> i8 + %0 = constant 2 : i8 %1 = cast %0 : c32 ; CHECK-LABEL: void cast_ic() { ; CHECK: float2 x1 = (float2) (x, 0); } func @cast_fc() { - %0 = constant 2.0 -> f64 + %0 = constant 2.0 : f64 %1 = cast %0 : c32 ; CHECK-LABEL: void cast_fc() { ; CHECK: float2 x1 = (float2) (x, 0); diff --git a/test/codegen/coopmatrix_basic.ir b/test/codegen/coopmatrix_basic.ir index fdd18ad0..5b135727 100644 --- a/test/codegen/coopmatrix_basic.ir +++ b/test/codegen/coopmatrix_basic.ir @@ -3,7 +3,7 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @coopmatrix_constant() { - %0 = constant 1.0 -> coopmatrix + %0 = constant 1.0 : coopmatrix ; CHECK-LABEL: void coopmatrix_constant({{.*}} ; CHECK: double x[5]; ; CHECK-NEXT: x[0] = 0x1p+0; @@ -14,8 +14,8 @@ func @coopmatrix_constant() { } func @coopmatrix_add() { - %0 = constant 1.0 -> coopmatrix - %1 = constant 1.0 -> coopmatrix + %0 = constant 1.0 : coopmatrix + %1 = constant 1.0 : coopmatrix %2 = arith.add %0, %1 : coopmatrix ; CHECK-LABEL: void coopmatrix_add({{.*}} ; CHECK: double x2[4]; @@ -26,7 +26,7 @@ func @coopmatrix_add() { } func @coopmatrix_neg() subgroup_size(16) { - %0 = constant 1.0 -> coopmatrix + %0 = constant 1.0 : coopmatrix %1 = arith.neg %0 : coopmatrix ; CHECK-LABEL: void coopmatrix_neg({{.*}} ; CHECK: double x1[8]; @@ -41,7 +41,7 @@ func @coopmatrix_neg() subgroup_size(16) { } func @coopmatrix_cast() subgroup_size(16) { - %0 = constant 1 -> coopmatrix + %0 = constant 1 : coopmatrix %1 = cast %0 : coopmatrix ; CHECK-LABEL: void coopmatrix_cast({{.*}} ; CHECK: float2 x1[8]; diff --git a/test/codegen/coopmatrix_load.ir b/test/codegen/coopmatrix_load.ir index d0b51885..3e64cdfb 100644 --- a/test/codegen/coopmatrix_load.ir +++ b/test/codegen/coopmatrix_load.ir @@ -3,7 +3,7 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @coopmatrix_a_load_n(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n %A[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.n %A[%x,%y] : coopmatrix ; CHECK-LABEL: void coopmatrix_a_load_n({{.*}} ; CHECK: float x1[8]; ; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; @@ -18,7 +18,7 @@ func @coopmatrix_a_load_n(%A: memref, %x: index, %y: index) subgroup_ } func @coopmatrix_a_load_n_rows_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n.rows_checked %A[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.n.rows_checked %A[%x,%y] : coopmatrix ; CHECK-LABEL: void coopmatrix_a_load_n_rows_checked({{.*}} ; CHECK: float x1[4]; ; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; @@ -33,7 +33,7 @@ func @coopmatrix_a_load_n_rows_checked(%A: memref, %x: index, %y: ind } func @coopmatrix_a_load_n_cols_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n.cols_checked %A[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.n.cols_checked %A[%x,%y] : coopmatrix ; CHECK-LABEL: void coopmatrix_a_load_n_cols_checked({{.*}} ; CHECK: float x1[4]; ; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; @@ -46,7 +46,7 @@ func @coopmatrix_a_load_n_cols_checked(%A: memref, %x: index, %y: ind } func @coopmatrix_a_load_n_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n.both_checked %A[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.n.both_checked %A[%x,%y] : coopmatrix ; CHECK-LABEL: void coopmatrix_a_load_n_checked({{.*}} ; CHECK: float x1[16]; ; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; @@ -73,7 +73,7 @@ func @coopmatrix_a_load_n_checked(%A: memref, %x: index, %y: index) s } func @coopmatrix_a_load_t(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.t %A[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.t %A[%x,%y] : coopmatrix ; CHECK-LABEL: void coopmatrix_a_load_t({{.*}} ; CHECK: float x1[8]; ; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; @@ -88,7 +88,7 @@ func @coopmatrix_a_load_t(%A: memref, %x: index, %y: index) subgroup_ } func @coopmatrix_a_load_t_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.t.both_checked %A[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.t.both_checked %A[%x,%y] : coopmatrix ; CHECK-LABEL: void coopmatrix_a_load_t_checked({{.*}} ; CHECK: float x1[8]; ; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; @@ -106,7 +106,7 @@ func @coopmatrix_a_load_t_checked(%A: memref, %x: index, %y: index) s } func @coopmatrix_b_load_n(%B: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n %B[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.n %B[%x,%y] : coopmatrix ; CHECK-LABEL: void coopmatrix_b_load_n({{.*}} ; CHECK: float x1[8]; ; CHECK-NEXT: global float* x2 = B + x * 1 + y * 64; @@ -121,7 +121,7 @@ func @coopmatrix_b_load_n(%B: memref, %x: index, %y: index) subgroup_ } func @coopmatrix_b_load_n_checked(%B: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n.both_checked %B[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.n.both_checked %B[%x,%y] : coopmatrix ; CHECK-LABEL: void coopmatrix_b_load_n_checked({{.*}} ; CHECK: float x1[16]; ; CHECK-NEXT: global float* x2 = B + x * 1 + y * 64; @@ -148,7 +148,7 @@ func @coopmatrix_b_load_n_checked(%B: memref, %x: index, %y: index) s } func @coopmatrix_b_load_t(%B: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.t %B[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.t %B[%x,%y] : coopmatrix ; CHECK-LABEL: void coopmatrix_b_load_t({{.*}} ; CHECK: float x1[8]; ; CHECK-NEXT: global float* x2 = B + x * 1 + y * 64; @@ -163,7 +163,7 @@ func @coopmatrix_b_load_t(%B: memref, %x: index, %y: index) subgroup_ } func @coopmatrix_b_load_t_checked(%B: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.t.both_checked %B[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.t.both_checked %B[%x,%y] : coopmatrix ; CHECK-LABEL: void coopmatrix_b_load_t_checked({{.*}} ; CHECK: float x1[8]; ; CHECK-NEXT: global float* x2 = B + x * 1 + y * 64; diff --git a/test/codegen/coopmatrix_mul_add.ir b/test/codegen/coopmatrix_mul_add.ir index aec95f2e..8ea8218f 100644 --- a/test/codegen/coopmatrix_mul_add.ir +++ b/test/codegen/coopmatrix_mul_add.ir @@ -3,12 +3,10 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @coopmatrix_mul_add_ff() subgroup_size(16) { - %a = constant 1.0 -> coopmatrix - %b = constant 1.0 -> coopmatrix - %c = constant 1.0 -> coopmatrix - %c_next = cooperative_matrix_mul_add %a, %b, %c - : coopmatrix, coopmatrix, - coopmatrix -> coopmatrix + %a = constant 1.0 : coopmatrix + %b = constant 1.0 : coopmatrix + %c = constant 1.0 : coopmatrix + %c_next = cooperative_matrix_mul_add %a, %b, %c : coopmatrix ; CHECK-LABEL: void coopmatrix_mul_add_ff({{.*}} ; CHECK: float c_next[4]; ; CHECK-NEXT: c_next[0] = fma(a[0], sub_group_broadcast(b[0], 0), c[0]); @@ -22,12 +20,10 @@ func @coopmatrix_mul_add_ff() subgroup_size(16) { } func @coopmatrix_mul_add_cf() subgroup_size(16) { - %a = constant [1.0, 0.0] -> coopmatrix - %b = constant 1.0 -> coopmatrix - %c = constant [1.0, 0.0] -> coopmatrix - %c_next = cooperative_matrix_mul_add %a, %b, %c - : coopmatrix, coopmatrix, - coopmatrix -> coopmatrix + %a = constant [1.0, 0.0] : coopmatrix + %b = constant 1.0 : coopmatrix + %c = constant [1.0, 0.0] : coopmatrix + %c_next = cooperative_matrix_mul_add %a, %b, %c : coopmatrix ; CHECK-LABEL: void coopmatrix_mul_add_cf({{.*}} ; CHECK: float2 c_next[4]; ; CHECK-NEXT: c_next[0] = c[0] + a[0] * sub_group_broadcast(b[0], 0); @@ -41,12 +37,10 @@ func @coopmatrix_mul_add_cf() subgroup_size(16) { } func @coopmatrix_mul_add_fc() subgroup_size(16) { - %a = constant 1.0 -> coopmatrix - %b = constant [1.0, 0.0] -> coopmatrix - %c = constant [1.0, 0.0] -> coopmatrix - %c_next = cooperative_matrix_mul_add %a, %b, %c - : coopmatrix, coopmatrix, - coopmatrix -> coopmatrix + %a = constant 1.0 : coopmatrix + %b = constant [1.0, 0.0] : coopmatrix + %c = constant [1.0, 0.0] : coopmatrix + %c_next = cooperative_matrix_mul_add %a, %b, %c : coopmatrix ; CHECK-LABEL: void coopmatrix_mul_add_fc({{.*}} ; CHECK: float2 c_next[4]; ; CHECK-NEXT: c_next[0].x = c[0].x + a[0] * sub_group_broadcast(b[0].x, 0); @@ -68,12 +62,10 @@ func @coopmatrix_mul_add_fc() subgroup_size(16) { } func @coopmatrix_mul_add_cc() subgroup_size(16) { - %a = constant [1.0, 0.0] -> coopmatrix - %b = constant [1.0, 0.0] -> coopmatrix - %c = constant [1.0, 0.0] -> coopmatrix - %c_next = cooperative_matrix_mul_add %a, %b, %c - : coopmatrix, coopmatrix, - coopmatrix -> coopmatrix + %a = constant [1.0, 0.0] : coopmatrix + %b = constant [1.0, 0.0] : coopmatrix + %c = constant [1.0, 0.0] : coopmatrix + %c_next = cooperative_matrix_mul_add %a, %b, %c : coopmatrix ; CHECK-LABEL: void coopmatrix_mul_add_cc({{.*}} ; CHECK: float2 c_next[4]; ; CHECK-NEXT: float2 x[4]; diff --git a/test/codegen/coopmatrix_store.ir b/test/codegen/coopmatrix_store.ir index 5dda1e26..692f7059 100644 --- a/test/codegen/coopmatrix_store.ir +++ b/test/codegen/coopmatrix_store.ir @@ -3,8 +3,8 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @coopmatrix_a_store_n(%A: memref, %x: index, %y: index) subgroup_size(16) { - %c0 = constant 1.0 -> coopmatrix - cooperative_matrix_store %c0, %A[%x,%y] : coopmatrix, memref + %c0 = constant 1.0 : coopmatrix + cooperative_matrix_store %c0, %A[%x,%y] ; CHECK-LABEL: void coopmatrix_a_store_n({{.*}} ; CHECK: global float* x1 = A + x * 1 + y * 64; ; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 0 * 64) = c0[0]; @@ -12,8 +12,8 @@ func @coopmatrix_a_store_n(%A: memref, %x: index, %y: index) subgroup } func @coopmatrix_a_store_n_rows_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %c0 = constant 1.0 -> coopmatrix - cooperative_matrix_store.rows_checked %c0, %A[%x,%y] : coopmatrix, memref + %c0 = constant 1.0 : coopmatrix + cooperative_matrix_store.rows_checked %c0, %A[%x,%y] ; CHECK-LABEL: void coopmatrix_a_store_n_rows_checked({{.*}} ; CHECK: global float* x1 = A + x * 1 + y * 64; ; CHECK-NEXT: long x2 = 64 - x; @@ -25,8 +25,8 @@ func @coopmatrix_a_store_n_rows_checked(%A: memref, %x: index, %y: in } func @coopmatrix_a_store_n_cols_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %c0 = constant 1.0 -> coopmatrix - cooperative_matrix_store.cols_checked %c0, %A[%x,%y] : coopmatrix, memref + %c0 = constant 1.0 : coopmatrix + cooperative_matrix_store.cols_checked %c0, %A[%x,%y] ; CHECK-LABEL: void coopmatrix_a_store_n_cols_checked({{.*}} ; CHECK: global float* x1 = A + x * 1 + y * 64; ; CHECK-NEXT: long x2 = 64 - x; @@ -40,8 +40,8 @@ func @coopmatrix_a_store_n_cols_checked(%A: memref, %x: index, %y: in } func @coopmatrix_a_store_n_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %c0 = constant 1.0 -> coopmatrix - cooperative_matrix_store.both_checked %c0, %A[%x,%y] : coopmatrix, memref + %c0 = constant 1.0 : coopmatrix + cooperative_matrix_store.both_checked %c0, %A[%x,%y] ; CHECK-LABEL: void coopmatrix_a_store_n_checked({{.*}} ; CHECK: global float* x1 = A + x * 1 + y * 64; ; CHECK-NEXT: long x2 = 64 - x; @@ -57,8 +57,8 @@ func @coopmatrix_a_store_n_checked(%A: memref, %x: index, %y: index) } func @coopmatrix_a_store_atomic_add(%A: memref, %x: index, %y: index) subgroup_size(16) { - %c0 = constant 1.0 -> coopmatrix - cooperative_matrix_store.atomic_add %c0, %A[%x,%y] : coopmatrix, memref + %c0 = constant 1.0 : coopmatrix + cooperative_matrix_store.atomic_add %c0, %A[%x,%y] ; CHECK-LABEL: void coopmatrix_a_store_atomic_add({{.*}} ; CHECK: global float* x1 = A + x * 1 + y * 64; ; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) (x1 + 1 * (get_sub_group_local_id() + 0) + 0 * 64), c0[0], memory_order_relaxed, memory_scope_work_group); @@ -66,8 +66,8 @@ func @coopmatrix_a_store_atomic_add(%A: memref, %x: index, %y: index) } func @coopmatrix_a_store_checked_atomic_add(%A: memref, %x: index, %y: index) subgroup_size(16) { - %c0 = constant 1.0 -> coopmatrix - cooperative_matrix_store.both_checked.atomic_add %c0, %A[%x,%y] : coopmatrix, memref + %c0 = constant 1.0 : coopmatrix + cooperative_matrix_store.both_checked.atomic_add %c0, %A[%x,%y] ; CHECK-LABEL: void coopmatrix_a_store_checked_atomic_add({{.*}} ; CHECK: global float* x1 = A + x * 1 + y * 64; ; CHECK-NEXT: long x2 = 64 - x; diff --git a/test/codegen/dope_vector_group0.ir b/test/codegen/dope_vector_group0.ir index 81886f91..bae9d6cf 100644 --- a/test/codegen/dope_vector_group0.ir +++ b/test/codegen/dope_vector_group0.ir @@ -4,7 +4,7 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @kernel1(%in: group>) { ; CHECK: void kernel1(global float*global* in, global long* in_shape1, global long* in_stride2) - %c5 = constant 5 -> index + %c5 = constant 5 : index %0 = load %in[%c5] : group> ; CHECK-NEXT: long c5 = 5ll; ; CHECK-NEXT: global float* x = *(in + c5) + 0; @@ -14,7 +14,7 @@ func @kernel1(%in: group>) { func @kernel2(%in: group, offset: ?>) { ; CHECK: void kernel2(global float*global* in, global long* in_shape0, long in_offset) - %c5 = constant 5 -> index + %c5 = constant 5 : index %0 = load %in[%c5] : group, offset: ?> ; CHECK-NEXT: long c5 = 5ll; ; CHECK-NEXT: global float* x = *(in + c5) + in_offset; diff --git a/test/codegen/expand.ir b/test/codegen/expand.ir index 188651af..c3722880 100644 --- a/test/codegen/expand.ir +++ b/test/codegen/expand.ir @@ -3,7 +3,7 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @t1(%0: memref) { - %z = constant 0 -> index + %z = constant 0 : index %1 = expand %0[1->2x8] : memref %2 = load %1[%z,%z,%z,%z] : memref ; CHECK-LABEL: void t1( @@ -11,7 +11,7 @@ func @t1(%0: memref) { ; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * 32 + z * 64 + z * 512); } func @t2(%0: memref) { - %z = constant 0 -> index + %z = constant 0 : index %1 = expand %0[1->2x2x2x2] : memref %2 = load %1[%z,%z,%z,%z,%z,%z] : memref ; CHECK-LABEL: void t2( @@ -19,7 +19,7 @@ func @t2(%0: memref) { ; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * 32 + z * 64 + z * 128 + z * 256 + z * 512); } func @t3(%0: memref, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = expand %0[1->%1 x 2] : memref %3 = load %2[%z,%z,%z] : memref ; CHECK-LABEL: void t3( @@ -29,7 +29,7 @@ func @t3(%0: memref, %1: index) { ; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * x_stride2); } func @t4(%0: memref, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = expand %0[1->2 x %1] : memref %3 = load %2[%z,%z,%z] : memref ; CHECK-LABEL: void t4( @@ -38,7 +38,7 @@ func @t4(%0: memref, %1: index) { ; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * 64); } func @t5(%0: memref, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = expand %0[1->%1 x 2] : memref %3 = load %2[%z,%z,%z] : memref ; CHECK-LABEL: void t5( @@ -48,7 +48,7 @@ func @t5(%0: memref, %1: index) { ; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * x_stride2); } func @t6(%0: memref, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = expand %0[1->%1 x 2] : memref %3 = load %2[%z,%z,%z] : memref ; CHECK-LABEL: void t6( @@ -58,7 +58,7 @@ func @t6(%0: memref, %1: index) { ; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * x_stride2); } func @t7(%0: memref, %1: index, %2: index) { - %z = constant 0 -> index + %z = constant 0 : index %3 = expand %0[1->%1 x %2 x 2] : memref %4 = load %3[%z,%z,%z,%z] : memref ; CHECK-LABEL: void t7( @@ -70,7 +70,7 @@ func @t7(%0: memref, %1: index, %2: index) { ; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x_stride2 + z * x_stride3); } func @t8(%0: memref, %1: index, %2: index) { - %z = constant 0 -> index + %z = constant 0 : index %3 = expand %0[1->%2 x 2 x %1] : memref %4 = load %3[%z,%z,%z,%z] : memref ; CHECK-LABEL: void t8( @@ -82,7 +82,7 @@ func @t8(%0: memref, %1: index, %2: index) { ; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x_stride2 + z * x_stride3); } func @t9(%0: memref, %1: index, %2: index) { - %z = constant 0 -> index + %z = constant 0 : index %3 = expand %0[1->%1 x %2] : memref %4 = load %3[%z,%z,%z] : memref ; CHECK-LABEL: void t9( @@ -93,7 +93,7 @@ func @t9(%0: memref, %1: index, %2: index) { ; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x_stride2); } func @t10(%0: memref, %1: index, %2: index) { - %z = constant 0 -> index + %z = constant 0 : index %3 = expand %0[1->%1 x %2] : memref %4 = load %3[%z,%z,%z] : memref ; CHECK-LABEL: void t10( @@ -104,7 +104,7 @@ func @t10(%0: memref, %1: index, %2: index) { ; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x_stride2); } func @t11(%0: memref>) { - %z = constant 0 -> index + %z = constant 0 : index %1 = expand %0[0->4 x 8] : memref> %2 = load %1[%z,%z,%z] : memref> ; CHECK-LABEL: void t11( @@ -112,7 +112,7 @@ func @t11(%0: memref>) { ; CHECK-NEXT: float x2 = *(x1 + z * 2 + z * 8 + z * 64); } func @t12(%0: memref>, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = expand %0[0->%1 x 4] : memref> %3 = load %2[%z,%z,%z] : memref> ; CHECK-LABEL: void t12( @@ -123,7 +123,7 @@ func @t12(%0: memref>, %1: index) { ; CHECK-NEXT: float x3 = *(x2 + z * 2 + z * x_stride11 + z * x_stride2); } func @t13(%0: memref>, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = expand %0[0->4 x %1] : memref> %3 = load %2[%z,%z,%z] : memref> ; CHECK-LABEL: void t13( diff --git a/test/codegen/for.ir b/test/codegen/for.ir index d58f9060..5b5ab374 100644 --- a/test/codegen/for.ir +++ b/test/codegen/for.ir @@ -3,12 +3,12 @@ ; RUN: %tinytc-oc < %s | filecheck %s func @for1() { - %lb0 = constant 0 -> index - %ub0 = constant 10 -> index + %lb0 = constant 0 : index + %ub0 = constant 10 : index for %0 = %lb0,%ub0 { } - %lb1 = constant -2 -> i16 - %ub1 = constant 2 -> i16 + %lb1 = constant -2 : i16 + %ub1 = constant 2 : i16 for %1 = %lb1,%ub1 : i16 { } ; CHECK-LABEL: void for1({{.*}} @@ -17,10 +17,10 @@ func @for1() { } func @for2(%fib: memref) { - %from = constant 2 -> i32 - %to = constant 6 -> i32 - %f0 = constant 0 -> i64 - %f1 = constant 1 -> i64 + %from = constant 2 : i32 + %to = constant 6 : i32 + %f0 = constant 0 : i64 + %f1 = constant 1 : i64 %fn_1, %fn = for %n=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) : i32 { %fn = arith.add %fn_2, %fn_1 : i64 yield %fn_1, %fn : i64, i64 diff --git a/test/codegen/fuse.ir b/test/codegen/fuse.ir index a6c4e96e..ba6c7c5f 100644 --- a/test/codegen/fuse.ir +++ b/test/codegen/fuse.ir @@ -3,13 +3,13 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @t1(%0: memref) { - %z = constant 0 -> index + %z = constant 0 : index %1 = fuse %0[1,3] : memref %2 = load %1[%z,%z,%z] : memref ; CHECK: float x2 = *(x1 + z * 1 + z * 32 + z * 16384); } func @t2(%0: memref) { - %z = constant 0 -> index + %z = constant 0 : index %1 = fuse %0[1,3] : memref %2 = load %1[%z,%z,%z] : memref> ; CHECK: long x_shape1 = 16 * x_shape2 * 4; @@ -17,13 +17,13 @@ func @t2(%0: memref) { ; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * 32 + z * x_stride2); } func @t3(%0: memref>) { - %z = constant 0 -> index + %z = constant 0 : index %1 = fuse %0[1,2] : memref> %2 = load %1[%z,%z,%z] : memref> ; CHECK: float x2 = *(x1 + z * 1 + z * 48 + z * 1536); } func @t4(%0: memref>) { - %z = constant 0 -> index + %z = constant 0 : index %1 = fuse %0[0,1] : memref> %2 = load %1[%z,%z] : memref> ; CHECK: long x_shape0 = 8 * x_shape1; diff --git a/test/codegen/if.ir b/test/codegen/if.ir index 177a1212..dbb44400 100644 --- a/test/codegen/if.ir +++ b/test/codegen/if.ir @@ -3,8 +3,8 @@ ; RUN: %tinytc-oc < %s | filecheck %s func @if0(%0: i32) { - %c16 = constant 16 -> i32 - %c0 = constant 0 -> i32 + %c16 = constant 16 : i32 + %c0 = constant 0 : i32 %1 = cmp.lt %0, %c16 : bool %2 = cmp.ge %0, %c0 : bool %3 = arith.and %1, %2 : bool @@ -19,7 +19,7 @@ func @if0(%0: i32) { } func @if1(%0: i32) { - %c16 = constant 16 -> i32 + %c16 = constant 16 : i32 %1 = cmp.lt %0, %c16 : bool if %1 { } else { @@ -29,7 +29,7 @@ func @if1(%0: i32) { } func @if2(%0: i32) { - %c16 = constant 16 -> i32 + %c16 = constant 16 : i32 %1 = cmp.lt %0, %c16 : bool if %1 -> () { yield : @@ -42,7 +42,7 @@ func @if2(%0: i32) { } func @if3(%0: i32) { - %c16 = constant 16 -> i32 + %c16 = constant 16 : i32 %1 = cmp.lt %0, %c16 : bool %x = if %1 -> (i32) { yield %0 : i32 @@ -58,19 +58,19 @@ func @if3(%0: i32) { } func @if4(%0: i32) { - %c16 = constant 16 -> i32 + %c16 = constant 16 : i32 %1 = cmp.lt %0, %c16 : bool %x, %y = if %1 -> (i32, f32) { if %1 { } - %one = constant 1.0 -> f32 + %one = constant 1.0 : f32 yield %0, %one : i32, f32 } else { %z = if %1 -> (f32) { - %one = constant 1.0 -> f32 + %one = constant 1.0 : f32 yield %one : f32 } else { - %zero = constant 0.0 -> f32 + %zero = constant 0.0 : f32 yield %zero : f32 } yield %c16, %z : i32, f32 diff --git a/test/codegen/load.ir b/test/codegen/load.ir index 98fdeec2..33f80f6a 100644 --- a/test/codegen/load.ir +++ b/test/codegen/load.ir @@ -3,7 +3,7 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @kernel1(%a: memref, %b: memref, %c: group>) { - %c5 = constant 5 -> index + %c5 = constant 5 : index %0 = load %a[] : memref %1 = group_id %2 = load %b[%c5, %1] : memref diff --git a/test/codegen/store.ir b/test/codegen/store.ir index 82161a12..2b09d642 100644 --- a/test/codegen/store.ir +++ b/test/codegen/store.ir @@ -3,7 +3,7 @@ ; RUN: %tinytc-oc < %s | filecheck %s func @kernel(%a: memref, %b: memref, %c: f32) { - %c5 = constant 5 -> index + %c5 = constant 5 : index %1 = group_id store %c, %a[] : memref store %c, %b[%c5, %1] : memref diff --git a/test/codegen/type_mismatch1.ir b/test/codegen/type_mismatch1.ir index be9cdc2f..b79bf466 100644 --- a/test/codegen/type_mismatch1.ir +++ b/test/codegen/type_mismatch1.ir @@ -3,7 +3,7 @@ ; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s func @kernel(%K0: memref, %x: index, %y: index) { - %z = constant 0 -> index + %z = constant 0 : index %0 = subview %K0[0:%x] : memref %1 = subview %0[0:%y] : memref %2 = load %1[%z] : memref diff --git a/test/codegen/work_group.ir b/test/codegen/work_group.ir index f182cdc1..eed318bd 100644 --- a/test/codegen/work_group.ir +++ b/test/codegen/work_group.ir @@ -3,13 +3,13 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @t1() { - %0 = constant 1.0 -> f32 + %0 = constant 1.0 : f32 %1 = work_group.reduce_add %0 : f32 ; CHECK-LABEL: void t1({{.*}} ; CHECK: float x1 = work_group_reduce_add(x); } func @t2() { - %0 = constant [1.0, 0.0] -> c32 + %0 = constant [1.0, 0.0] : c32 %1 = work_group.reduce_add %0 : c32 ; CHECK-LABEL: void t2({{.*}} ; CHECK: float2 x1 = (float2) (work_group_reduce_add(x.x), work_group_reduce_add(x.y)); diff --git a/test/opt/check-ir/cast_forbidden.ir b/test/opt/check-ir/cast_forbidden.ir index 31e785f9..013512a8 100644 --- a/test/opt/check-ir/cast_forbidden.ir +++ b/test/opt/check-ir/cast_forbidden.ir @@ -3,7 +3,7 @@ ; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @cast_cf() { - %0 = constant [2.0, 1.0] -> c32 + %0 = constant [2.0, 1.0] : c32 %1 = cast %0 : i32 ; CHECK: :7.8-20: Forbidden cast } diff --git a/test/opt/check-ir/expand.ir b/test/opt/check-ir/expand.ir index 1bc54c69..7bfc61a8 100644 --- a/test/opt/check-ir/expand.ir +++ b/test/opt/check-ir/expand.ir @@ -8,67 +8,67 @@ ; CHECK: func @t1({{.*}} func @t1(%0: memref) { - %z = constant 0 -> index + %z = constant 0 : index %1 = expand %0[1->2x8] : memref %2 = load %1[%z,%z,%z,%z] : memref } func @t2(%0: memref) { - %z = constant 0 -> index + %z = constant 0 : index %1 = expand %0[1->2x2x2x2] : memref %2 = load %1[%z,%z,%z,%z,%z,%z] : memref } func @t3(%0: memref, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = expand %0[1->%1 x 2] : memref %3 = load %2[%z,%z,%z] : memref } func @t4(%0: memref, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = expand %0[1->2 x %1] : memref %3 = load %2[%z,%z,%z] : memref } func @t5(%0: memref, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = expand %0[1->%1 x 2] : memref %3 = load %2[%z,%z,%z] : memref } func @t6(%0: memref, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = expand %0[1->%1 x 2] : memref %3 = load %2[%z,%z,%z] : memref } func @t7(%0: memref, %1: index, %2: index) { - %z = constant 0 -> index + %z = constant 0 : index %3 = expand %0[1->%1 x %2 x 2] : memref %4 = load %3[%z,%z,%z,%z] : memref } func @t8(%0: memref, %1: index, %2: index) { - %z = constant 0 -> index + %z = constant 0 : index %3 = expand %0[1->%2 x 2 x %1] : memref %4 = load %3[%z,%z,%z,%z] : memref } func @t9(%0: memref, %1: index, %2: index) { - %z = constant 0 -> index + %z = constant 0 : index %3 = expand %0[1->%1 x %2] : memref %4 = load %3[%z,%z,%z] : memref } func @t10(%0: memref, %1: index, %2: index) { - %z = constant 0 -> index + %z = constant 0 : index %3 = expand %0[1->%1 x %2] : memref %4 = load %3[%z,%z,%z] : memref } func @t11(%0: memref>) { - %z = constant 0 -> index + %z = constant 0 : index %1 = expand %0[0->4 x 8] : memref> %2 = load %1[%z,%z,%z] : memref> } func @t12(%0: memref>, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = expand %0[0->%1 x 4] : memref> %3 = load %2[%z,%z,%z] : memref> } func @t13(%0: memref>, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = expand %0[0->4 x %1] : memref> %3 = load %2[%z,%z,%z] : memref> } diff --git a/test/opt/check-ir/nesting0.ir b/test/opt/check-ir/nesting0.ir index 4821a7be..66e29d1f 100644 --- a/test/opt/check-ir/nesting0.ir +++ b/test/opt/check-ir/nesting0.ir @@ -3,8 +3,8 @@ ; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @illegal_nesting(%c: f32, %A: memref, %B: memref, %C: memref) { - %lb = constant 1 -> index - %ub = constant 16 -> index + %lb = constant 1 : index + %ub = constant 16 : index foreach (%i)=(%lb),(%ub) { gemm.n.n %c, %A, %B, %c, %C } diff --git a/test/opt/check-ir/nesting1.ir b/test/opt/check-ir/nesting1.ir index e49e9fbb..45d4bc11 100644 --- a/test/opt/check-ir/nesting1.ir +++ b/test/opt/check-ir/nesting1.ir @@ -3,8 +3,8 @@ ; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @illegal_nesting() { - %lb = constant 1 -> index - %ub = constant 16 -> index + %lb = constant 1 : index + %ub = constant 16 : index foreach (%i)=(%lb),(%ub) { foreach (%j)=(%lb),(%ub) { } diff --git a/test/opt/check-ir/nesting3.ir b/test/opt/check-ir/nesting3.ir index ac787966..52bbfe32 100644 --- a/test/opt/check-ir/nesting3.ir +++ b/test/opt/check-ir/nesting3.ir @@ -3,8 +3,8 @@ ; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @illegal_nesting() { - %lb = constant 1 -> index - %ub = constant 16 -> index + %lb = constant 1 : index + %ub = constant 16 : index parallel { foreach (%j)=(%lb),(%ub) { } diff --git a/test/opt/check-ir/subview.ir b/test/opt/check-ir/subview.ir index 709d8d02..5f06833e 100644 --- a/test/opt/check-ir/subview.ir +++ b/test/opt/check-ir/subview.ir @@ -8,42 +8,42 @@ ; CHECK: func @t1({{.*}} func @t1(%0: memref) { - %z = constant 0 -> index + %z = constant 0 : index %1 = subview %0[4:8,8:4] : memref %2 = load %1[%z,%z] : memref> } func @t2(%0: memref, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = subview %0[2:4,%1] : memref %3 = load %2[%z] : memref } func @t3(%0: memref, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = subview %0[2:4,%1:0] : memref %3 = load %2[%z] : memref } func @t4(%0: memref, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = subview %0[2:4,%1:1] : memref %3 = load %2[%z,%z] : memref> } func @t5(%0: memref, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = subview %0[%1:4] : memref %3 = load %2[%z] : memref } func @t6(%0: memref, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = subview %0[%1:%1] : memref %3 = load %2[%z] : memref } func @t7(%0: memref, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = subview %0[2:4, %1:%1, 6:7] : memref %3 = load %2[%z,%z,%z] : memref> } func @t8(%0: memref>, %1: index) { - %z = constant 0 -> index + %z = constant 0 : index %2 = subview %0[2:4, %1:%1, 6:7] : memref> %3 = load %2[%z,%z,%z] : memref> } diff --git a/test/opt/constant-propagation-safe.ir b/test/opt/constant-propagation-safe.ir index 1de79f15..37cd6354 100644 --- a/test/opt/constant-propagation-safe.ir +++ b/test/opt/constant-propagation-safe.ir @@ -4,7 +4,7 @@ ; RUN: %tinytc-opt -pconstant-propagation -fno-unsafe-fp-math < %s | filecheck %s func @identity_iadd(%a: i32) { - %c0 = constant 0 -> i32 + %c0 = constant 0 : i32 %0 = arith.add %a, %c0 : i32 %1 = arith.add %c0, %a : i32 %2 = arith.add %0, %1 : i32 @@ -13,7 +13,7 @@ func @identity_iadd(%a: i32) { } func @identity_isub(%a: i32) { - %c0 = constant 0 -> i32 + %c0 = constant 0 : i32 %0 = arith.sub %a, %c0 : i32 %1 = arith.sub %c0, %a : i32 %2 = arith.add %0, %1 : i32 @@ -22,18 +22,18 @@ func @identity_isub(%a: i32) { } func @identity_imul0(%a: i32) { - %c0 = constant 0 -> i32 + %c0 = constant 0 : i32 %0 = arith.mul %a, %c0 : i32 %1 = arith.mul %c0, %a : i32 ; CHECK-LABEL: func @identity_imul0({{.*}} -; CHECK: %0 = constant 0 -> i32 +; CHECK: %0 = constant 0 : i32 ; CHECK-NEXT: %1 = arith.mul %a, %c0 : i32 -; CHECK-NEXT: %2 = constant 0 -> i32 +; CHECK-NEXT: %2 = constant 0 : i32 ; CHECK-NEXT: %3 = arith.mul %c0, %a : i32 } func @identity_imul1(%a: i32) { - %c1 = constant 1 -> i32 + %c1 = constant 1 : i32 %0 = arith.mul %a, %c1 : i32 %1 = arith.mul %c1, %a : i32 %2 = arith.mul %0, %1 : i32 @@ -42,7 +42,7 @@ func @identity_imul1(%a: i32) { } func @identity_idiv(%a: i32) { - %c1 = constant 1 -> i32 + %c1 = constant 1 : i32 %0 = arith.div %a, %c1 : i32 %1 = arith.div %c1, %a : i32 %2 = arith.mul %0, %1 : i32 @@ -51,48 +51,48 @@ func @identity_idiv(%a: i32) { } func @identity_irem(%a: i32) { - %c1 = constant 1 -> i32 + %c1 = constant 1 : i32 %0 = arith.rem %a, %c1 : i32 %1 = arith.rem %c1, %a : i32 %2 = arith.mul %0, %1 : i32 ; CHECK-LABEL: func @identity_irem({{.*}} -; CHECK: %0 = constant 0 -> i32 +; CHECK: %0 = constant 0 : i32 ; CHECK: %4 = arith.mul %0, %2 : i32 } func @identity_ishl(%a: i32) { - %c0 = constant 0 -> i32 + %c0 = constant 0 : i32 %0 = arith.shl %a, %c0 : i32 %1 = arith.shl %c0, %a : i32 %2 = arith.add %0, %1 : i32 ; CHECK-LABEL: func @identity_ishl({{.*}} -; CHECK: %1 = constant 0 -> i32 +; CHECK: %1 = constant 0 : i32 ; CHECK: %3 = arith.add %a, %1 : i32 } func @identity_ishr(%a: i32) { - %c0 = constant 0 -> i32 + %c0 = constant 0 : i32 %0 = arith.shr %a, %c0 : i32 %1 = arith.shr %c0, %a : i32 %2 = arith.add %0, %1 : i32 ; CHECK-LABEL: func @identity_ishr({{.*}} -; CHECK: %1 = constant 0 -> i32 +; CHECK: %1 = constant 0 : i32 ; CHECK: %3 = arith.add %a, %1 : i32 } func @identity_iand(%a: i32) { - %c0 = constant 0 -> i32 + %c0 = constant 0 : i32 %0 = arith.and %a, %c0 : i32 %1 = arith.and %c0, %a : i32 %2 = arith.add %0, %1 : i32 ; CHECK-LABEL: func @identity_iand({{.*}} -; CHECK: %0 = constant 0 -> i32 -; CHECK: %2 = constant 0 -> i32 +; CHECK: %0 = constant 0 : i32 +; CHECK: %2 = constant 0 : i32 ; CHECK: %5 = arith.add %0, %2 : i32 } func @identity_ior(%a: i32) { - %c0 = constant 0 -> i32 + %c0 = constant 0 : i32 %0 = arith.or %a, %c0 : i32 %1 = arith.or %c0, %a : i32 %2 = arith.add %0, %1 : i32 @@ -101,7 +101,7 @@ func @identity_ior(%a: i32) { } func @identity_ixor(%a: i32) { - %c0 = constant 0 -> i32 + %c0 = constant 0 : i32 %0 = arith.xor %a, %c0 : i32 %1 = arith.xor %c0, %a : i32 %2 = arith.add %0, %1 : i32 @@ -110,18 +110,18 @@ func @identity_ixor(%a: i32) { } func @identity_band(%a: bool) { - %c0 = constant false -> bool + %c0 = constant false : bool %0 = arith.and %a, %c0 : bool %1 = arith.and %c0, %a : bool %2 = arith.and %0, %1 : bool ; CHECK-LABEL: func @identity_band({{.*}} -; CHECK: %0 = constant false -> bool -; CHECK: %2 = constant false -> bool +; CHECK: %0 = constant false : bool +; CHECK: %2 = constant false : bool ; CHECK: %5 = arith.and %0, %2 : bool } func @identity_bor(%a: bool) { - %c0 = constant false -> bool + %c0 = constant false : bool %0 = arith.or %a, %c0 : bool %1 = arith.or %c0, %a : bool %2 = arith.and %0, %1 : bool @@ -130,7 +130,7 @@ func @identity_bor(%a: bool) { } func @identity_bxor(%a: bool) { - %c0 = constant false -> bool + %c0 = constant false : bool %0 = arith.xor %a, %c0 : bool %1 = arith.xor %c0, %a : bool %2 = arith.and %0, %1 : bool diff --git a/test/opt/constant-propagation-unsafe.ir b/test/opt/constant-propagation-unsafe.ir index 53b4dc5c..b3c0ea7b 100644 --- a/test/opt/constant-propagation-unsafe.ir +++ b/test/opt/constant-propagation-unsafe.ir @@ -4,7 +4,7 @@ ; RUN: %tinytc-opt -pconstant-propagation -funsafe-fp-math < %s | filecheck %s func @identity_fadd(%a: f32) { - %c0 = constant 0.0 -> f32 + %c0 = constant 0.0 : f32 %0 = arith.add %a, %c0 : f32 %1 = arith.add %c0, %a : f32 %2 = arith.add %0, %1 : f32 @@ -13,7 +13,7 @@ func @identity_fadd(%a: f32) { } func @identity_fsub(%a: f32) { - %c0 = constant 0.0 -> f32 + %c0 = constant 0.0 : f32 %0 = arith.sub %a, %c0 : f32 %1 = arith.sub %c0, %a : f32 %2 = arith.add %0, %1 : f32 @@ -22,18 +22,18 @@ func @identity_fsub(%a: f32) { } func @identity_fmul0(%a: f32) { - %c0 = constant 0.0 -> f32 + %c0 = constant 0.0 : f32 %0 = arith.mul %a, %c0 : f32 %1 = arith.mul %c0, %a : f32 ; CHECK-LABEL: func @identity_fmul0({{.*}} -; CHECK: %0 = constant 0x0p+0 -> f32 +; CHECK: %0 = constant 0x0p+0 : f32 ; CHECK-NEXT: %1 = arith.mul %a, %c0 : f32 -; CHECK-NEXT: %2 = constant 0x0p+0 -> f32 +; CHECK-NEXT: %2 = constant 0x0p+0 : f32 ; CHECK-NEXT: %3 = arith.mul %c0, %a : f32 } func @identity_fmul1(%a: f32) { - %c1 = constant 1.0 -> f32 + %c1 = constant 1.0 : f32 %0 = arith.mul %a, %c1 : f32 %1 = arith.mul %c1, %a : f32 %2 = arith.mul %0, %1 : f32 @@ -42,7 +42,7 @@ func @identity_fmul1(%a: f32) { } func @identity_fdiv(%a: f32) { - %c1 = constant 1.0 -> f32 + %c1 = constant 1.0 : f32 %0 = arith.div %a, %c1 : f32 %1 = arith.div %c1, %a : f32 %2 = arith.mul %0, %1 : f32 @@ -51,7 +51,7 @@ func @identity_fdiv(%a: f32) { } func @identity_cmul1(%a: c32) { - %c1 = constant [1.0, 0.0] -> c32 + %c1 = constant [1.0, 0.0] : c32 %0 = arith.mul %a, %c1 : c32 %1 = arith.mul %c1, %a : c32 %2 = arith.mul %0, %1 : c32 diff --git a/test/opt/constant-propagation.ir b/test/opt/constant-propagation.ir index 39a40243..daf0ff66 100644 --- a/test/opt/constant-propagation.ir +++ b/test/opt/constant-propagation.ir @@ -8,45 +8,45 @@ func @known_size(%a: memref, %b: index) { %2 = arith.add %0, %1 : index %3 = arith.add %2, %b : index ; CHECK-LABEL: func @known_size({{.*}} -; CHECK: %0 = constant 64 -> index +; CHECK: %0 = constant 64 : index ; CHECK-NEXT: %1 = size %a[0] : memref -; CHECK-NEXT: %2 = constant 32 -> index +; CHECK-NEXT: %2 = constant 32 : index ; CHECK-NEXT: %3 = size %a[1] : memref -; CHECK-NEXT: %4 = constant 96 -> index +; CHECK-NEXT: %4 = constant 96 : index ; CHECK-NEXT: %5 = arith.add %0, %2 : index ; CHECK-NEXT: %6 = arith.add %4, %b : index } func @known_loop_bounds() { - %one = constant 1 -> index - %lb = constant 5 -> index - %size = constant 42 -> index + %one = constant 1 : index + %lb = constant 5 : index + %size = constant 42 : index %tmp = arith.sub %size, %lb : index %ub = arith.sub %tmp, %one : index for %i=%lb,%ub { } ; CHECK-LABEL: func @known_loop_bounds({{.*}} -; CHECK-NEXT: %one = constant 1 -> index -; CHECK-NEXT: %lb = constant 5 -> index -; CHECK-NEXT: %size = constant 42 -> index -; CHECK-NEXT: %0 = constant 37 -> index +; CHECK-NEXT: %one = constant 1 : index +; CHECK-NEXT: %lb = constant 5 : index +; CHECK-NEXT: %size = constant 42 : index +; CHECK-NEXT: %0 = constant 37 : index ; CHECK-NEXT: %tmp = arith.sub %size, %lb : index -; CHECK-NEXT: %1 = constant 36 -> index +; CHECK-NEXT: %1 = constant 36 : index ; CHECK-NEXT: %ub = arith.sub %0, %one : index ; CHECK-NEXT: for %i=%lb,%1 : index { } func @known_loop_iter_args() { - %c1 = constant 1 -> index - %c5 = constant 5 -> index + %c1 = constant 1 : index + %c5 = constant 5 : index %0 = arith.add %c1, %c5 : index %2 = for %i=%c1,%c5 init(%1=%0) -> (index) { yield %1 : index } ; CHECK-LABEL: func @known_loop_iter_args({{.*}} -; CHECK-NEXT: %c1 = constant 1 -> index -; CHECK-NEXT: %c5 = constant 5 -> index -; CHECK-NEXT: %0 = constant 6 -> index +; CHECK-NEXT: %c1 = constant 1 : index +; CHECK-NEXT: %c5 = constant 5 : index +; CHECK-NEXT: %0 = constant 6 : index ; CHECK-NEXT: %1 = arith.add %c1, %c5 : index ; CHECK-NEXT: %3 = for %i=%c1,%c5 init(%2=%0) -> (index) : index { ; CHECK-NEXT: yield %2 : index @@ -54,12 +54,12 @@ func @known_loop_iter_args() { } func @known_arith() { - %0 = constant 1 -> i64 - %1 = constant 2 -> i64 - %2 = constant -2.0 -> f32 - %3 = constant [1.0, -1.0] -> c32 - %4 = constant false -> bool - %5 = constant true -> bool + %0 = constant 1 : i64 + %1 = constant 2 : i64 + %2 = constant -2.0 : f32 + %3 = constant [1.0, -1.0] : c32 + %4 = constant false : bool + %5 = constant true : bool %6 = arith.not %0 : i64 %7 = arith.add %0, %1 : i64 %8 = arith.neg %2 : f32 @@ -70,62 +70,62 @@ func @known_arith() { %13 = arith.xor %5, %5 : bool %14 = arith.not %4 : bool ; CHECK-LABEL: func @known_arith({{.*}} -; CHECK: %6 = constant -2 -> i64 +; CHECK: %6 = constant -2 : i64 ; CHECK-NEXT: %7 = arith.not %0 : i64 -; CHECK-NEXT: %8 = constant 3 -> i64 +; CHECK-NEXT: %8 = constant 3 : i64 ; CHECK-NEXT: %9 = arith.add %0, %1 : i64 -; CHECK-NEXT: %10 = constant 0x1p+1 -> f32 +; CHECK-NEXT: %10 = constant 0x1p+1 : f32 ; CHECK-NEXT: %11 = arith.neg %2 : f32 -; CHECK-NEXT: %12 = constant [0x1p+1,-0x1p+1] -> c32 +; CHECK-NEXT: %12 = constant [0x1p+1,-0x1p+1] : c32 ; CHECK-NEXT: %13 = arith.add %3, %3 : c32 -; CHECK-NEXT: %14 = constant 0x1p+1 -> f32 +; CHECK-NEXT: %14 = constant 0x1p+1 : f32 ; CHECK-NEXT: %15 = arith.abs %2 : f32 -; CHECK-NEXT: %16 = constant false -> bool +; CHECK-NEXT: %16 = constant false : bool ; CHECK-NEXT: %17 = arith.and %4, %5 : bool -; CHECK-NEXT: %18 = constant true -> bool +; CHECK-NEXT: %18 = constant true : bool ; CHECK-NEXT: %19 = arith.or %4, %5 : bool -; CHECK-NEXT: %20 = constant false -> bool +; CHECK-NEXT: %20 = constant false : bool ; CHECK-NEXT: %21 = arith.xor %5, %5 : bool -; CHECK-NEXT: %22 = constant true -> bool +; CHECK-NEXT: %22 = constant true : bool ; CHECK-NEXT: %23 = arith.not %4 : bool } func @known_cast() { - %c0 = constant 32768 -> i32 - %c1 = constant [3.0, -2.0] -> c32 + %c0 = constant 32768 : i32 + %c1 = constant [3.0, -2.0] : c32 %0 = cast %c0 : i16 %1 = cast %c0 : f32 %2 = cast %c0 : c32 %3 = cast %c0 : c32 %4 = cast %c1 : c64 ; CHECK-LABEL: func @known_cast({{.*}} -; CHECK: %0 = constant -32768 -> i16 +; CHECK: %0 = constant -32768 : i16 ; CHECK-NEXT: %1 = cast %c0 : i16 -; CHECK-NEXT: %2 = constant 0x1p+15 -> f32 +; CHECK-NEXT: %2 = constant 0x1p+15 : f32 ; CHECK-NEXT: %3 = cast %c0 : f32 -; CHECK-NEXT: %4 = constant [0x1p+15,0x0p+0] -> c32 +; CHECK-NEXT: %4 = constant [0x1p+15,0x0p+0] : c32 ; CHECK-NEXT: %5 = cast %c0 : c32 -; CHECK-NEXT: %6 = constant [0x1p+15,0x0p+0] -> c32 +; CHECK-NEXT: %6 = constant [0x1p+15,0x0p+0] : c32 ; CHECK-NEXT: %7 = cast %c0 : c32 -; CHECK-NEXT: %8 = constant [0x1.8p+1,-0x1p+1] -> c64 +; CHECK-NEXT: %8 = constant [0x1.8p+1,-0x1p+1] : c64 ; CHECK-NEXT: %9 = cast %c1 : c64 } func @known_compare() { - %0 = constant 1.0 -> f32 - %1 = constant 2.0 -> f32 + %0 = constant 1.0 : f32 + %1 = constant 2.0 : f32 %2 = cmp.eq %0, %0 : bool %3 = cmp.eq %0, %1 : bool ; CHECK-LABEL: func @known_compare({{.*}} -; CHECK: %2 = constant true -> bool +; CHECK: %2 = constant true : bool ; CHECK-NEXT: %3 = cmp.eq %0, %0 : bool -; CHECK-NEXT: %4 = constant false -> bool +; CHECK-NEXT: %4 = constant false : bool ; CHECK-NEXT: %5 = cmp.eq %0, %1 : bool } func @known_arith_complex() { - %a = constant [3.0, 2.0] -> c32 - %b = constant [-1.0, 5.0] -> c32 + %a = constant [3.0, 2.0] : c32 + %b = constant [-1.0, 5.0] : c32 %0 = arith.add %a, %b : c32 %1 = arith.sub %a, %b : c32 %2 = arith.mul %a, %b : c32 @@ -136,22 +136,22 @@ func @known_arith_complex() { %7 = arith.im %a : f32 %8 = arith.re %a : f32 ; CHECK-LABEL: func @known_arith_complex({{.*}} -; CHECK: %0 = constant [0x1p+1,0x1.cp+2] -> c32 +; CHECK: %0 = constant [0x1p+1,0x1.cp+2] : c32 ; CHECK-NEXT: %1 = arith.add %a, %b : c32 -; CHECK-NEXT: %2 = constant [0x1p+2,-0x1.8p+1] -> c32 +; CHECK-NEXT: %2 = constant [0x1p+2,-0x1.8p+1] : c32 ; CHECK-NEXT: %3 = arith.sub %a, %b : c32 -; CHECK-NEXT: %4 = constant [-0x1.ap+3,0x1.ap+3] -> c32 +; CHECK-NEXT: %4 = constant [-0x1.ap+3,0x1.ap+3] : c32 ; CHECK-NEXT: %5 = arith.mul %a, %b : c32 -; CHECK-NEXT: %6 = constant [0x1.13b13cp-2,-0x1.4ec4eep-1] -> c32 +; CHECK-NEXT: %6 = constant [0x1.13b13cp-2,-0x1.4ec4eep-1] : c32 ; CHECK-NEXT: %7 = arith.div %a, %b : c32 -; CHECK-NEXT: %8 = constant [-0x1.8p+1,-0x1p+1] -> c32 +; CHECK-NEXT: %8 = constant [-0x1.8p+1,-0x1p+1] : c32 ; CHECK-NEXT: %9 = arith.neg %a : c32 -; CHECK-NEXT: %10 = constant [0x1.8p+1,-0x1p+1] -> c32 +; CHECK-NEXT: %10 = constant [0x1.8p+1,-0x1p+1] : c32 ; CHECK-NEXT: %11 = arith.conj %a : c32 -; CHECK-NEXT: %12 = constant 0x1.cd82b4p+1 -> f32 +; CHECK-NEXT: %12 = constant 0x1.cd82b4p+1 : f32 ; CHECK-NEXT: %13 = arith.abs %a : f32 -; CHECK-NEXT: %14 = constant 0x1p+1 -> f32 +; CHECK-NEXT: %14 = constant 0x1p+1 : f32 ; CHECK-NEXT: %15 = arith.im %a : f32 -; CHECK-NEXT: %16 = constant 0x1.8p+1 -> f32 +; CHECK-NEXT: %16 = constant 0x1.8p+1 : f32 ; CHECK-NEXT: %17 = arith.re %a : f32 } diff --git a/test/opt/dead-code-elimination.ir b/test/opt/dead-code-elimination.ir index 6cdb62f3..832c95b3 100644 --- a/test/opt/dead-code-elimination.ir +++ b/test/opt/dead-code-elimination.ir @@ -3,18 +3,18 @@ ; RUN: %tinytc-opt -pdead-code-elimination < %s | filecheck %s func @dead_if(%a: memref) { - %c0 = constant false -> bool + %c0 = constant false : bool if %c0 { - %c42 = constant 42.0 -> f64 + %c42 = constant 42.0 : f64 store %c42, %a[] : memref } - %c1 = constant true -> bool + %c1 = constant true : bool if %c1 { - %c43 = constant 43.0 -> f64 + %c43 = constant 43.0 : f64 store %c43, %a[] : memref } ; CHECK-LABEL: func @dead_if({{.*}} -; CHECK-NEXT: %c1 = constant true -> bool +; CHECK-NEXT: %c1 = constant true : bool ; CHECK-NEXT: if %c1 { ; CHECK-NEXT: %c43{{.*}} ; CHECK-NEXT: store{{.*}} @@ -22,42 +22,42 @@ func @dead_if(%a: memref) { } func @dead_if_with_yield(%a: memref) { - %c0 = constant false -> bool + %c0 = constant false : bool %0 = if %c0 -> (f64) { - %c42 = constant 42.0 -> f64 + %c42 = constant 42.0 : f64 yield %c42 : f64 } else { - %c43 = constant 43.0 -> f64 + %c43 = constant 43.0 : f64 yield %c43 : f64 } store %0, %a[] : memref ; Cannot eliminate if that returns results currently ; CHECK-LABEL: func @dead_if_with_yield({{.*}} ; CHECK: %0 = if %c0 -> (f64) { -; CHECK-NEXT: %c42 = constant 0x1.5p+5 -> f64 +; CHECK-NEXT: %c42 = constant 0x1.5p+5 : f64 ; CHECK-NEXT: yield %c42 : f64 ; CHECK-NEXT: } else { -; CHECK-NEXT: %c43 = constant 0x1.58p+5 -> f64 +; CHECK-NEXT: %c43 = constant 0x1.58p+5 : f64 ; CHECK-NEXT: yield %c43 : f64 ; CHECK-NEXT: } ; CHECK-NEXT: store %0, %a[] : memref } func @dead_loop(%a: memref) { - %c2 = constant 2 -> index + %c2 = constant 2 : index for %0=%c2,%c2 { - %c42 = constant 42.0 -> f64 + %c42 = constant 42.0 : f64 store %c42, %a[] : memref } - %c5 = constant 5 -> index - %c6 = constant 6 -> index + %c5 = constant 5 : index + %c6 = constant 6 : index for %0=%c5,%c6 { - %c43 = constant 43.0 -> f64 + %c43 = constant 43.0 : f64 store %c43, %a[] : memref } ; CHECK-LABEL: func @dead_loop({{.*}} -; CHECK-NEXT: %c5 = constant 5 -> index -; CHECK-NEXT: %c6 = constant 6 -> index +; CHECK-NEXT: %c5 = constant 5 : index +; CHECK-NEXT: %c6 = constant 6 : index ; CHECK-NEXT: for %0=%c5,%c6 : index { ; CHECK-NEXT: %c43{{.*}} ; CHECK-NEXT: store{{.*}} @@ -67,7 +67,7 @@ func @dead_loop(%a: memref) { func @unused_alloca(%a: memref) { %0 = alloca : memref %1 = alloca : memref - %one = constant 1.0 -> f64 + %one = constant 1.0 : f64 axpby.n %one, %1, %one, %a ; CHECK-LABEL: func @unused_alloca({{.*}} ; CHECK-NEXT: %0 = alloca : memref diff --git a/test/opt/dump-def-use.ir b/test/opt/dump-def-use.ir index 23fe73a9..14043c90 100644 --- a/test/opt/dump-def-use.ir +++ b/test/opt/dump-def-use.ir @@ -4,22 +4,22 @@ ; RUN: %tinytc-opt -pdump-def-use < %s | filecheck %s func @foobar() { - %one = constant 1 -> index - %lb = constant 0 -> index - %ub = constant 5 -> index + %one = constant 1 : index + %lb = constant 0 : index + %ub = constant 5 : index for %i=%lb,%ub : index { %0 = arith.add %i, %one : index %1 = arith.rem %0, %one : index } ; CHECK: Def-use in foobar -; CHECK-NEXT: > %one = constant 1 -> index +; CHECK-NEXT: > %one = constant 1 : index ; CHECK-NEXT: def %one ; CHECK-NEXT: > %1 = arith.rem %0, %one : index ; CHECK-NEXT: > %0 = arith.add %i, %one : index -; CHECK-NEXT: > %lb = constant 0 -> index +; CHECK-NEXT: > %lb = constant 0 : index ; CHECK-NEXT: def %lb ; CHECK-NEXT: > for %i=%lb,%ub : index {...} -; CHECK-NEXT: > %ub = constant 5 -> index +; CHECK-NEXT: > %ub = constant 5 : index ; CHECK-NEXT: def %ub ; CHECK-NEXT: > for %i=%lb,%ub : index {...} ; CHECK-NEXT: > for %i=%lb,%ub : index {...} diff --git a/test/opt/insert-barrier.ir b/test/opt/insert-barrier.ir index c0cd1f58..44537e6e 100644 --- a/test/opt/insert-barrier.ir +++ b/test/opt/insert-barrier.ir @@ -80,7 +80,7 @@ func @war_alias(%a: f32, %b: f32, %A: memref, %C: memref) { } func @if(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { - %c42 = constant 42.0 -> f32 + %c42 = constant 42.0 : f32 %0 = cmp.gt %a, %c42 : bool if %0 { axpby.n %a, %A, %b, %B @@ -102,7 +102,7 @@ func @if(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref< } func @if2(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref, %D: memref) { - %c42 = constant 42.0 -> f32 + %c42 = constant 42.0 : f32 %0 = cmp.gt %a, %c42 : bool axpby.n %a, %B, %b, %A if %0 { @@ -126,11 +126,11 @@ func @if2(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref } func @region1() { - %one = constant 1.0 -> f32 - %zero = constant 0.0 -> f32 + %one = constant 1.0 : f32 + %zero = constant 0.0 : f32 %0 = alloca : memref - %lb = constant 0 -> index - %ub = constant 4 -> index + %lb = constant 0 : index + %ub = constant 4 : index for %i=%lb,%ub : index { %1 = alloca : memref for %k=%lb,%ub : index { @@ -156,9 +156,9 @@ func @region1() { } func @no_barrier_spmd(%a: f32, %b: f32, %A: memref, %B: memref) { - %c0 = constant 0 -> i32 - %c3 = constant 3 -> index - %c4 = constant 4 -> index + %c0 = constant 0 : i32 + %c3 = constant 3 : index + %c4 = constant 4 : index parallel { %0 = subgroup_id %1 = cmp.eq %0, %c0 : bool diff --git a/test/opt/insert-lifetime-stop.ir b/test/opt/insert-lifetime-stop.ir index 3bc8c7bf..7087d79d 100644 --- a/test/opt/insert-lifetime-stop.ir +++ b/test/opt/insert-lifetime-stop.ir @@ -11,7 +11,7 @@ func @basic() { func @use1(%A: memref, %C: memref) { ; CHECK-LABEL: func @use1{{.*}} %B = alloca : memref - %one = constant 1.0 -> f32 + %one = constant 1.0 : f32 gemm.n.n %one, %A, %B, %one, %C ; CHECK: gemm.n.n{{.*}} ; CHECK-NEXT: lifetime_stop %B @@ -19,7 +19,7 @@ func @use1(%A: memref, %C: memref) { func @use2(%A: memref, %C: memref) { ; CHECK-LABEL: func @use2{{.*}} - %one = constant 1.0 -> f32 + %one = constant 1.0 : f32 %B = alloca : memref gemm.n.n %one, %A, %B, %one, %C %B2 = alloca : memref @@ -44,10 +44,10 @@ func @use_alias(%a: f32, %A: memref, %C: memref) { func @region1() { ; CHECK-LABEL: func @region1{{.*}} - %one = constant 1.0 -> f32 + %one = constant 1.0 : f32 %0 = alloca : memref - %lb = constant 0 -> index - %ub = constant 4 -> index + %lb = constant 0 : index + %ub = constant 4 : index for %i=%lb,%ub : index { %1 = alloca : memref for %k=%lb,%ub : index { diff --git a/test/opt/work-group-size.ir b/test/opt/work-group-size.ir index 0a048528..87f0f63c 100644 --- a/test/opt/work-group-size.ir +++ b/test/opt/work-group-size.ir @@ -10,10 +10,10 @@ func @f32_blas() { ; CHECK: func @f32_blas() subgroup_size(32) work_group_size(128,2) { %0 = alloca : memref %1 = alloca : memref - %one = constant 1.0 -> f32 - %zero = constant 0.0 -> f32 - %lb = constant 0 -> index - %ub = constant 4 -> index + %one = constant 1.0 : f32 + %zero = constant 0.0 : f32 + %lb = constant 0 : index + %ub = constant 4 : index for %i=%lb,%ub { axpby.n %one, %0, %zero, %1 } @@ -24,10 +24,10 @@ func @f64_blas() { %0 = alloca : memref %1 = alloca : memref %2 = alloca : memref - %one = constant 1.0 -> f64 - %zero = constant 0.0 -> f64 - %lb = constant 0 -> index - %ub = constant 4 -> index + %one = constant 1.0 : f64 + %zero = constant 0.0 : f64 + %lb = constant 0 : index + %ub = constant 4 : index for %i=%lb,%ub { gemm.n.n %one, %0, %1, %zero, %2 } diff --git a/test/spv/alloca.ir b/test/spv/alloca.ir index 9d0aaac0..1d4f8592 100644 --- a/test/spv/alloca.ir +++ b/test/spv/alloca.ir @@ -25,7 +25,7 @@ ; CHECK: %[[#I16_PTR]] = OpTypePointer Workgroup %[[I16]] func @alloca() { - %c0 = constant 0 -> index + %c0 = constant 0 : index %0 = alloca : memref %1 = alloca : memref %2 = alloca : memref diff --git a/test/spv/arith.ir b/test/spv/arith.ir index 7a7113d7..3b600b5b 100644 --- a/test/spv/arith.ir +++ b/test/spv/arith.ir @@ -90,8 +90,8 @@ func @tcomplex(%a: c32, %b: c32) { } func @tfloatcoopmatrix() subgroup_size(16) { - %0 = constant 1.0 -> coopmatrix - %1 = constant 2.0 -> coopmatrix + %0 = constant 1.0 : coopmatrix + %1 = constant 2.0 : coopmatrix %2 = arith.add %0, %1 : coopmatrix %3 = arith.sub %0, %1 : coopmatrix %4 = arith.mul %0, %1 : coopmatrix diff --git a/test/spv/arith_unary.ir b/test/spv/arith_unary.ir index 30f79255..d3c41d22 100644 --- a/test/spv/arith_unary.ir +++ b/test/spv/arith_unary.ir @@ -51,7 +51,7 @@ func @tcomplex(%a: c32) { } func @tfloatcoopmatrix() subgroup_size(16) { - %0 = constant 1.0 -> coopmatrix + %0 = constant 1.0 : coopmatrix %2 = arith.neg %0 : coopmatrix ; CHECK: %[[#]] = OpFNegate %[[#F32]] %[[#]] ; CHECK-NEXT: %[[#]] = OpFNegate %[[#F32]] %[[#]] diff --git a/test/spv/cast.ir b/test/spv/cast.ir index 0c0f5a42..62114064 100644 --- a/test/spv/cast.ir +++ b/test/spv/cast.ir @@ -40,7 +40,7 @@ func @tcomplex(%a: c32) { } func @tfloatcoopmatrix() subgroup_size(16) { - %0 = constant 1.0 -> coopmatrix + %0 = constant 1.0 : coopmatrix %2 = cast %0 : coopmatrix ; CHECK: %[[#]] = OpConvertFToS %[[#I8]] %[[#]] ; CHECK-NEXT: %[[#]] = OpConvertFToS %[[#I8]] %[[#]] diff --git a/test/spv/cooperative_matrix_load.ir b/test/spv/cooperative_matrix_load.ir index 61d64f55..8679884e 100644 --- a/test/spv/cooperative_matrix_load.ir +++ b/test/spv/cooperative_matrix_load.ir @@ -16,7 +16,7 @@ func @coopmatrix_a_load_n(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n %A[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.n %A[%x,%y] : coopmatrix ; CHECK: %[[#]] = OpFunction {{.*}} ; CHECK-NEXT: %[[#T1_MR:]] = OpFunctionParameter %[[#I32_PTR]] ; CHECK-NEXT: %[[#T1_X:]] = OpFunctionParameter %[[#I64]] @@ -36,7 +36,7 @@ func @coopmatrix_a_load_n(%A: memref, %x: index, %y: index) subgroup_ } func @coopmatrix_a_load_n_rows_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n.rows_checked %A[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.n.rows_checked %A[%x,%y] : coopmatrix ; CHECK: %[[#]] = OpFunction {{.*}} ; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I32_PTR]] ; CHECK-NEXT: %[[#T2_X:]] = OpFunctionParameter %[[#I64]] @@ -70,7 +70,7 @@ func @coopmatrix_a_load_n_rows_checked(%A: memref, %x: index, %y: ind } func @coopmatrix_a_load_n_cols_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n.cols_checked %A[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.n.cols_checked %A[%x,%y] : coopmatrix ; CHECK: %[[#]] = OpFunction {{.*}} ; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I32_PTR]] ; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I64]] @@ -90,7 +90,7 @@ func @coopmatrix_a_load_n_cols_checked(%A: memref, %x: index, %y: ind } func @coopmatrix_a_load_t(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.t %A[%x,%y] : memref -> coopmatrix + %0 = cooperative_matrix_load.t %A[%x,%y] : coopmatrix ; CHECK: %[[#]] = OpFunction {{.*}} ; CHECK: %[[#T4_M:]] = OpSConvert %[[#I64]] %[[#]] ; CHECK: %[[#T4_OFFSET:]] = OpIMul %[[#I64]] %[[#T4_M]] %[[#I64_C64]] diff --git a/test/spv/cooperative_matrix_mul_add.ir b/test/spv/cooperative_matrix_mul_add.ir index 0eb745bb..1fb06168 100644 --- a/test/spv/cooperative_matrix_mul_add.ir +++ b/test/spv/cooperative_matrix_mul_add.ir @@ -26,12 +26,10 @@ ; CHECK: %[[#I64_C3:]] = OpConstant %[[#I64]] 3 func @coopmatrix_mul_add_ff() subgroup_size(16) { - %a = constant 1.0 -> coopmatrix - %b = constant 2.0 -> coopmatrix - %c = constant 3.0 -> coopmatrix - %c_next = cooperative_matrix_mul_add %a, %b, %c - : coopmatrix, coopmatrix, - coopmatrix -> coopmatrix + %a = constant 1.0 : coopmatrix + %b = constant 2.0 : coopmatrix + %c = constant 3.0 : coopmatrix + %c_next = cooperative_matrix_mul_add %a, %b, %c : coopmatrix ; CHECK-LABEL: %[[#]] = OpFunction {{.*}} ; CHECK: %[[#FF_B0:]] = OpGroupBroadcast %[[#F32]] %[[#I32_C3]] %[[#F32_C2]] %[[#I32_C0]] ; CHECK-NEXT: %[[#FF_C0:]] = OpExtInst %[[#F32]] %[[#]] fma %[[#F32_C1]] %[[#FF_B0]] %[[#F32_C3]] @@ -44,12 +42,10 @@ func @coopmatrix_mul_add_ff() subgroup_size(16) { } func @coopmatrix_mul_add_cf() subgroup_size(16) { - %a = constant [1.0, 0.0] -> coopmatrix - %b = constant 2.0 -> coopmatrix - %c = constant [3.0, 0.0] -> coopmatrix - %c_next = cooperative_matrix_mul_add %a, %b, %c - : coopmatrix, coopmatrix, - coopmatrix -> coopmatrix + %a = constant [1.0, 0.0] : coopmatrix + %b = constant 2.0 : coopmatrix + %c = constant [3.0, 0.0] : coopmatrix + %c_next = cooperative_matrix_mul_add %a, %b, %c : coopmatrix ; CHECK-LABEL: %[[#]] = OpFunction {{.*}} ; CHECK: %[[#CF_B0:]] = OpGroupBroadcast %[[#F32]] %[[#I32_C3]] %[[#F32_C2]] %[[#I32_C0]] ; CHECK-NEXT: %[[#CF_DUMMY:]] = OpUndef %[[#C32]] @@ -59,12 +55,10 @@ func @coopmatrix_mul_add_cf() subgroup_size(16) { } func @coopmatrix_mul_add_cc() subgroup_size(16) { - %a = constant [1.0, 0.0] -> coopmatrix - %b = constant [2.0, 0.0] -> coopmatrix - %c = constant [3.0, 0.0] -> coopmatrix - %c_next = cooperative_matrix_mul_add %a, %b, %c - : coopmatrix, coopmatrix, - coopmatrix -> coopmatrix + %a = constant [1.0, 0.0] : coopmatrix + %b = constant [2.0, 0.0] : coopmatrix + %c = constant [3.0, 0.0] : coopmatrix + %c_next = cooperative_matrix_mul_add %a, %b, %c : coopmatrix ; CHECK-LABEL: %[[#]] = OpFunction {{.*}} ; CHECK: %[[#CC_B0:]] = OpGroupBroadcast %[[#C32]] %[[#I32_C3]] %[[#C32_C2_0]] %[[#I32_C0]] ; CHECK-NEXT: %[[#CC_B0_RE:]] = OpCompositeExtract %[[#F32]] %[[#CC_B0]] 0 @@ -83,12 +77,10 @@ func @coopmatrix_mul_add_cc() subgroup_size(16) { } func @coopmatrix_mul_add_ii_mixed() subgroup_size(16) { - %a = constant 1 -> coopmatrix - %b = constant 2 -> coopmatrix - %c = constant 3 -> coopmatrix - %c_next = cooperative_matrix_mul_add %a, %b, %c - : coopmatrix, coopmatrix, - coopmatrix -> coopmatrix + %a = constant 1 : coopmatrix + %b = constant 2 : coopmatrix + %c = constant 3 : coopmatrix + %c_next = cooperative_matrix_mul_add %a, %b, %c : coopmatrix ; CHECK-LABEL: %[[#]] = OpFunction {{.*}} ; CHECK: %[[#II_BC:]] = OpGroupBroadcast %[[#I32]] %[[#I32_C3]] %[[#I32_C2]] %[[#I32_C0]] ; CHECK-NEXT: %[[#II_A_I32:]] = OpSConvert %[[#I32]] %[[#I16_C1]] diff --git a/test/spv/cooperative_matrix_scale.ir b/test/spv/cooperative_matrix_scale.ir index f837cc7e..235a6f37 100644 --- a/test/spv/cooperative_matrix_scale.ir +++ b/test/spv/cooperative_matrix_scale.ir @@ -8,9 +8,9 @@ ; CHECK: %[[#F32_C13_369:]] = OpConstant %[[#F32]] 0x1.abcefap+3 func @scale() subgroup_size(16) { - %0 = constant 3.14159265358979323846 -> f32 - %1 = constant 13.36901521971920820459 -> coopmatrix - %2 = cooperative_matrix_scale %0, %1 : f32, coopmatrix + %0 = constant 3.14159265358979323846 : f32 + %1 = constant 13.36901521971920820459 : coopmatrix + %2 = cooperative_matrix_scale %0, %1 : coopmatrix ; CHECK: %[[#]] = OpFMul %[[#F32]] %[[#F32_CPI]] %[[#F32_C13_369]] ; CHECK-NEXT: %[[#]] = OpFMul %[[#F32]] %[[#F32_CPI]] %[[#F32_C13_369]] ; CHECK-NEXT: %[[#]] = OpFMul %[[#F32]] %[[#F32_CPI]] %[[#F32_C13_369]] diff --git a/test/spv/cooperative_matrix_store.ir b/test/spv/cooperative_matrix_store.ir index 0958028c..bc90ae03 100644 --- a/test/spv/cooperative_matrix_store.ir +++ b/test/spv/cooperative_matrix_store.ir @@ -13,8 +13,8 @@ func @coopmatrix_a_store_n(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = constant 1 -> coopmatrix - cooperative_matrix_store %0, %A[%x,%y] : coopmatrix, memref + %0 = constant 1 : coopmatrix + cooperative_matrix_store %0, %A[%x,%y] ; CHECK: %[[#]] = OpFunction {{.*}} ; CHECK-NEXT: %[[#T1_MR:]] = OpFunctionParameter %[[#I32_PTR]] ; CHECK-NEXT: %[[#T1_X:]] = OpFunctionParameter %[[#I64]] diff --git a/test/spv/expand.ir b/test/spv/expand.ir index e799d2fa..d1705a71 100644 --- a/test/spv/expand.ir +++ b/test/spv/expand.ir @@ -12,7 +12,7 @@ ; CHECK: %[[#I64_C5:]] = OpConstant %[[#I64]] 5 func @f1(%0: memref, %1: index) { - %c0 = constant 0 -> index + %c0 = constant 0 : index %2 = expand %0[1->4x%1x5] : memref %3 = size %2[0] : memref %4 = size %2[1] : memref diff --git a/test/spv/for.ir b/test/spv/for.ir index c7666930..11bec7c7 100644 --- a/test/spv/for.ir +++ b/test/spv/for.ir @@ -18,9 +18,9 @@ ; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s func @for1() { - %lb = constant 0 -> i16 - %ub = constant 10 -> i16 - %step = constant 2 -> i16 + %lb = constant 0 : i16 + %ub = constant 10 : i16 + %step = constant 2 : i16 for %0 = %lb,%ub,%step : i16 { } ; CHECK: %[[#]] = OpFunction {{.*}} @@ -40,10 +40,10 @@ func @for1() { } func @for2() { - %from = constant 2 -> i32 - %to = constant 6 -> i32 - %f0 = constant 0 -> i64 - %f1 = constant 1 -> i64 + %from = constant 2 : i32 + %to = constant 6 : i32 + %f0 = constant 0 : i64 + %f1 = constant 1 : i64 %fn_1, %fn = for %n=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) : i32 { %fn = arith.add %fn_2, %fn_1 : i64 yield %fn_1, %fn : i64, i64 @@ -72,9 +72,9 @@ func @for2() { } func @for3() subgroup_size(16) { - %from = constant 2 -> i16 - %to = constant 6 -> i16 - %m_init = constant 1 -> coopmatrix + %from = constant 2 : i16 + %to = constant 6 : i16 + %m_init = constant 1 : coopmatrix %m = for %n=%from,%to init(%m_iter=%m_init) -> (coopmatrix) : i16 { %m_update = arith.add %m_iter, %m_init : coopmatrix yield %m_update : coopmatrix diff --git a/test/spv/fuse.ir b/test/spv/fuse.ir index a5118136..663bb6dc 100644 --- a/test/spv/fuse.ir +++ b/test/spv/fuse.ir @@ -13,7 +13,7 @@ ; CHECK: %[[#I64_C0:]] = OpConstant %[[#I64]] 0 func @f1(%0: memref) { - %z = constant 0 -> index + %z = constant 0 : index %1 = fuse %0[1,3] : memref %2 = size %1[0] : memref> %3 = size %1[1] : memref> diff --git a/test/spv/if.ir b/test/spv/if.ir index 420bc70b..7dfc728b 100644 --- a/test/spv/if.ir +++ b/test/spv/if.ir @@ -12,7 +12,7 @@ ; CHECK: %[[#CST0:]] = OpConstant %[[#F32]] 0x0p+0 func @if0(%0: i32) { - %c42 = constant 42 -> i32 + %c42 = constant 42 : i32 %1 = cmp.lt %0, %c42 : bool if %1 { %2 = arith.neg %0 : i32 @@ -33,7 +33,7 @@ func @if0(%0: i32) { } func @if1() { - %c1 = constant true -> bool + %c1 = constant true : bool if %c1 -> (){ yield : } else { @@ -44,12 +44,12 @@ func @if1() { } func @if2(%0: i32) { - %c1 = constant true -> bool + %c1 = constant true : bool %x = if %c1 -> (i32) { %1 = if %c1 -> (i32) { yield %0 : i32 } else { - %c0 = constant 0 -> i32 + %c0 = constant 0 : i32 yield %c0 : i32 } yield %1 : i32 @@ -82,12 +82,12 @@ func @if2(%0: i32) { } func @if3() subgroup_size(16) { - %c1 = constant true -> bool + %c1 = constant true : bool %y, %x = if %c1 -> (bool,coopmatrix) { - %0 = constant 1.0 -> coopmatrix + %0 = constant 1.0 : coopmatrix yield %c1, %0 : bool, coopmatrix } else { - %1 = constant 0.0 -> coopmatrix + %1 = constant 0.0 : coopmatrix yield %c1, %1 : bool, coopmatrix } %z = arith.neg %x : coopmatrix diff --git a/test/spv/load.ir b/test/spv/load.ir index a17c3da8..70248fe4 100644 --- a/test/spv/load.ir +++ b/test/spv/load.ir @@ -13,7 +13,7 @@ ; CHECK: %[[#PTR_PTR_F32:]] = OpTypePointer CrossWorkgroup %[[#PTR_F32]] func @l1(%0: memref) { - %2 = constant 0 -> index + %2 = constant 0 : index %3 = load %0[%2,%2] : memref ; CHECK: %[[#]] = OpFunction {{.*}} ; CHECK-NEXT: %[[#L1_MR:]] = OpFunctionParameter %[[#PTR_F32]] @@ -29,7 +29,7 @@ func @l1(%0: memref) { } func @l2(%0: group>, offset: ?>) { - %1 = constant 0 -> index + %1 = constant 0 : index %2 = load %0[%1] : group>, offset: ?> %3 = load %2[%1] : memref> ; CHECK: %[[#]] = OpFunction {{.*}} diff --git a/test/spv/store.ir b/test/spv/store.ir index 79d8f374..4659cbdb 100644 --- a/test/spv/store.ir +++ b/test/spv/store.ir @@ -30,8 +30,8 @@ ; CHECK: %[[#I32_C1:]] = OpConstant %[[#I32]] 1 func @si8(%0: memref, %1: memref) { - %2 = constant 0 -> index - %3 = constant -42 -> i8 + %2 = constant 0 : index + %3 = constant -42 : i8 store %3, %0[%2,%2] : memref store.atomic %3, %1[] : memref store.atomic_add %3, %1[] : memref @@ -49,7 +49,7 @@ func @si8(%0: memref, %1: memref) { } func @sf32(%0: memref) { - %1 = constant 42.0 -> f32 + %1 = constant 42.0 : f32 store.atomic %1, %0[] : memref store.atomic_add %1, %0[] : memref ; CHECK: %[[#]] = OpFunction {{.*}} @@ -59,7 +59,7 @@ func @sf32(%0: memref) { } func @sc64(%0: memref) { - %1 = constant [42.0, 1.0] -> c64 + %1 = constant [42.0, 1.0] : c64 store.atomic %1, %0[] : memref store.atomic_add %1, %0[] : memref ; CHECK: %[[#]] = OpFunction {{.*}} diff --git a/test/spv/work_group.ir b/test/spv/work_group.ir index c3748978..7f9b840b 100644 --- a/test/spv/work_group.ir +++ b/test/spv/work_group.ir @@ -11,9 +11,9 @@ ; CHECK: %[[#SCOPE:]] = OpConstant %[[#]] 2 func @twg() { - %0 = constant 1 -> i16 - %1 = constant 1.0 -> f32 - %2 = constant [1.0, 0.0] -> c64 + %0 = constant 1 : i16 + %1 = constant 1.0 : f32 + %2 = constant [1.0, 0.0] : c64 %3 = work_group.reduce_add %0 : i16 %4 = work_group.reduce_add %1 : f32 %5 = work_group.reduce_add %2 : c64 From 828f63508dbecb238c7cb0b85b811f980f499fda Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 21 Nov 2024 17:22:41 +0100 Subject: [PATCH 121/297] Update expand and fuse Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 2 + include/tinytc/tinytc.h | 12 ++- include/tinytc/tinytc.hpp | 11 ++- include/tinytc/types.h | 36 +++++---- include/tinytc/types.hpp | 2 + src/error.cpp | 6 +- src/inst.cpp | 12 +-- src/node/inst_node.cpp | 131 ++++++++++++++++++++++--------- src/node/inst_node.hpp | 6 +- src/parser/parser_impl.yy | 18 +---- src/pass/clone.cpp | 9 ++- src/pass/dump_ir.cpp | 4 +- test/codegen/expand.ir | 26 +++--- test/codegen/fuse.ir | 8 +- test/opt/check-ir/expand.ir | 57 ++++---------- test/opt/insert-lifetime-stop.ir | 2 +- test/spv/expand.ir | 2 +- test/spv/fuse.ir | 8 +- 18 files changed, 198 insertions(+), 154 deletions(-) diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 2d2fb113..ec54e716 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -1110,6 +1110,7 @@ Restrictions The memref type of the result must conform with the following rules: +#. Element type and address space must match the operand's memref type. #. **Shape:** The mode size is replaced with the expand shape. The product of the expand shape must equal the size of the expanded mode. @@ -1256,6 +1257,7 @@ Restrictions The memref type of the result must conform with the following rules: +#. Element type and address space must match the operand's memref type. #. **Shape:** The mode size of the fused modes is the product of the mode sizes. If one mode is dynamic the fused mode size is dynamic. .. code:: diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 0fb05610..606e8a9d 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -493,14 +493,16 @@ TINYTC_EXPORT tinytc_status_t tinytc_axpby_inst_create(tinytc_inst_t *instr, tin * @param expand_shape_size [in][optional] dimension of expand shape; must match number of entries * equal to TINYTC_DYNAMIC in static_expand_shape array; can be 0 * @param expand_shape [in][optional][range(0, expand_shape_size)] expand shape array + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_expand_inst_create( - tinytc_inst_t *instr, tinytc_value_t a, int64_t expanded_mode, - uint32_t static_expand_shape_size, const int64_t *static_expand_shape, - uint32_t expand_shape_size, const tinytc_value_t *expand_shape, const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t +tinytc_expand_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t expanded_mode, + uint32_t static_expand_shape_size, const int64_t *static_expand_shape, + uint32_t expand_shape_size, const tinytc_value_t *expand_shape, + tinytc_data_type_t ty, const tinytc_location_t *loc); /** * @brief Create fuse instruction @@ -511,12 +513,14 @@ TINYTC_EXPORT tinytc_status_t tinytc_expand_inst_create( * @param a [in] operand * @param from [in] first mode to fuse * @param to [in] last mode to fuse + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_fuse_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t from, int64_t to, + tinytc_data_type_t ty, const tinytc_location_t *loc); /** diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 9dafba64..0304c54d 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1168,13 +1168,14 @@ inline inst make_axpby(transpose tA, bool atomic, value alpha, value A, value be * @param expanded_mode Expanded mode * @param static_expand_shape Static expand shape * @param expand_shape Dynamic expand shape + * @param ty Result type * @param loc Source code location * * @return Instruction */ inline inst make_expand(value a, std::int64_t expanded_mode, array_view static_expand_shape, - array_view expand_shape, location const &loc = {}) { + array_view expand_shape, data_type ty, location const &loc = {}) { tinytc_inst_t instr; auto static_len = static_expand_shape.size(); if (static_len > std::numeric_limits::max()) { @@ -1186,7 +1187,7 @@ inline inst make_expand(value a, std::int64_t expanded_mode, } const tinytc_value_t *es = reinterpret_cast(expand_shape.data()); CHECK_STATUS_LOC(tinytc_expand_inst_create(&instr, a, expanded_mode, static_len, - static_expand_shape.data(), len, es, &loc), + static_expand_shape.data(), len, es, ty, &loc), loc); return inst(instr); } @@ -1197,13 +1198,15 @@ inline inst make_expand(value a, std::int64_t expanded_mode, * @param a Operand * @param from First mode to fuse * @param to Last mode to fuse + * @param ty Result type * @param loc Source code location * * @return Instruction */ -inline inst make_fuse(value a, std::int64_t from, std::int64_t to, location const &loc = {}) { +inline inst make_fuse(value a, std::int64_t from, std::int64_t to, data_type ty, + location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_fuse_inst_create(&instr, a, from, to, &loc), loc); + CHECK_STATUS_LOC(tinytc_fuse_inst_create(&instr, a, from, to, ty, &loc), loc); return inst(instr); } diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 686671ff..6c8379a5 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -77,26 +77,28 @@ typedef enum { tinytc_status_ir_spmd_called_from_collective = 0x11d, ///< SPMD instruction from collective tinytc_status_ir_expected_local_address_space = 0x11e, ///< Expected local address space tinytc_status_ir_expected_global_address_space = 0x11f, ///< Expected global address space - tinytc_status_ir_invalid_offset = 0x120, ///< Invalid offset - tinytc_status_ir_int_unsupported = 0x121, ///< Instruction does not support int type - tinytc_status_ir_boolean_unsupported = 0x122, ///< Instruction does not support boolean type - tinytc_status_ir_complex_unsupported = 0x123, ///< Instruction does not support complex type + tinytc_status_ir_address_space_mismatch = 0x120, ///< Address space must match + tinytc_status_ir_invalid_offset = 0x121, ///< Invalid offset + tinytc_status_ir_int_unsupported = 0x122, ///< Instruction does not support int type + tinytc_status_ir_boolean_unsupported = 0x123, ///< Instruction does not support boolean type + tinytc_status_ir_complex_unsupported = 0x124, ///< Instruction does not support complex type tinytc_status_ir_coopmatrix_unsupported = - 0x124, ///< Instruction does not support coopmatrix type - tinytc_status_ir_forbidden_cast = 0x125, ///< Forbidden cast - tinytc_status_ir_invalid_beta = 0x126, ///< Invalid beta value - tinytc_status_ir_init_return_mismatch = 0x127, ///< Mismatch of init values and returned values - tinytc_status_ir_invalid_matrix_use = 0x128, ///< Invalid matrix use - tinytc_status_ir_unsupported_coopmatrix_shape = 0x129, ///< Unsupported coopmatrix shape - tinytc_status_ir_incompatible_scalar_types = 0x12a, ///< Incompatible scalar types - tinytc_status_ir_constant_mismatch = 0x12b, ///< Constant mismatch - tinytc_status_ir_insufficient_alignment = 0x12c, ///< Insufficient alignment - tinytc_status_ir_must_have_yield = 0x12d, ///< Must have yield instruction + 0x125, ///< Instruction does not support coopmatrix type + tinytc_status_ir_forbidden_cast = 0x126, ///< Forbidden cast + tinytc_status_ir_invalid_beta = 0x127, ///< Invalid beta value + tinytc_status_ir_init_return_mismatch = 0x128, ///< Mismatch of init values and returned values + tinytc_status_ir_invalid_matrix_use = 0x129, ///< Invalid matrix use + tinytc_status_ir_unsupported_coopmatrix_shape = 0x12a, ///< Unsupported coopmatrix shape + tinytc_status_ir_incompatible_scalar_types = 0x12b, ///< Incompatible scalar types + tinytc_status_ir_constant_mismatch = 0x12c, ///< Constant mismatch + tinytc_status_ir_insufficient_alignment = 0x12d, ///< Insufficient alignment + tinytc_status_ir_must_have_yield = 0x12e, ///< Must have yield instruction tinytc_status_ir_yield_in_else_branch_missing = - 0x12e, ///< Must have yield instruction in else branch - tinytc_status_ir_from_to_mismatch = 0x12f, ///< size(from) != size(to) in foreach + 0x12f, ///< Must have yield instruction in else branch + tinytc_status_ir_from_to_mismatch = 0x130, ///< size(from) != size(to) in foreach tinytc_status_ir_operand_type_must_match_return_type = - 0x130, /// Operand type must match return type + 0x131, /// Operand type must match return type + tinytc_status_ir_invalid_stride = 0x132, ///< Invalid stride // SPIR-V errors tinytc_status_spirv_forbidden_forward_declaration = 0x1000, ///< Forward declaration of id is forbidden diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index f7c53aa0..f63c9902 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -86,6 +86,7 @@ enum class status { ir_spmd_called_from_collective = tinytc_status_ir_spmd_called_from_collective, ir_expected_local_address_space = tinytc_status_ir_expected_local_address_space, ir_expected_global_address_space = tinytc_status_ir_expected_global_address_space, + ir_address_space_mismatch = tinytc_status_ir_address_space_mismatch, ir_invalid_offset = tinytc_status_ir_invalid_offset, ir_int_unsupported = tinytc_status_ir_int_unsupported, ir_boolean_unsupported = tinytc_status_ir_boolean_unsupported, @@ -103,6 +104,7 @@ enum class status { ir_yield_in_else_branch_missing = tinytc_status_ir_yield_in_else_branch_missing, ir_from_to_mismatch = tinytc_status_ir_from_to_mismatch, ir_operand_type_must_match_return_type = tinytc_status_ir_operand_type_must_match_return_type, + ir_invalid_stride = tinytc_status_ir_invalid_stride, spirv_forbidden_forward_declaration = tinytc_status_spirv_forbidden_forward_declaration, spirv_undefined_value = tinytc_status_spirv_undefined_value, spirv_missing_dope_vector = tinytc_status_spirv_missing_dope_vector, diff --git a/src/error.cpp b/src/error.cpp index c11c6412..5368474c 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -137,7 +137,7 @@ char const *tinytc_error_string(tinytc_status_t status) { case tinytc_status_ir_out_of_bounds: return "Argument is out of bounds"; case tinytc_status_ir_invalid_shape: - return "Mode size must be non-negative"; + return "Invalid shape"; case tinytc_status_ir_incompatible_shapes: return "Incompatible tensor shapes"; case tinytc_status_ir_shape_stride_mismatch: @@ -200,6 +200,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "A memref with local address space is expected"; case tinytc_status_ir_expected_global_address_space: return "A memref with global address space is expected"; + case tinytc_status_ir_address_space_mismatch: + return "Address space must match"; case tinytc_status_ir_invalid_offset: return "Offset must be non-negative or dynamic"; case tinytc_status_ir_int_unsupported: @@ -235,6 +237,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "length(from) must equal length(to) and length must be greater than 0"; case tinytc_status_ir_operand_type_must_match_return_type: return "Type of operand must match return type"; + case tinytc_status_ir_invalid_stride: + return "Invalid stride"; // SPIR-V case tinytc_status_spirv_forbidden_forward_declaration: return "Forward declaration of id is forbidden"; diff --git a/src/inst.cpp b/src/inst.cpp index 4052f608..dbbdbc8e 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -407,7 +407,7 @@ tinytc_status_t tinytc_expand_inst_create(tinytc_inst_t *instr, tinytc_value_t a int64_t expanded_mode, uint32_t static_expand_shape_size, const int64_t *static_expand_shape, uint32_t expand_shape_size, - const tinytc_value_t *expand_shape, + const tinytc_value_t *expand_shape, tinytc_data_type_t ty, const tinytc_location_t *loc) { if (instr == nullptr || static_expand_shape == nullptr || (expand_shape_size > 0 && expand_shape == nullptr)) { @@ -416,18 +416,20 @@ tinytc_status_t tinytc_expand_inst_create(tinytc_inst_t *instr, tinytc_value_t a return exception_to_status_code([&] { *instr = std::make_unique( a, expanded_mode, array_view{static_expand_shape, static_expand_shape_size}, - array_view{expand_shape, expand_shape_size}, get_optional(loc)) + array_view{expand_shape, expand_shape_size}, ty, get_optional(loc)) .release(); }); } tinytc_status_t tinytc_fuse_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t from, - int64_t to, const tinytc_location_t *loc) { + int64_t to, tinytc_data_type_t ty, + const tinytc_location_t *loc) { if (instr == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code( - [&] { *instr = std::make_unique(a, from, to, get_optional(loc)).release(); }); + return exception_to_status_code([&] { + *instr = std::make_unique(a, from, to, ty, get_optional(loc)).release(); + }); } tinytc_status_t tinytc_load_inst_create(tinytc_inst_t *instr, tinytc_value_t a, diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 91502b42..e6c90138 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -108,6 +108,31 @@ void check_index_ty(location const &loc, tinytc_value const &v) { } } +void check_memref_shape(memref_data_type *rt, std::int64_t ri, memref_data_type *ot, + std::int64_t oi, location const &loc) { + if (rt->shape(ri) != ot->shape(oi)) { + auto extra_info = std::ostringstream{} << "Size of mode " << ri + << " does not match operand mode " << oi << " [" + << rt->shape(oi) << "!=" << ot->shape(oi) << "]"; + throw compilation_error(loc, status::ir_invalid_shape, std::move(extra_info).str()); + } +} +void check_memref_stride(memref_data_type *rt, std::int64_t ri, memref_data_type *ot, + std::int64_t oi, location const &loc) { + if (!is_dynamic_value(rt->stride(ri)) && rt->stride(ri) != ot->stride(oi)) { + auto extra_info = std::ostringstream{} << "Stride of mode " << ri + << " does not match operand stride " << oi << " [" + << rt->stride(oi) << "!=" << ot->stride(oi) << "]"; + throw compilation_error(loc, status::ir_invalid_stride, std::move(extra_info).str()); + } +} + +void check_memref_mode(memref_data_type *rt, std::int64_t ri, memref_data_type *ot, std::int64_t oi, + location const &loc) { + check_memref_shape(rt, ri, ot, oi, loc); + check_memref_stride(rt, ri, ot, oi, loc); +} + blas_a2_inst::blas_a2_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t beta, tinytc_value_t B, bool atomic, location const &lc) : standard_inst{tid}, atomic_(atomic) { @@ -584,7 +609,8 @@ cooperative_matrix_store_inst::cooperative_matrix_store_inst(checked_flag cflag, expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, array_view static_expand_shape0, - array_view expand_shape0, location const &lc) + array_view expand_shape0, tinytc_data_type_t ty, + location const &lc) : standard_inst{IK::expand, static_cast(1 + expand_shape0.size())}, expanded_mode_(expanded_mode), static_expand_shape_(std::move(static_expand_shape0)) { op(0, op0); @@ -593,10 +619,22 @@ expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, } loc(lc); + auto rt = dyn_cast(ty); + if (!rt) { + throw compilation_error(loc(), status::ir_expected_memref); + } + auto m = get_memref_type(loc(), operand()); + if (rt->element_data_ty() != m->element_data_ty()) { + throw compilation_error(loc(), {&operand()}, status::ir_scalar_mismatch); + } + if (rt->addrspace() != m->addrspace()) { + throw compilation_error(loc(), {&operand()}, status::ir_address_space_mismatch); + } + bool const range_ok = 0 <= expanded_mode_ && expanded_mode_ < m->dim(); if (!range_ok) { - throw compilation_error(loc(), status::ir_out_of_bounds); + throw compilation_error(loc(), {&operand()}, status::ir_out_of_bounds); } if (static_expand_shape_.size() < 2) { @@ -607,30 +645,32 @@ expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, throw compilation_error(loc(), status::ir_expand_shape_mismatch); } - auto shape = std::vector{}; - auto stride = std::vector{}; - shape.reserve(m->dim() + static_expand_shape_.size() - 1); - stride.reserve(m->dim() + static_expand_shape_.size() - 1); for (std::int64_t i = 0; i < expanded_mode_; ++i) { - shape.push_back(m->shape(i)); - stride.push_back(m->stride(i)); - } - - stride.push_back(m->stride(expanded_mode_)); - shape.push_back(static_expand_shape_[0]); - for (std::size_t j = 1; j < static_expand_shape_.size(); ++j) { - stride.push_back(is_dynamic_value(stride.back()) || is_dynamic_value(shape.back()) - ? dynamic - : stride.back() * shape.back()); - shape.push_back(static_expand_shape_[j]); + check_memref_mode(rt, i, m, i, loc()); + } + auto stride = m->stride(expanded_mode_); + for (std::size_t i = 0; i < static_expand_shape_.size(); ++i) { + const auto mode = expanded_mode_ + i; + if (rt->shape(mode) != static_expand_shape()[i]) { + auto extra_info = std::ostringstream{} + << "Size of mode " << mode << " does not match static expand shape (" + << rt->shape(mode) << "!=" << static_expand_shape()[i] << ")"; + throw compilation_error(loc(), status::ir_invalid_shape, std::move(extra_info).str()); + } + if (!is_dynamic_value(rt->stride(mode)) && rt->stride(mode) != stride) { + auto extra_info = std::ostringstream{} << "Stride of mode " << mode << " is invalid (" + << rt->stride(mode) << "!=" << stride << ")"; + throw compilation_error(loc(), status::ir_invalid_stride, std::move(extra_info).str()); + } + stride = is_dynamic_value(stride) || is_dynamic_value(rt->shape(mode)) + ? dynamic + : stride * rt->shape(mode); } for (std::int64_t i = expanded_mode_ + 1; i < m->dim(); ++i) { - shape.push_back(m->shape(i)); - stride.push_back(m->stride(i)); + check_memref_mode(rt, i + static_expand_shape_.size() - 1, m, i, loc()); } - auto result_ty = memref_data_type::get(m->element_data_ty(), shape, stride, m->addrspace()); - result(0) = value_node{result_ty, this, lc}; + result(0) = value_node{ty, this, lc}; } for_inst::for_inst(tinytc_value_t from0, tinytc_value_t to0, tinytc_value_t step0, @@ -711,40 +751,55 @@ foreach_inst::foreach_inst(array_view from, array_view(ty); + if (!rt) { + throw compilation_error(loc(), status::ir_expected_memref); + } + auto m = get_memref_type(loc(), operand()); + if (rt->element_data_ty() != m->element_data_ty()) { + throw compilation_error(loc(), {&operand()}, status::ir_scalar_mismatch); + } + if (rt->addrspace() != m->addrspace()) { + throw compilation_error(loc(), {&operand()}, status::ir_address_space_mismatch); + } + bool const range_ok = 0 <= from_ && from_ < to_ && to_ < m->dim(); if (!range_ok) { throw compilation_error(loc(), status::ir_out_of_bounds); } - auto shape = std::vector{}; - auto stride = std::vector{}; - shape.reserve(m->dim()); - stride.reserve(m->dim()); - std::int64_t i = 0; - for (; i < from_; ++i) { - shape.push_back(m->shape(i)); - stride.push_back(m->stride(i)); + + for (std::int64_t i = 0; i < from_; ++i) { + check_memref_mode(rt, i, m, i, loc()); } + std::int64_t prod = 1; - for (; i <= to_; ++i) { + for (std::int64_t i = from_; i <= to_; ++i) { if (is_dynamic_value(m->shape(i))) { prod = dynamic; break; } prod *= m->shape(i); } - shape.push_back(prod); - stride.push_back(m->stride(from_)); - for (i = to_ + 1; i < m->dim(); ++i) { - shape.push_back(m->shape(i)); - stride.push_back(m->stride(i)); + if (rt->shape(from_) != prod) { + auto extra_info = std::ostringstream{} << "Size of mode " << from_ + << " does not match shape product (" + << rt->shape(from_) << "!=" << prod << ")"; + throw compilation_error(loc(), status::ir_invalid_shape, std::move(extra_info).str()); } - auto result_ty = memref_data_type::get(m->element_data_ty(), shape, stride, m->addrspace()); - result(0) = value_node{result_ty, this, lc}; + check_memref_stride(rt, from_, m, from_, loc()); + + for (std::int64_t i = to_ + 1; i < m->dim(); ++i) { + check_memref_mode(rt, i - to_ + from_, m, i, loc()); + } + + result(0) = value_node{ty, this, lc}; } load_inst::load_inst(tinytc_value_t op0, array_view index_list0, location const &lc) diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 11d029f1..c8ea0257 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -542,7 +542,8 @@ class expand_inst : public standard_inst { inline static bool classof(inst_node const &i) { return i.type_id() == IK::expand; } expand_inst(tinytc_value_t op, std::int64_t expanded_mode, array_view static_expand_shape, - array_view expand_shape, location const &lc = {}); + array_view expand_shape, tinytc_data_type_t ty, + location const &lc = {}); inline std::int64_t expanded_mode() const { return expanded_mode_; } inline auto static_expand_shape() const -> array_view { @@ -563,7 +564,8 @@ class expand_inst : public standard_inst { class fuse_inst : public standard_inst<1, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::fuse; } - fuse_inst(tinytc_value_t op, std::int64_t from, std::int64_t to, location const &lc = {}); + fuse_inst(tinytc_value_t op, std::int64_t from, std::int64_t to, tinytc_data_type_t ty, + location const &lc = {}); inline auto operand() -> tinytc_value & { return op(0); } inline auto operand() const -> tinytc_value const & { return op(0); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index ed2de8f6..890da4a7 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -959,12 +959,7 @@ cooperative_matrix_store_inst: ; expand_inst: - EXPAND var LSQBR INTEGER_CONSTANT[expanded_mode] ARROW expand_shape RSQBR COLON memref_type { - if ($var->ty() != $memref_type) { - auto loc = @var; - loc.end = @memref_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); - } + EXPAND var LSQBR INTEGER_CONSTANT[expanded_mode] ARROW expand_shape RSQBR COLON memref_type[ty] { try { auto static_shape = std::vector{}; static_shape.reserve($expand_shape.size()); @@ -981,7 +976,7 @@ expand_inst: } $$ = inst { std::make_unique(std::move($var), $expanded_mode, std::move(static_shape), - std::move(dynamic_shape), @expand_inst) + std::move(dynamic_shape), $ty, @expand_inst) .release() }; } catch (compilation_error const &e) { @@ -1011,15 +1006,10 @@ integer_constant_or_identifier: ; fuse_inst: - FUSE var LSQBR INTEGER_CONSTANT[from] COMMA INTEGER_CONSTANT[to] RSQBR COLON memref_type { - if ($var->ty() != $memref_type) { - auto loc = @var; - loc.end = @memref_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); - } + FUSE var LSQBR INTEGER_CONSTANT[from] COMMA INTEGER_CONSTANT[to] RSQBR COLON memref_type[ty] { try { $$ = inst { - std::make_unique(std::move($var), $from, $to, @fuse_inst).release() + std::make_unique(std::move($var), $from, $to, $ty, @fuse_inst).release() }; } catch (compilation_error const &e) { report_error(ctx.cctx(), e); diff --git a/src/pass/clone.cpp b/src/pass/clone.cpp index 1cfdafef..69a5cbbb 100644 --- a/src/pass/clone.cpp +++ b/src/pass/clone.cpp @@ -64,12 +64,13 @@ auto inst_cloner::operator()(cooperative_matrix_store_inst &in) -> std::unique_p subs(&in.pos1()), in.loc()); } auto inst_cloner::operator()(expand_inst &in) -> std::unique_ptr { - return std::make_unique(subs(&in.operand()), in.expanded_mode(), - in.static_expand_shape(), - subs_value_range(in.expand_shape()), in.loc()); + return std::make_unique( + subs(&in.operand()), in.expanded_mode(), in.static_expand_shape(), + subs_value_range(in.expand_shape()), in.result(0).ty(), in.loc()); } auto inst_cloner::operator()(fuse_inst &in) -> std::unique_ptr { - return std::make_unique(subs(&in.operand()), in.from(), in.to(), in.loc()); + return std::make_unique(subs(&in.operand()), in.from(), in.to(), in.result(0).ty(), + in.loc()); } auto inst_cloner::operator()(lifetime_stop_inst &in) -> std::unique_ptr { diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 8c9cdd6f..caa7d2c5 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -260,7 +260,7 @@ void dump_ir_pass::operator()(expand_inst const &e) { } } *os_ << "] : "; - visit(*this, *e.operand().ty()); + visit(*this, *e.result(0).ty()); } void dump_ir_pass::operator()(fuse_inst const &f) { @@ -269,7 +269,7 @@ void dump_ir_pass::operator()(fuse_inst const &f) { dump_val(f.operand()); *os_ << "[" << f.from() << "," << f.to() << "]"; *os_ << " : "; - visit(*this, *f.operand().ty()); + visit(*this, *f.result(0).ty()); } void dump_ir_pass::operator()(load_inst const &e) { diff --git a/test/codegen/expand.ir b/test/codegen/expand.ir index c3722880..1e5a9565 100644 --- a/test/codegen/expand.ir +++ b/test/codegen/expand.ir @@ -4,7 +4,7 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @t1(%0: memref) { %z = constant 0 : index - %1 = expand %0[1->2x8] : memref + %1 = expand %0[1->2x8] : memref %2 = load %1[%z,%z,%z,%z] : memref ; CHECK-LABEL: void t1( ; CHECK: global float* x1 = x; @@ -12,7 +12,7 @@ func @t1(%0: memref) { } func @t2(%0: memref) { %z = constant 0 : index - %1 = expand %0[1->2x2x2x2] : memref + %1 = expand %0[1->2x2x2x2] : memref %2 = load %1[%z,%z,%z,%z,%z,%z] : memref ; CHECK-LABEL: void t2( ; CHECK: global float* x1 = x; @@ -20,7 +20,7 @@ func @t2(%0: memref) { } func @t3(%0: memref, %1: index) { %z = constant 0 : index - %2 = expand %0[1->%1 x 2] : memref + %2 = expand %0[1->%1 x 2] : memref %3 = load %2[%z,%z,%z] : memref ; CHECK-LABEL: void t3( ; CHECK: global float* x2 = x; @@ -30,7 +30,7 @@ func @t3(%0: memref, %1: index) { } func @t4(%0: memref, %1: index) { %z = constant 0 : index - %2 = expand %0[1->2 x %1] : memref + %2 = expand %0[1->2 x %1] : memref %3 = load %2[%z,%z,%z] : memref ; CHECK-LABEL: void t4( ; CHECK: global float* x2 = x; @@ -39,7 +39,7 @@ func @t4(%0: memref, %1: index) { } func @t5(%0: memref, %1: index) { %z = constant 0 : index - %2 = expand %0[1->%1 x 2] : memref + %2 = expand %0[1->%1 x 2] : memref %3 = load %2[%z,%z,%z] : memref ; CHECK-LABEL: void t5( ; CHECK: global float* x2 = x; @@ -49,7 +49,7 @@ func @t5(%0: memref, %1: index) { } func @t6(%0: memref, %1: index) { %z = constant 0 : index - %2 = expand %0[1->%1 x 2] : memref + %2 = expand %0[1->%1 x 2] : memref %3 = load %2[%z,%z,%z] : memref ; CHECK-LABEL: void t6( ; CHECK: global float* x2 = x; @@ -59,7 +59,7 @@ func @t6(%0: memref, %1: index) { } func @t7(%0: memref, %1: index, %2: index) { %z = constant 0 : index - %3 = expand %0[1->%1 x %2 x 2] : memref + %3 = expand %0[1->%1 x %2 x 2] : memref %4 = load %3[%z,%z,%z,%z] : memref ; CHECK-LABEL: void t7( ; CHECK: global float* x3 = x; @@ -71,7 +71,7 @@ func @t7(%0: memref, %1: index, %2: index) { } func @t8(%0: memref, %1: index, %2: index) { %z = constant 0 : index - %3 = expand %0[1->%2 x 2 x %1] : memref + %3 = expand %0[1->%2 x 2 x %1] : memref %4 = load %3[%z,%z,%z,%z] : memref ; CHECK-LABEL: void t8( ; CHECK: global float* x3 = x; @@ -83,7 +83,7 @@ func @t8(%0: memref, %1: index, %2: index) { } func @t9(%0: memref, %1: index, %2: index) { %z = constant 0 : index - %3 = expand %0[1->%1 x %2] : memref + %3 = expand %0[1->%1 x %2] : memref %4 = load %3[%z,%z,%z] : memref ; CHECK-LABEL: void t9( ; CHECK: global float* x3 = x; @@ -94,7 +94,7 @@ func @t9(%0: memref, %1: index, %2: index) { } func @t10(%0: memref, %1: index, %2: index) { %z = constant 0 : index - %3 = expand %0[1->%1 x %2] : memref + %3 = expand %0[1->%1 x %2] : memref %4 = load %3[%z,%z,%z] : memref ; CHECK-LABEL: void t10( ; CHECK: global float* x3 = x; @@ -105,7 +105,7 @@ func @t10(%0: memref, %1: index, %2: index) { } func @t11(%0: memref>) { %z = constant 0 : index - %1 = expand %0[0->4 x 8] : memref> + %1 = expand %0[0->4 x 8] : memref> %2 = load %1[%z,%z,%z] : memref> ; CHECK-LABEL: void t11( ; CHECK: global float* x1 = x; @@ -113,7 +113,7 @@ func @t11(%0: memref>) { } func @t12(%0: memref>, %1: index) { %z = constant 0 : index - %2 = expand %0[0->%1 x 4] : memref> + %2 = expand %0[0->%1 x 4] : memref> %3 = load %2[%z,%z,%z] : memref> ; CHECK-LABEL: void t12( ; CHECK: global float* x2 = x; @@ -124,7 +124,7 @@ func @t12(%0: memref>, %1: index) { } func @t13(%0: memref>, %1: index) { %z = constant 0 : index - %2 = expand %0[0->4 x %1] : memref> + %2 = expand %0[0->4 x %1] : memref> %3 = load %2[%z,%z,%z] : memref> ; CHECK-LABEL: void t13( ; CHECK: global float* x2 = x; diff --git a/test/codegen/fuse.ir b/test/codegen/fuse.ir index ba6c7c5f..cf080a04 100644 --- a/test/codegen/fuse.ir +++ b/test/codegen/fuse.ir @@ -4,13 +4,13 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @t1(%0: memref) { %z = constant 0 : index - %1 = fuse %0[1,3] : memref + %1 = fuse %0[1,3] : memref %2 = load %1[%z,%z,%z] : memref ; CHECK: float x2 = *(x1 + z * 1 + z * 32 + z * 16384); } func @t2(%0: memref) { %z = constant 0 : index - %1 = fuse %0[1,3] : memref + %1 = fuse %0[1,3] : memref> %2 = load %1[%z,%z,%z] : memref> ; CHECK: long x_shape1 = 16 * x_shape2 * 4; ; CHECK-NEXT: long x_stride2 = x_stride4; @@ -18,13 +18,13 @@ func @t2(%0: memref) { } func @t3(%0: memref>) { %z = constant 0 : index - %1 = fuse %0[1,2] : memref> + %1 = fuse %0[1,2] : memref> %2 = load %1[%z,%z,%z] : memref> ; CHECK: float x2 = *(x1 + z * 1 + z * 48 + z * 1536); } func @t4(%0: memref>) { %z = constant 0 : index - %1 = fuse %0[0,1] : memref> + %1 = fuse %0[0,1] : memref> %2 = load %1[%z,%z] : memref> ; CHECK: long x_shape0 = 8 * x_shape1; ; CHECK-NEXT: long x_stride11 = x_stride2; diff --git a/test/opt/check-ir/expand.ir b/test/opt/check-ir/expand.ir index 7bfc61a8..34180db0 100644 --- a/test/opt/check-ir/expand.ir +++ b/test/opt/check-ir/expand.ir @@ -8,67 +8,44 @@ ; CHECK: func @t1({{.*}} func @t1(%0: memref) { - %z = constant 0 : index - %1 = expand %0[1->2x8] : memref - %2 = load %1[%z,%z,%z,%z] : memref + %1 = expand %0[1->2x8] : memref } func @t2(%0: memref) { - %z = constant 0 : index - %1 = expand %0[1->2x2x2x2] : memref - %2 = load %1[%z,%z,%z,%z,%z,%z] : memref + %1 = expand %0[1->2x2x2x2] : memref } func @t3(%0: memref, %1: index) { - %z = constant 0 : index - %2 = expand %0[1->%1 x 2] : memref - %3 = load %2[%z,%z,%z] : memref + %2 = expand %0[1->%1 x 2] : memref } func @t4(%0: memref, %1: index) { - %z = constant 0 : index - %2 = expand %0[1->2 x %1] : memref - %3 = load %2[%z,%z,%z] : memref + %2 = expand %0[1->2 x %1] : memref } func @t5(%0: memref, %1: index) { - %z = constant 0 : index - %2 = expand %0[1->%1 x 2] : memref - %3 = load %2[%z,%z,%z] : memref + %2 = expand %0[1->%1 x 2] : memref } func @t6(%0: memref, %1: index) { - %z = constant 0 : index - %2 = expand %0[1->%1 x 2] : memref - %3 = load %2[%z,%z,%z] : memref + %2 = expand %0[1->%1 x 2] : memref } func @t7(%0: memref, %1: index, %2: index) { - %z = constant 0 : index - %3 = expand %0[1->%1 x %2 x 2] : memref - %4 = load %3[%z,%z,%z,%z] : memref + %3 = expand %0[1->%1 x %2 x 2] : memref } func @t8(%0: memref, %1: index, %2: index) { - %z = constant 0 : index - %3 = expand %0[1->%2 x 2 x %1] : memref - %4 = load %3[%z,%z,%z,%z] : memref + %3 = expand %0[1->%2 x 2 x %1] : memref } func @t9(%0: memref, %1: index, %2: index) { - %z = constant 0 : index - %3 = expand %0[1->%1 x %2] : memref - %4 = load %3[%z,%z,%z] : memref + %3 = expand %0[1->%1 x %2] : memref } func @t10(%0: memref, %1: index, %2: index) { - %z = constant 0 : index - %3 = expand %0[1->%1 x %2] : memref - %4 = load %3[%z,%z,%z] : memref + %3 = expand %0[1->%1 x %2] : memref } func @t11(%0: memref>) { - %z = constant 0 : index - %1 = expand %0[0->4 x 8] : memref> - %2 = load %1[%z,%z,%z] : memref> + %1 = expand %0[0->4 x 8] : memref> } -func @t12(%0: memref>, %1: index) { - %z = constant 0 : index - %2 = expand %0[0->%1 x 4] : memref> - %3 = load %2[%z,%z,%z] : memref> +func @t12(%0: memref>) { + %1 = expand %0[0->4 x 8] : memref> } func @t13(%0: memref>, %1: index) { - %z = constant 0 : index - %2 = expand %0[0->4 x %1] : memref> - %3 = load %2[%z,%z,%z] : memref> + %2 = expand %0[0->%1 x 4] : memref> +} +func @t14(%0: memref>, %1: index) { + %2 = expand %0[0->4 x %1] : memref> } diff --git a/test/opt/insert-lifetime-stop.ir b/test/opt/insert-lifetime-stop.ir index 7087d79d..e919e34a 100644 --- a/test/opt/insert-lifetime-stop.ir +++ b/test/opt/insert-lifetime-stop.ir @@ -35,7 +35,7 @@ func @use2(%A: memref, %C: memref) { func @use_alias(%a: f32, %A: memref, %C: memref) { ; CHECK-LABEL: func @use_alias{{.*}} %B = alloca : memref - %0 = fuse %B[1,3] : memref + %0 = fuse %B[1,3] : memref %1 = subview %0[0:8,0:8] : memref gemm.n.n %a, %A, %1, %a, %C ; CHECK: gemm.n.n{{.*}} diff --git a/test/spv/expand.ir b/test/spv/expand.ir index d1705a71..634945b6 100644 --- a/test/spv/expand.ir +++ b/test/spv/expand.ir @@ -13,7 +13,7 @@ func @f1(%0: memref, %1: index) { %c0 = constant 0 : index - %2 = expand %0[1->4x%1x5] : memref + %2 = expand %0[1->4x%1x5] : memref %3 = size %2[0] : memref %4 = size %2[1] : memref %5 = size %2[2] : memref diff --git a/test/spv/fuse.ir b/test/spv/fuse.ir index 663bb6dc..75c8bf8c 100644 --- a/test/spv/fuse.ir +++ b/test/spv/fuse.ir @@ -14,10 +14,10 @@ func @f1(%0: memref) { %z = constant 0 : index - %1 = fuse %0[1,3] : memref - %2 = size %1[0] : memref> - %3 = size %1[1] : memref> - %4 = size %1[2] : memref> + %1 = fuse %0[1,3] : memref + %2 = size %1[0] : memref + %3 = size %1[1] : memref + %4 = size %1[2] : memref %5 = arith.not %2 : index %6 = arith.not %3 : index %7 = arith.not %4 : index From e9fe5a0d34203c67593753bfe5d87555e94f8b6b Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 21 Nov 2024 17:31:04 +0100 Subject: [PATCH 122/297] Update for/foreach Signed-off-by: Carsten Uphoff --- src/parser/parser_impl.yy | 6 +++--- src/pass/dump_ir.cpp | 10 +++++----- test/codegen/for.ir | 4 ++-- test/opt/check-ir/nesting1.ir | 2 +- test/opt/check-ir/nesting3.ir | 2 +- test/opt/constant-propagation.ir | 4 ++-- test/opt/dead-code-elimination.ir | 2 +- test/opt/dump-def-use.ir | 8 ++++---- test/opt/insert-barrier.ir | 8 ++++---- test/opt/insert-lifetime-stop.ir | 4 ++-- test/spv/for.ir | 6 +++--- 11 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 890da4a7..9fa634de 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -587,7 +587,7 @@ ger_inst: ; for_inst: - FOR LOCAL_IDENTIFIER[loop_var] EQUALS var[from] COMMA var[to] optional_step optional_loop_carried_values[lcv] for_loop_var_type { + FOR LOCAL_IDENTIFIER[loop_var] for_loop_var_type EQUALS var[from] COMMA var[to] optional_step optional_loop_carried_values[lcv] { check_type($from, $for_loop_var_type, @from, @for_loop_var_type); check_type($to, $for_loop_var_type, @to, @for_loop_var_type); if ($optional_step) { @@ -654,8 +654,8 @@ init_value: ; foreach_inst: - FOREACH LPAREN identifier_list[loop_var] RPAREN EQUALS - LPAREN value_list[from] RPAREN COMMA LPAREN value_list[to] RPAREN for_loop_var_type { + FOREACH LPAREN identifier_list[loop_var] RPAREN for_loop_var_type EQUALS + LPAREN value_list[from] RPAREN COMMA LPAREN value_list[to] RPAREN { try { location loc = @FOREACH; loc.end = @for_loop_var_type.end; diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index caa7d2c5..25c6eea6 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -323,6 +323,8 @@ void dump_ir_pass::operator()(for_inst const &in) { } *os_ << "for "; dump_val(in.loop_var()); + *os_ << ":"; + visit(*this, *in.loop_var().ty()); *os_ << "="; dump_val(in.from()); *os_ << ","; @@ -346,8 +348,6 @@ void dump_ir_pass::operator()(for_inst const &in) { [this](auto const &i) { visit(*this, *i.ty()); }); *os_ << ")"; } - *os_ << " : "; - visit(*this, *in.loop_var().ty()); *os_ << " "; dump_region(in.body()); } @@ -356,12 +356,12 @@ void dump_ir_pass::operator()(foreach_inst const &in) { *os_ << "foreach ("; do_with_infix(in.loop_vars().begin(), in.loop_vars().end(), [this](auto const &i) { dump_val(i); }); - *os_ << ")=("; + *os_ << "):"; + visit(*this, *in.loop_vars().begin()->ty()); + *os_ << "=("; do_with_infix(in.from().begin(), in.from().end(), [this](auto const &i) { dump_val(i); }); *os_ << "),("; do_with_infix(in.to().begin(), in.to().end(), [this](auto const &i) { dump_val(i); }); - *os_ << ") : "; - visit(*this, *in.loop_vars().begin()->ty()); *os_ << " "; dump_region(in.body()); } diff --git a/test/codegen/for.ir b/test/codegen/for.ir index 5b5ab374..b544246a 100644 --- a/test/codegen/for.ir +++ b/test/codegen/for.ir @@ -9,7 +9,7 @@ func @for1() { } %lb1 = constant -2 : i16 %ub1 = constant 2 : i16 - for %1 = %lb1,%ub1 : i16 { + for %1:i16 = %lb1,%ub1 { } ; CHECK-LABEL: void for1({{.*}} ; CHECK: for (long x = lb0; x < ub0; ++x) @@ -21,7 +21,7 @@ func @for2(%fib: memref) { %to = constant 6 : i32 %f0 = constant 0 : i64 %f1 = constant 1 : i64 - %fn_1, %fn = for %n=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) : i32 { + %fn_1, %fn = for %n:i32=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) { %fn = arith.add %fn_2, %fn_1 : i64 yield %fn_1, %fn : i64, i64 } diff --git a/test/opt/check-ir/nesting1.ir b/test/opt/check-ir/nesting1.ir index 45d4bc11..db33d0b0 100644 --- a/test/opt/check-ir/nesting1.ir +++ b/test/opt/check-ir/nesting1.ir @@ -8,6 +8,6 @@ func @illegal_nesting() { foreach (%i)=(%lb),(%ub) { foreach (%j)=(%lb),(%ub) { } -; CHECK: 9.9-32: Collective instruction must not be called from SPMD region +; CHECK: 9.9-20: Collective instruction must not be called from SPMD region } } diff --git a/test/opt/check-ir/nesting3.ir b/test/opt/check-ir/nesting3.ir index 52bbfe32..8e98ec71 100644 --- a/test/opt/check-ir/nesting3.ir +++ b/test/opt/check-ir/nesting3.ir @@ -8,6 +8,6 @@ func @illegal_nesting() { parallel { foreach (%j)=(%lb),(%ub) { } -; CHECK: 9.9-32: Collective instruction must not be called from SPMD region +; CHECK: 9.9-20: Collective instruction must not be called from SPMD region } } diff --git a/test/opt/constant-propagation.ir b/test/opt/constant-propagation.ir index daf0ff66..9a487878 100644 --- a/test/opt/constant-propagation.ir +++ b/test/opt/constant-propagation.ir @@ -33,7 +33,7 @@ func @known_loop_bounds() { ; CHECK-NEXT: %tmp = arith.sub %size, %lb : index ; CHECK-NEXT: %1 = constant 36 : index ; CHECK-NEXT: %ub = arith.sub %0, %one : index -; CHECK-NEXT: for %i=%lb,%1 : index { +; CHECK-NEXT: for %i:index=%lb,%1 { } func @known_loop_iter_args() { @@ -48,7 +48,7 @@ func @known_loop_iter_args() { ; CHECK-NEXT: %c5 = constant 5 : index ; CHECK-NEXT: %0 = constant 6 : index ; CHECK-NEXT: %1 = arith.add %c1, %c5 : index -; CHECK-NEXT: %3 = for %i=%c1,%c5 init(%2=%0) -> (index) : index { +; CHECK-NEXT: %3 = for %i:index=%c1,%c5 init(%2=%0) -> (index) { ; CHECK-NEXT: yield %2 : index ; CHECK-NEXT: } } diff --git a/test/opt/dead-code-elimination.ir b/test/opt/dead-code-elimination.ir index 832c95b3..f0097471 100644 --- a/test/opt/dead-code-elimination.ir +++ b/test/opt/dead-code-elimination.ir @@ -58,7 +58,7 @@ func @dead_loop(%a: memref) { ; CHECK-LABEL: func @dead_loop({{.*}} ; CHECK-NEXT: %c5 = constant 5 : index ; CHECK-NEXT: %c6 = constant 6 : index -; CHECK-NEXT: for %0=%c5,%c6 : index { +; CHECK-NEXT: for %0:index=%c5,%c6 { ; CHECK-NEXT: %c43{{.*}} ; CHECK-NEXT: store{{.*}} ; CHECK-NEXT: } diff --git a/test/opt/dump-def-use.ir b/test/opt/dump-def-use.ir index 14043c90..583fddd3 100644 --- a/test/opt/dump-def-use.ir +++ b/test/opt/dump-def-use.ir @@ -7,7 +7,7 @@ func @foobar() { %one = constant 1 : index %lb = constant 0 : index %ub = constant 5 : index - for %i=%lb,%ub : index { + for %i:index=%lb,%ub { %0 = arith.add %i, %one : index %1 = arith.rem %0, %one : index } @@ -18,11 +18,11 @@ func @foobar() { ; CHECK-NEXT: > %0 = arith.add %i, %one : index ; CHECK-NEXT: > %lb = constant 0 : index ; CHECK-NEXT: def %lb -; CHECK-NEXT: > for %i=%lb,%ub : index {...} +; CHECK-NEXT: > for %i:index=%lb,%ub {...} ; CHECK-NEXT: > %ub = constant 5 : index ; CHECK-NEXT: def %ub -; CHECK-NEXT: > for %i=%lb,%ub : index {...} -; CHECK-NEXT: > for %i=%lb,%ub : index {...} +; CHECK-NEXT: > for %i:index=%lb,%ub {...} +; CHECK-NEXT: > for %i:index=%lb,%ub {...} ; CHECK-NEXT: def %i ; CHECK-NEXT: > %0 = arith.add %i, %one : index ; CHECK-NEXT: > %0 = arith.add %i, %one : index diff --git a/test/opt/insert-barrier.ir b/test/opt/insert-barrier.ir index 44537e6e..3c28c0a4 100644 --- a/test/opt/insert-barrier.ir +++ b/test/opt/insert-barrier.ir @@ -131,9 +131,9 @@ func @region1() { %0 = alloca : memref %lb = constant 0 : index %ub = constant 4 : index - for %i=%lb,%ub : index { + for %i:index=%lb,%ub { %1 = alloca : memref - for %k=%lb,%ub : index { + for %k:index=%lb,%ub { %2 = alloca : memref gemm.n.n %one, %0, %1, %zero, %2 axpby.n %one, %1, %zero, %0 @@ -141,9 +141,9 @@ func @region1() { axpby.n %one, %0, %zero, %1 } ; CHECK-LABEL: func @region1({{.*}} -; CHECK: for %i=%lb,%ub : index { +; CHECK: for %i:index=%lb,%ub { ; CHECK-NEXT: %1 = alloca : memref -; CHECK-NEXT: for %k=%lb,%ub : index { +; CHECK-NEXT: for %k:index=%lb,%ub { ; CHECK-NEXT: %2 = alloca : memref ; CHECK-NEXT: barrier.local ; CHECK-NEXT: gemm.n.n %one, %0, %1, %zero, %2 diff --git a/test/opt/insert-lifetime-stop.ir b/test/opt/insert-lifetime-stop.ir index e919e34a..3b318188 100644 --- a/test/opt/insert-lifetime-stop.ir +++ b/test/opt/insert-lifetime-stop.ir @@ -48,9 +48,9 @@ func @region1() { %0 = alloca : memref %lb = constant 0 : index %ub = constant 4 : index - for %i=%lb,%ub : index { + for %i:index=%lb,%ub { %1 = alloca : memref - for %k=%lb,%ub : index { + for %k:index=%lb,%ub { %2 = alloca : memref gemm.n.n %one, %0, %1, %one, %2 axpby.n %one, %0, %one, %1 diff --git a/test/spv/for.ir b/test/spv/for.ir index 11bec7c7..d7ce4aa2 100644 --- a/test/spv/for.ir +++ b/test/spv/for.ir @@ -21,7 +21,7 @@ func @for1() { %lb = constant 0 : i16 %ub = constant 10 : i16 %step = constant 2 : i16 - for %0 = %lb,%ub,%step : i16 { + for %0:i16 = %lb,%ub,%step { } ; CHECK: %[[#]] = OpFunction {{.*}} ; CHECK: OpLoopMerge %[[#MERGE_LABEL1:]] %[[#CONT_LABEL1:]] None @@ -44,7 +44,7 @@ func @for2() { %to = constant 6 : i32 %f0 = constant 0 : i64 %f1 = constant 1 : i64 - %fn_1, %fn = for %n=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) : i32 { + %fn_1, %fn = for %n:i32=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) { %fn = arith.add %fn_2, %fn_1 : i64 yield %fn_1, %fn : i64, i64 } @@ -75,7 +75,7 @@ func @for3() subgroup_size(16) { %from = constant 2 : i16 %to = constant 6 : i16 %m_init = constant 1 : coopmatrix - %m = for %n=%from,%to init(%m_iter=%m_init) -> (coopmatrix) : i16 { + %m = for %n:i16=%from,%to init(%m_iter=%m_init) -> (coopmatrix) { %m_update = arith.add %m_iter, %m_init : coopmatrix yield %m_update : coopmatrix } From 84b5442c98269aa145561e0fdc88f17dad64bf90 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 21 Nov 2024 18:18:42 +0100 Subject: [PATCH 123/297] Update load, size, and store Signed-off-by: Carsten Uphoff --- include/tinytc/tinytc.h | 44 ++++++++--------- include/tinytc/tinytc.hpp | 11 +++-- src/codegen_tools.cpp | 6 ++- src/inst.cpp | 8 ++-- src/node/inst_node.cpp | 34 +++++++++---- src/node/inst_node.hpp | 5 +- src/parser/parser_impl.yy | 33 +++---------- src/pass/clone.cpp | 4 +- src/pass/dump_ir.cpp | 4 +- src/pass/lower_linalg.cpp | 76 ++++++++++++++++++------------ src/recipe/tall_and_skinny.cpp | 2 +- test/codegen/atomic.ir | 8 ++-- test/codegen/dope_vector_group0.ir | 4 +- test/codegen/expand.ir | 26 +++++----- test/codegen/for.ir | 2 +- test/codegen/fuse.ir | 8 ++-- test/codegen/load.ir | 8 ++-- test/codegen/size.ir | 8 ++-- test/codegen/store.ir | 4 +- test/codegen/type_mismatch0.ir | 4 +- test/codegen/type_mismatch1.ir | 6 +-- test/opt/check-ir/subview.ir | 19 +------- test/opt/constant-propagation.ir | 8 ++-- test/opt/dead-code-elimination.ir | 12 ++--- test/opt/insert-barrier.ir | 10 ++-- test/spv/alloca.ir | 8 ++-- test/spv/expand.ir | 12 ++--- test/spv/fuse.ir | 6 +-- test/spv/load.ir | 6 +-- test/spv/size.ir | 8 ++-- test/spv/store.ir | 14 +++--- test/spv/subview.ir | 4 +- 32 files changed, 206 insertions(+), 206 deletions(-) diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 606e8a9d..cf30cae2 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -227,7 +227,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_arith_inst_create(tinytc_inst_t *instr, tin /** * @brief Create arithmetic instruction (unary) * - * @code %value = arith. %a : type(%a) @endcode + * @code %value = arith. %a : ty @endcode * * @param instr [out] pointer to the inst object created * @param op [in] unary arithmetic operation type @@ -246,7 +246,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_arith_unary_inst_create(tinytc_inst_t *inst /** * @brief Create cast instruction * - * @code %value = cast %a, %b : type(%a) -> %to_ty @endcode + * @code %value = cast %a, %b : %to_ty @endcode * * @param instr [out] pointer to the inst object created * @param a [in] operand @@ -262,7 +262,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_cast_inst_create(tinytc_inst_t *instr, tiny /** * @brief Create binary op instruction * - * @code %value = cmp. %a, %b : type(%a) ; type(%a) == type(%b) @endcode + * @code %value = cmp. %a, %b : ty ; type(%a) == type(%b) @endcode * * @param instr [out] pointer to the inst object created * @param cond [in] compare type @@ -366,7 +366,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_zero(tinytc_inst_t *in /** * @brief Create cooperative matrix load instruction * - * @code %value = cooperative_matrix_load.transpose.checked %op[%p0, %p1] : type(%op) -> to_ty + * @code %value = cooperative_matrix_load.transpose.checked %op[%p0, %p1] : to_ty * @endcode * * @param instr [out] pointer to the inst object created @@ -388,7 +388,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_load_inst_create( /** * @brief Create cooperative matrix mul add instruction * - * @code cooperative_matrix_mul_add %a, %b, %c : type(%a), type(%b), type(%c) -> to_ty @endcode + * @code cooperative_matrix_mul_add %a, %b, %c : to_ty @endcode * * @param instr [out] pointer to the inst object created * @param a [in] %a @@ -406,7 +406,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_mul_add_inst_create( /** * @brief Create cooperative matrix scale instruction * - * @code cooperative_matrix_scale %a, %b : type(%a), type(%b) @endcode + * @code cooperative_matrix_scale %a, %b : ty @endcode * * @param instr [out] pointer to the inst object created * @param a [in] %a @@ -423,7 +423,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_scale_inst_create( /** * @brief Create cooperative matrix store instruction * - * @code cooperative_matrix_store.checked.store_flag %val, %op[%p0, %p1] : type(%val), type(%op) + * @code cooperative_matrix_store.checked.store_flag %val, %op[%p0, %p1] * @endcode * * @param instr [out] pointer to the inst object created @@ -445,7 +445,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_store_inst_create( /** * @brief Create alloca instruction * - * @code %value = alloca -> %ty @endcode + * @code %value = alloca : %ty @endcode * * @param instr [out] pointer to the inst object created * @param ty [in] type that is allocated @@ -460,7 +460,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_alloca_inst_create(tinytc_inst_t *instr, ti * @brief Create axpby instruction * * @code - * axpby.. %alpha, %A, %beta, %B : type(%alpha), type(%A), type(%beta), type(%B) + * axpby.. %alpha, %A, %beta, %B * @endcode * * @param instr [out] pointer to the inst object created @@ -483,7 +483,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_axpby_inst_create(tinytc_inst_t *instr, tin /** * @brief Create expand instruction * - * @code %value = expand %a[%mode -> %expand_shape] : type(%a) @endcode + * @code %value = expand %a[%mode -> %expand_shape] : ty @endcode * * @param instr [out] pointer to the inst object created * @param a [in] operand @@ -507,7 +507,7 @@ tinytc_expand_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t expand /** * @brief Create fuse instruction * - * @code %value = fuse %a[%from, %to] : type(%a) @endcode + * @code %value = fuse %a[%from, %to] : ty @endcode * * @param instr [out] pointer to the inst object created * @param a [in] operand @@ -526,13 +526,14 @@ TINYTC_EXPORT tinytc_status_t tinytc_fuse_inst_create(tinytc_inst_t *instr, tiny /** * @brief Create load instruction * - * @code %value = load %a[%index_list] : type(%a) @endcode + * @code %value = load %a[%index_list] : ty @endcode * * @param instr [out] pointer to the inst object created * @param a [in] operand * @param index_list_size [in] number of indices * @param index_list [in][range(0, index_list_size)] indices array; may be nullptr if * index_list_size is 0 + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise @@ -540,6 +541,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_fuse_inst_create(tinytc_inst_t *instr, tiny TINYTC_EXPORT tinytc_status_t tinytc_load_inst_create(tinytc_inst_t *instr, tinytc_value_t a, uint32_t index_list_size, const tinytc_value_t *index_list, + tinytc_data_type_t ty, const tinytc_location_t *loc); /** * @brief Create group_id instruction @@ -576,7 +578,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_group_size_inst_create(tinytc_inst_t *instr * * @code * gemm... %alpha, %A, %B, %beta, %C - * : type(%alpha), type(%A), type(%B), type(%beta), type(%C) * @endcode * * @param instr [out] pointer to the inst object created @@ -604,7 +605,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_gemm_inst_create(tinytc_inst_t *instr, tiny * * @code * gemv.. %alpha, %A, %B, %beta, %C - * : type(%alpha), type(%A), type(%B), type(%beta), type(%C) * @endcode * * @param instr [out] pointer to the inst object created @@ -630,7 +630,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_gemv_inst_create(tinytc_inst_t *instr, tiny * * @code * ger. %alpha, %A, %B, %beta, %C - * : type(%alpha), type(%A), type(%B), type(%beta), type(%C) * @endcode * * @param instr [out] pointer to the inst object created @@ -655,7 +654,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_ger_inst_create(tinytc_inst_t *instr, tinyt * * @code * hadamard. %alpha, %A, %B, %beta, %C - * : type(%alpha), type(%A), type(%B), type(%beta), type(%C) * @endcode * * @param instr [out] pointer to the inst object created @@ -706,17 +704,19 @@ TINYTC_EXPORT tinytc_status_t tinytc_parallel_inst_create(tinytc_inst_t *instr, /** * @brief Create size instruction * - * @code %value = size %a[%mode] : type(%a) @endcode + * @code %value = size %a[%mode] : ty @endcode * * @param instr [out] pointer to the inst object created * @param a [in] operand * @param mode [in] mode for that the size is queried + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_size_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - int64_t mode, const tinytc_location_t *loc); + int64_t mode, tinytc_data_type_t ty, + const tinytc_location_t *loc); /** * @brief Create subgroup_id instruction @@ -816,7 +816,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_store_inst_create(tinytc_inst_t *instr, * @brief Create sum instruction * * @code - * sum.. %alpha, %A, %beta, %B : type(%alpha), type(%A), type(%beta), type(%B) + * sum.. %alpha, %A, %beta, %B * @endcode * * @param instr [out] pointer to the inst object created @@ -840,8 +840,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinyt * @brief Create for loop * * @code - * for %loop_var = %from, %to, %step - * init(initial_value_list) -> (types(initial_value_list)) : loop_var_type { } + * for %loop_var : loop_var_type = %from, %to, %step + * init(initial_value_list) -> (types(initial_value_list)) { } * ; loop_var_type == type(%from) * ; loop_var_type == type(%to) * ; loop_var_type == type(%step) @@ -870,7 +870,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinyt * @brief Create foreach loop * * @code - * foreach (loop_var_list) = (from_list), (to_list) : loop_var_type { } + * foreach (loop_var_list) : loop_var_type = (from_list), (to_list) { } * ; loop_var_type == type(%f) forall %f in from_list * ; loop_var_type == type(%t) forall %t in to_list * @endcode diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 0304c54d..cb80479d 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1215,18 +1215,20 @@ inline inst make_fuse(value a, std::int64_t from, std::int64_t to, data_type ty, * * @param a Operand * @param index_list Vector of indices + * @param ty Result type * @param loc Source code location * * @return Instruction */ -inline inst make_load(value a, array_view index_list, location const &loc = {}) { +inline inst make_load(value a, array_view index_list, tinytc_data_type_t ty, + location const &loc = {}) { tinytc_inst_t instr; auto len = index_list.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("index list too long"); } const tinytc_value_t *il = reinterpret_cast(index_list.data()); - CHECK_STATUS_LOC(tinytc_load_inst_create(&instr, a, len, il, &loc), loc); + CHECK_STATUS_LOC(tinytc_load_inst_create(&instr, a, len, il, ty, &loc), loc); return inst(instr); } @@ -1378,13 +1380,14 @@ inline inst make_parallel(location const &loc = {}) { * * @param a Operand * @param mode Mode + * @param ty Result type * @param loc Source code location * * @return Instruction */ -inline inst make_size(value a, std::int64_t mode, location const &loc = {}) { +inline inst make_size(value a, std::int64_t mode, data_type ty, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_size_inst_create(&instr, a, mode, &loc), loc); + CHECK_STATUS_LOC(tinytc_size_inst_create(&instr, a, mode, ty, &loc), loc); return inst(instr); } diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 4382887f..f8110a41 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -645,7 +645,11 @@ void blas_update(region_builder &bb, bool atomic, value alpha, value ab, value b } bb.add(make_store(*flag, alpha_ab, C, index_list, loc)); } else { - auto c = bb.add(make_load(C, index_list, loc)); + memref_data_type *ct = dyn_cast(C->ty()); + if (ct == nullptr) { + throw compilation_error(loc, {C.get()}, status::ir_expected_scalar); + } + auto c = bb.add(make_load(C, index_list, ct->element_data_ty(), loc)); auto beta_c = mixed_precision_arithmetic(bb, arithmetic::mul, beta, c, loc); auto alpha_ab_plus_beta_c = mixed_precision_arithmetic(bb, arithmetic::add, alpha_ab, beta_c, loc); diff --git a/src/inst.cpp b/src/inst.cpp index dbbdbc8e..c4e1adf6 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -434,12 +434,12 @@ tinytc_status_t tinytc_fuse_inst_create(tinytc_inst_t *instr, tinytc_value_t a, tinytc_status_t tinytc_load_inst_create(tinytc_inst_t *instr, tinytc_value_t a, uint32_t index_list_size, const tinytc_value_t *index_list, - const tinytc_location_t *loc) { + tinytc_data_type_t ty, const tinytc_location_t *loc) { if (instr == nullptr || (index_list_size > 0 && index_list == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(a, array_view{index_list, index_list_size}, + *instr = std::make_unique(a, array_view{index_list, index_list_size}, ty, get_optional(loc)) .release(); }); @@ -538,12 +538,12 @@ tinytc_status_t tinytc_parallel_inst_create(tinytc_inst_t *instr, const tinytc_l } tinytc_status_t tinytc_size_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t mode, - const tinytc_location_t *loc) { + tinytc_data_type_t ty, const tinytc_location_t *loc) { if (instr == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code( - [&] { *instr = std::make_unique(a, mode, get_optional(loc)).release(); }); + [&] { *instr = std::make_unique(a, mode, ty, get_optional(loc)).release(); }); } tinytc_status_t tinytc_subgroup_id_inst_create(tinytc_inst_t *instr, tinytc_compiler_context_t ctx, diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index e6c90138..b19c67ed 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -802,7 +802,8 @@ fuse_inst::fuse_inst(tinytc_value_t op0, std::int64_t from, std::int64_t to, tin result(0) = value_node{ty, this, lc}; } -load_inst::load_inst(tinytc_value_t op0, array_view index_list0, location const &lc) +load_inst::load_inst(tinytc_value_t op0, array_view index_list0, + tinytc_data_type_t ty, location const &lc) : standard_inst{IK::load, static_cast(1 + index_list0.size())} { op(0, op0); for (std::size_t i = 0; i < index_list0.size(); ++i) { @@ -813,17 +814,24 @@ load_inst::load_inst(tinytc_value_t op0, array_view index_list0, visit(overloaded{ [&](group_data_type &g) { + if (g.ty() != ty) { + throw compilation_error(loc(), {&operand()}, + status::ir_operand_type_must_match_return_type); + } if (static_cast(index_list().size()) != 1) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } - result(0) = value_node{g.ty(), this, lc}; + result(0) = value_node{ty, this, lc}; }, [&](memref_data_type &m) { + if (m.element_data_ty() != ty) { + throw compilation_error(loc(), {&operand()}, + status::ir_operand_type_must_match_return_type); + } if (m.dim() != static_cast(index_list().size())) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } - auto result_ty = scalar_data_type::get(m.context(), m.element_ty()); - result(0) = value_node{result_ty, this, lc}; + result(0) = value_node{ty, this, lc}; }, [&](auto &) { throw compilation_error(loc(), status::ir_expected_memref_or_group); }}, *operand().ty()); @@ -968,7 +976,7 @@ if_inst::if_inst(tinytc_value_t condition, array_view return then().loc(lc); otherwise().loc(lc); if (!isa(*condition->ty())) { - throw compilation_error(loc(), status::ir_expected_boolean); + throw compilation_error(loc(), {condition}, status::ir_expected_boolean); } for (std::size_t i = 0; i < return_types.size(); ++i) { if (!isa(*return_types[i]) && !isa(*return_types[i]) && @@ -985,18 +993,24 @@ parallel_inst::parallel_inst(location const &lc) : standard_inst{IK::parallel} { child_region(0).kind(region_kind::spmd); } -size_inst::size_inst(tinytc_value_t op0, std::int64_t mode, location const &lc) +size_inst::size_inst(tinytc_value_t op0, std::int64_t mode, tinytc_data_type_t ty, + location const &lc) : standard_inst{IK::size}, mode_(mode) { op(0, op0); loc(lc); + + auto rt = dyn_cast(ty); + if (!rt || rt->ty() != scalar_type::index) { + throw compilation_error(loc(), status::ir_expected_index); + } + auto m = get_memref_type(loc(), operand()); bool const range_ok = 0 <= mode_ && mode_ < m->dim(); if (!range_ok) { throw compilation_error(loc(), status::ir_out_of_bounds); } - auto result_ty = scalar_data_type::get(op(0).context(), scalar_type::index); - result(0) = value_node{result_ty, this, lc}; + result(0) = value_node{ty, this, lc}; } subview_inst::subview_inst(tinytc_value_t op0, array_view static_offsets0, @@ -1067,11 +1081,11 @@ store_inst::store_inst(store_flag flag, tinytc_value_t val0, tinytc_value_t op0, auto o = get_memref_type(loc(), operand()); if (v->ty() != o->element_ty()) { - throw compilation_error(loc(), status::ir_scalar_mismatch); + throw compilation_error(loc(), {&val(), &operand()}, status::ir_scalar_mismatch); } if (o->dim() != static_cast(index_list0.size())) { - throw compilation_error(loc(), status::ir_invalid_number_of_indices); + throw compilation_error(loc(), {&operand()}, status::ir_invalid_number_of_indices); } } diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index c8ea0257..6c24b6b3 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -579,7 +579,8 @@ class fuse_inst : public standard_inst<1, 1> { class load_inst : public standard_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::load; } - load_inst(tinytc_value_t op, array_view index_list, location const &lc = {}); + load_inst(tinytc_value_t op, array_view index_list, tinytc_data_type_t ty, + location const &lc = {}); inline auto operand() -> tinytc_value & { return op(0); } inline auto operand() const -> tinytc_value const & { return op(0); } @@ -746,7 +747,7 @@ class parallel_inst : public standard_inst<0, 0, 1> { class size_inst : public standard_inst<1, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::size; } - size_inst(tinytc_value_t op, std::int64_t mode, location const &lc = {}); + size_inst(tinytc_value_t op, std::int64_t mode, tinytc_data_type_t ty, location const &lc = {}); inline auto operand() -> tinytc_value & { return op(0); } inline auto operand() const -> tinytc_value const & { return op(0); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 9fa634de..f1aea607 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -188,7 +188,6 @@ %nterm constant_or_dynamic %nterm group_type %nterm group_offset -%nterm memref_or_group_type %nterm var %nterm instruction %nterm axpby_inst @@ -435,11 +434,6 @@ group_offset: | COMMA OFFSET COLON constant_or_dynamic { $$ = $constant_or_dynamic; } ; -memref_or_group_type: - memref_type - | group_type -; - var: LOCAL_IDENTIFIER { $$ = ctx.val($LOCAL_IDENTIFIER, @LOCAL_IDENTIFIER); } ; @@ -1019,16 +1013,11 @@ fuse_inst: ; load_inst: - LOAD var LSQBR optional_value_list RSQBR COLON memref_or_group_type { - if ($var->ty() != $memref_or_group_type) { - auto loc = @var; - loc.end = @memref_or_group_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); - } + LOAD var LSQBR optional_value_list RSQBR COLON data_type { try { $$ = inst { std::make_unique(std::move($var), std::move($optional_value_list), - @load_inst) + std::move($data_type), @load_inst) .release() }; } catch (compilation_error const &e) { @@ -1039,12 +1028,7 @@ load_inst: ; store_inst: - STORE store_flag var[a] COMMA var[b] LSQBR optional_value_list RSQBR COLON memref_type { - if ($b->ty() != $memref_type) { - auto loc = @b; - loc.end = @memref_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); - } + STORE store_flag var[a] COMMA var[b] LSQBR optional_value_list RSQBR { try { $$ = inst { std::make_unique($store_flag, std::move($a), std::move($b), @@ -1138,14 +1122,11 @@ parallel_inst: ; size_inst: - SIZE var LSQBR INTEGER_CONSTANT[mode] RSQBR COLON memref_type { - if ($var->ty() != $memref_type) { - auto loc = @var; - loc.end = @memref_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); - } + SIZE var LSQBR INTEGER_CONSTANT[mode] RSQBR COLON scalar_type { try { - $$ = inst { std::make_unique(std::move($var), $mode, @size_inst).release() }; + $$ = inst { + std::make_unique(std::move($var), $mode, $scalar_type, @size_inst).release() + }; } catch (compilation_error const &e) { report_error(ctx.cctx(), e); YYERROR; diff --git a/src/pass/clone.cpp b/src/pass/clone.cpp index 69a5cbbb..d323ea41 100644 --- a/src/pass/clone.cpp +++ b/src/pass/clone.cpp @@ -78,7 +78,7 @@ auto inst_cloner::operator()(lifetime_stop_inst &in) -> std::unique_ptr std::unique_ptr { return std::make_unique(subs(&in.operand()), subs_value_range(in.index_list()), - in.loc()); + in.result(0).ty(), in.loc()); } auto inst_cloner::operator()(group_id_inst &in) -> std::unique_ptr { return std::make_unique(in.context(), in.loc()); @@ -136,7 +136,7 @@ auto inst_cloner::operator()(parallel_inst &in) -> std::unique_ptr } auto inst_cloner::operator()(size_inst &in) -> std::unique_ptr { - return std::make_unique(subs(&in.operand()), in.mode(), in.loc()); + return std::make_unique(subs(&in.operand()), in.mode(), in.result(0).ty(), in.loc()); } auto inst_cloner::operator()(subgroup_id_inst &in) -> std::unique_ptr { diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 25c6eea6..c49282e3 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -280,7 +280,7 @@ void dump_ir_pass::operator()(load_inst const &e) { do_with_infix(e.index_list().begin(), e.index_list().end(), [this](auto const &i) { dump_val(i); }); *os_ << "] : "; - visit(*this, *e.operand().ty()); + visit(*this, *e.result(0).ty()); } void dump_ir_pass::operator()(group_id_inst const &g) { @@ -409,7 +409,7 @@ void dump_ir_pass::operator()(size_inst const &s) { dump_val(s.operand()); *os_ << "[" << s.mode() << "]"; *os_ << " : "; - visit(*this, *s.operand().ty()); + visit(*this, *s.result(0).ty()); } void dump_ir_pass::operator()(subgroup_id_inst const &sg) { diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index 5adf7df5..13a9dcf6 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -189,6 +189,7 @@ void linalg_generator::operator()(axpby_inst &in) { auto bool_ty = get_boolean(ctx); auto index_ty = get_scalar(ctx, scalar_type::index); + auto at = get_memref_type(in.A()); auto bt = get_memref_type(in.B()); if (bt->dim() == 0) { auto parallel = make_parallel(in.loc()); @@ -203,26 +204,27 @@ void linalg_generator::operator()(axpby_inst &in) { auto cond1 = bb.add(make_cmp(cmp_condition::eq, sg_lid, c0, bool_ty, in.loc())); auto cond = bb.add(make_arith(arithmetic::and_, cond0, cond1, cond0->ty())); bb.if_condition(cond, [&](region_builder &bb) { - auto a = bb.add(make_load(&in.A(), {}, in.loc())); + auto a = bb.add(make_load(&in.A(), {}, at->element_data_ty(), in.loc())); blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), {}, in.loc()); }); add(std::move(parallel)); } else if (bt->dim() == 1) { auto c0 = add(make_constant(0, index_ty, in.loc())); - auto c_shape0 = add(make_size(&in.B(), 0, in.loc())); + auto c_shape0 = add(make_size(&in.B(), 0, index_ty, in.loc())); add_foreach( {c0.get()}, {c_shape0.get()}, index_ty, [&](region_builder &bb, auto loop_vars) { - auto a = bb.add(make_load(&in.A(), {&loop_vars[0]}, in.loc())); + auto a = + bb.add(make_load(&in.A(), {&loop_vars[0]}, at->element_data_ty(), in.loc())); blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), {&loop_vars[0]}, in.loc()); }, in.loc()); } else if (bt->dim() == 2) { auto c0 = add(make_constant(0, index_ty, in.loc())); - auto c_shape0 = add(make_size(&in.B(), 0, in.loc())); - auto c_shape1 = add(make_size(&in.B(), 1, in.loc())); + auto c_shape0 = add(make_size(&in.B(), 0, index_ty, in.loc())); + auto c_shape1 = add(make_size(&in.B(), 1, index_ty, in.loc())); add_foreach( {c0.get(), c0.get()}, {c_shape0.get(), c_shape1.get()}, index_ty, [&](region_builder &bb, auto loop_vars) { @@ -230,7 +232,7 @@ void linalg_generator::operator()(axpby_inst &in) { if (in.tA() == transpose::T) { std::swap(a_idx[0], a_idx[1]); } - auto a = bb.add(make_load(&in.A(), a_idx, in.loc())); + auto a = bb.add(make_load(&in.A(), a_idx, at->element_data_ty(), in.loc())); blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), {&loop_vars[0], &loop_vars[1]}, in.loc()); }, @@ -241,13 +243,15 @@ void linalg_generator::operator()(axpby_inst &in) { void linalg_generator::operator()(ger_inst &in) { auto index_ty = scalar_data_type::get(in.alpha().context(), scalar_type::index); auto c0 = add(make_constant(0, index_ty, in.loc())); - auto c_shape0 = add(make_size(&in.C(), 0, in.loc())); - auto c_shape1 = add(make_size(&in.C(), 1, in.loc())); + auto c_shape0 = add(make_size(&in.C(), 0, index_ty, in.loc())); + auto c_shape1 = add(make_size(&in.C(), 1, index_ty, in.loc())); add_foreach( {c0.get(), c0.get()}, {c_shape0.get(), c_shape1.get()}, index_ty, [&](region_builder &bb, auto loop_vars) { - auto a = bb.add(make_load(&in.A(), {&loop_vars[0]}, in.loc())); - auto b = bb.add(make_load(&in.B(), {&loop_vars[1]}, in.loc())); + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + auto a = bb.add(make_load(&in.A(), {&loop_vars[0]}, at->element_data_ty(), in.loc())); + auto b = bb.add(make_load(&in.B(), {&loop_vars[1]}, bt->element_data_ty(), in.loc())); auto ab = mixed_precision_arithmetic(bb, arithmetic::mul, a, b, in.loc()); blas_update(bb, in.atomic(), &in.alpha(), ab, &in.beta(), &in.C(), {&loop_vars[0], &loop_vars[1]}, in.loc()); @@ -266,6 +270,7 @@ void linalg_generator::operator()(gemm_inst &in) { auto ctx = compiler_context{in.alpha().context(), true}; auto i32_ty = get_scalar(ctx, scalar_type::i32); + auto index_ty = get_scalar(ctx, scalar_type::index); auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); auto c_m_tiles = bb.add(make_constant(tiling_.m_tiles(), i32_ty, in.loc())); @@ -276,10 +281,10 @@ void linalg_generator::operator()(gemm_inst &in) { size(ct->element_ty()), core_cfg_.subgroup_size, core_cfg_.register_space, is_complex_type(ct->element_ty()) ? 2 : 1); - auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.C(), 0, in.loc())); - auto c_shape1 = instant_constant_fold_add(bb, make_size(&in.C(), 1, in.loc())); + auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.C(), 0, index_ty, in.loc())); + auto c_shape1 = instant_constant_fold_add(bb, make_size(&in.C(), 1, index_ty, in.loc())); auto K = instant_constant_fold_add( - bb, make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, in.loc())); + bb, make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, index_ty, in.loc())); auto const_shape0 = get_int_constant(c_shape0); auto const_shape1 = get_int_constant(c_shape1); @@ -329,21 +334,24 @@ void linalg_generator::operator()(gemm_inst &in) { void linalg_generator::operator()(gemv_inst &in) { auto index_ty = scalar_data_type::get(in.alpha().context(), scalar_type::index); auto c0 = add(make_constant(0, index_ty, in.loc())); - auto c_shape0 = add(make_size(&in.C(), 0, in.loc())); + auto c_shape0 = add(make_size(&in.C(), 0, index_ty, in.loc())); auto ct = get_memref_type(in.C()); add_foreach( {c0.get()}, {c_shape0.get()}, index_ty, [&](region_builder &bb, auto loop_vars) { auto c_init = bb.add(make_constant_zero(ct->element_data_ty())); - auto K = bb.add(make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, in.loc())); + auto K = + bb.add(make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, index_ty, in.loc())); auto c_acc = bb.for_loop( c0, K, {}, {c_init}, index_ty, [&](region_builder &bb, array_view p) { auto a_idx = std::array{&loop_vars[0], p[0]}; if (in.tA() == transpose::T) { std::swap(a_idx[0], a_idx[1]); } - auto a = bb.add(make_load(&in.A(), a_idx, in.loc())); - auto b = bb.add(make_load(&in.B(), {p[0]}, in.loc())); + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + auto a = bb.add(make_load(&in.A(), a_idx, at->element_data_ty(), in.loc())); + auto b = bb.add(make_load(&in.B(), {p[0]}, bt->element_data_ty(), in.loc())); auto ab = mixed_precision_arithmetic(bb, arithmetic::mul, a, b, in.loc()); auto ab_c = mixed_precision_arithmetic(bb, arithmetic::add, p[1], ab, in.loc()); bb.add(make_yield({ab_c}, in.loc())); @@ -357,12 +365,14 @@ void linalg_generator::operator()(gemv_inst &in) { void linalg_generator::operator()(hadamard_inst &in) { auto index_ty = scalar_data_type::get(in.alpha().context(), scalar_type::index); auto c0 = add(make_constant(0, index_ty, in.loc())); - auto c_shape0 = add(make_size(&in.C(), 0, in.loc())); + auto c_shape0 = add(make_size(&in.C(), 0, index_ty, in.loc())); add_foreach( {c0.get()}, {c_shape0.get()}, index_ty, [&](region_builder &bb, auto loop_vars) { - auto a = bb.add(make_load(&in.A(), {&loop_vars[0]}, in.loc())); - auto b = bb.add(make_load(&in.B(), {&loop_vars[0]}, in.loc())); + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + auto a = bb.add(make_load(&in.A(), {&loop_vars[0]}, at->element_data_ty(), in.loc())); + auto b = bb.add(make_load(&in.B(), {&loop_vars[0]}, bt->element_data_ty(), in.loc())); auto ab = mixed_precision_arithmetic(bb, arithmetic::mul, a, b, in.loc()); blas_update(bb, in.atomic(), &in.alpha(), ab, &in.beta(), &in.C(), {&loop_vars[0]}, in.loc()); @@ -380,6 +390,7 @@ void linalg_generator::operator()(sum_inst &in) { auto i32_ty = get_scalar(ctx, scalar_type::i32); auto index_ty = get_scalar(ctx, scalar_type::index); + auto at = get_memref_type(in.A()); auto bt = get_memref_type(in.B()); if (bt->dim() == 0) { auto c_sgs = bb.add(make_constant(core_cfg_.subgroup_size, i32_ty, in.loc())); @@ -392,18 +403,19 @@ void linalg_generator::operator()(sum_inst &in) { auto c_zero = bb.add(make_constant_zero(i32_ty, in.loc())); auto is_from_0 = bb.add(make_cmp(cmp_condition::eq, from1, c_zero, bool_ty, in.loc())); - auto c_trip_count = instant_constant_fold_add(bb, make_size(&in.A(), 0, in.loc())); + auto c_trip_count = + instant_constant_fold_add(bb, make_size(&in.A(), 0, index_ty, in.loc())); auto c_step = bb.add(make_constant( core_cfg_.subgroup_size * tiling_.m_tiles() * tiling_.n_tiles(), index_ty, in.loc())); auto c_init = bb.add(make_constant_zero(bt->element_data_ty(), in.loc())); - auto acc = bb.for_loop(from_index, c_trip_count, c_step, {c_init}, index_ty, - [&](region_builder &bb, array_view args) { - auto a = bb.add(make_load(&in.A(), {args[0]}, in.loc())); - auto sum = mixed_precision_arithmetic(bb, arithmetic::add, - args[1], a, in.loc()); - bb.add(make_yield({sum}, in.loc())); - }); + auto acc = bb.for_loop( + from_index, c_trip_count, c_step, {c_init}, index_ty, + [&](region_builder &bb, array_view args) { + auto a = bb.add(make_load(&in.A(), {args[0]}, at->element_data_ty(), in.loc())); + auto sum = mixed_precision_arithmetic(bb, arithmetic::add, args[1], a, in.loc()); + bb.add(make_yield({sum}, in.loc())); + }); auto sum = bb.add(make_work_group(work_group_operation::reduce_add, acc[0], in.loc())); bb.if_condition( is_from_0, @@ -414,11 +426,12 @@ void linalg_generator::operator()(sum_inst &in) { } else if (bt->dim() == 1) { auto index_ty = scalar_data_type::get(in.alpha().context(), scalar_type::index); auto c0 = add(make_constant(0, index_ty, in.loc())); - auto c_shape0 = add(make_size(&in.B(), 0, in.loc())); + auto c_shape0 = add(make_size(&in.B(), 0, index_ty, in.loc())); add_foreach( {c0.get()}, {c_shape0.get()}, index_ty, [&](region_builder &bb, auto loop_vars) { - auto K = bb.add(make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, in.loc())); + auto K = + bb.add(make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, index_ty, in.loc())); auto c_init = bb.add(make_constant_zero(bt->element_data_ty())); auto acc = bb.for_loop( c0, K, {}, {c_init}, index_ty, [&](region_builder &bb, array_view args) { @@ -426,7 +439,8 @@ void linalg_generator::operator()(sum_inst &in) { if (in.tA() == transpose::T) { std::swap(index_list[0], index_list[1]); } - auto a = bb.add(make_load(&in.A(), index_list, in.loc())); + auto a = + bb.add(make_load(&in.A(), index_list, at->element_data_ty(), in.loc())); auto sum = mixed_precision_arithmetic(bb, arithmetic::add, args[1], a, in.loc()); bb.add(make_yield({sum}, in.loc())); diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index c4879507..b2c3d2bb 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -142,7 +142,7 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( static_gemm(bb); } else { - auto M_val = bb.add(make_size(C, 0, my_loc())); + auto M_val = bb.add(make_size(C, 0, index_ty, my_loc())); auto M_val_sub_m = bb.add(make_arith(arithmetic::sub, M_val, m, m.get_type(), my_loc())); auto cond = bb.add(make_cmp(cmp_condition::lt, M_val_sub_m, c_M_block_size, diff --git a/test/codegen/atomic.ir b/test/codegen/atomic.ir index ab5d86e6..49f5d186 100644 --- a/test/codegen/atomic.ir +++ b/test/codegen/atomic.ir @@ -5,8 +5,8 @@ func @atomic_store(%A: memref) { %f0 = constant 0.0 : f64 %i0 = constant 0 : index - store.atomic %f0, %A[%i0] : memref - store.atomic_add %f0, %A[%i0] : memref + store.atomic %f0, %A[%i0] + store.atomic_add %f0, %A[%i0] ; CHECK-LABEL: void atomic_store({{.*}} ; CHECK: atomic_store_explicit((global volatile atomic_double*) (A + i0 * 1), f0, memory_order_relaxed, memory_scope_work_group); ; CHECK: atomic_fetch_add_explicit((global volatile atomic_double*) (A + i0 * 1), f0, memory_order_relaxed, memory_scope_work_group); @@ -15,8 +15,8 @@ func @atomic_store(%A: memref) { func @atomic_store_c64(%A: memref) { %f0 = constant [0.0, 0.0] : c64 %i0 = constant 0 : index - store.atomic %f0, %A[%i0] : memref - store.atomic_add %f0, %A[%i0] : memref + store.atomic %f0, %A[%i0] + store.atomic_add %f0, %A[%i0] ; CHECK-LABEL: void atomic_store_c64({{.*}} ; CHECK: atomic_store_explicit((atomic_double*global volatile) &(*(A + i0 * 1)).x, f0, memory_order_relaxed, memory_scope_work_group); ; CHECK-NEXT: atomic_store_explicit((atomic_double*global volatile) &(*(A + i0 * 1)).y, f0, memory_order_relaxed, memory_scope_work_group); diff --git a/test/codegen/dope_vector_group0.ir b/test/codegen/dope_vector_group0.ir index bae9d6cf..86bd8348 100644 --- a/test/codegen/dope_vector_group0.ir +++ b/test/codegen/dope_vector_group0.ir @@ -5,7 +5,7 @@ func @kernel1(%in: group>) { ; CHECK: void kernel1(global float*global* in, global long* in_shape1, global long* in_stride2) %c5 = constant 5 : index - %0 = load %in[%c5] : group> + %0 = load %in[%c5] : memref ; CHECK-NEXT: long c5 = 5ll; ; CHECK-NEXT: global float* x = *(in + c5) + 0; ; CHECK-NEXT: long x_shape1 = in_shape1[c5]; @@ -15,7 +15,7 @@ func @kernel1(%in: group>) { func @kernel2(%in: group, offset: ?>) { ; CHECK: void kernel2(global float*global* in, global long* in_shape0, long in_offset) %c5 = constant 5 : index - %0 = load %in[%c5] : group, offset: ?> + %0 = load %in[%c5] : memref ; CHECK-NEXT: long c5 = 5ll; ; CHECK-NEXT: global float* x = *(in + c5) + in_offset; ; CHECK-NEXT: long x_shape0 = in_shape0[c5]; diff --git a/test/codegen/expand.ir b/test/codegen/expand.ir index 1e5a9565..e7873b92 100644 --- a/test/codegen/expand.ir +++ b/test/codegen/expand.ir @@ -5,7 +5,7 @@ func @t1(%0: memref) { %z = constant 0 : index %1 = expand %0[1->2x8] : memref - %2 = load %1[%z,%z,%z,%z] : memref + %2 = load %1[%z,%z,%z,%z] : f32 ; CHECK-LABEL: void t1( ; CHECK: global float* x1 = x; ; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * 32 + z * 64 + z * 512); @@ -13,7 +13,7 @@ func @t1(%0: memref) { func @t2(%0: memref) { %z = constant 0 : index %1 = expand %0[1->2x2x2x2] : memref - %2 = load %1[%z,%z,%z,%z,%z,%z] : memref + %2 = load %1[%z,%z,%z,%z,%z,%z] : f32 ; CHECK-LABEL: void t2( ; CHECK: global float* x1 = x; ; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * 32 + z * 64 + z * 128 + z * 256 + z * 512); @@ -21,7 +21,7 @@ func @t2(%0: memref) { func @t3(%0: memref, %1: index) { %z = constant 0 : index %2 = expand %0[1->%1 x 2] : memref - %3 = load %2[%z,%z,%z] : memref + %3 = load %2[%z,%z,%z] : f32 ; CHECK-LABEL: void t3( ; CHECK: global float* x2 = x; ; CHECK-NEXT: long x_shape11 = x1; @@ -31,7 +31,7 @@ func @t3(%0: memref, %1: index) { func @t4(%0: memref, %1: index) { %z = constant 0 : index %2 = expand %0[1->2 x %1] : memref - %3 = load %2[%z,%z,%z] : memref + %3 = load %2[%z,%z,%z] : f32 ; CHECK-LABEL: void t4( ; CHECK: global float* x2 = x; ; CHECK-NEXT: long x_shape2 = x1; @@ -40,7 +40,7 @@ func @t4(%0: memref, %1: index) { func @t5(%0: memref, %1: index) { %z = constant 0 : index %2 = expand %0[1->%1 x 2] : memref - %3 = load %2[%z,%z,%z] : memref + %3 = load %2[%z,%z,%z] : f32 ; CHECK-LABEL: void t5( ; CHECK: global float* x2 = x; ; CHECK-NEXT: long x_shape1 = x1; @@ -50,7 +50,7 @@ func @t5(%0: memref, %1: index) { func @t6(%0: memref, %1: index) { %z = constant 0 : index %2 = expand %0[1->%1 x 2] : memref - %3 = load %2[%z,%z,%z] : memref + %3 = load %2[%z,%z,%z] : f32 ; CHECK-LABEL: void t6( ; CHECK: global float* x2 = x; ; CHECK-NEXT: long x_shape11 = x1; @@ -60,7 +60,7 @@ func @t6(%0: memref, %1: index) { func @t7(%0: memref, %1: index, %2: index) { %z = constant 0 : index %3 = expand %0[1->%1 x %2 x 2] : memref - %4 = load %3[%z,%z,%z,%z] : memref + %4 = load %3[%z,%z,%z,%z] : f32 ; CHECK-LABEL: void t7( ; CHECK: global float* x3 = x; ; CHECK-NEXT: long x_shape1 = x1; @@ -72,7 +72,7 @@ func @t7(%0: memref, %1: index, %2: index) { func @t8(%0: memref, %1: index, %2: index) { %z = constant 0 : index %3 = expand %0[1->%2 x 2 x %1] : memref - %4 = load %3[%z,%z,%z,%z] : memref + %4 = load %3[%z,%z,%z,%z] : f32 ; CHECK-LABEL: void t8( ; CHECK: global float* x3 = x; ; CHECK-NEXT: long x_shape1 = x2; @@ -84,7 +84,7 @@ func @t8(%0: memref, %1: index, %2: index) { func @t9(%0: memref, %1: index, %2: index) { %z = constant 0 : index %3 = expand %0[1->%1 x %2] : memref - %4 = load %3[%z,%z,%z] : memref + %4 = load %3[%z,%z,%z] : f32 ; CHECK-LABEL: void t9( ; CHECK: global float* x3 = x; ; CHECK-NEXT: long x_shape11 = x1; @@ -95,7 +95,7 @@ func @t9(%0: memref, %1: index, %2: index) { func @t10(%0: memref, %1: index, %2: index) { %z = constant 0 : index %3 = expand %0[1->%1 x %2] : memref - %4 = load %3[%z,%z,%z] : memref + %4 = load %3[%z,%z,%z] : f32 ; CHECK-LABEL: void t10( ; CHECK: global float* x3 = x; ; CHECK-NEXT: long x_shape1 = x1; @@ -106,7 +106,7 @@ func @t10(%0: memref, %1: index, %2: index) { func @t11(%0: memref>) { %z = constant 0 : index %1 = expand %0[0->4 x 8] : memref> - %2 = load %1[%z,%z,%z] : memref> + %2 = load %1[%z,%z,%z] : f32 ; CHECK-LABEL: void t11( ; CHECK: global float* x1 = x; ; CHECK-NEXT: float x2 = *(x1 + z * 2 + z * 8 + z * 64); @@ -114,7 +114,7 @@ func @t11(%0: memref>) { func @t12(%0: memref>, %1: index) { %z = constant 0 : index %2 = expand %0[0->%1 x 4] : memref> - %3 = load %2[%z,%z,%z] : memref> + %3 = load %2[%z,%z,%z] : f32 ; CHECK-LABEL: void t12( ; CHECK: global float* x2 = x; ; CHECK-NEXT: long x_shape01 = x1; @@ -125,7 +125,7 @@ func @t12(%0: memref>, %1: index) { func @t13(%0: memref>, %1: index) { %z = constant 0 : index %2 = expand %0[0->4 x %1] : memref> - %3 = load %2[%z,%z,%z] : memref> + %3 = load %2[%z,%z,%z] : f32 ; CHECK-LABEL: void t13( ; CHECK: global float* x2 = x; ; CHECK-NEXT: long x_shape1 = x1; diff --git a/test/codegen/for.ir b/test/codegen/for.ir index b544246a..0e6a8fd1 100644 --- a/test/codegen/for.ir +++ b/test/codegen/for.ir @@ -25,7 +25,7 @@ func @for2(%fib: memref) { %fn = arith.add %fn_2, %fn_1 : i64 yield %fn_1, %fn : i64, i64 } - store %fn, %fib[] : memref + store %fn, %fib[] ; CHECK-LABEL: void for2({{.*}} ; CHECK: long f0 = 0ll; ; CHECK-NEXT: long f1 = 1ll; diff --git a/test/codegen/fuse.ir b/test/codegen/fuse.ir index cf080a04..72e5a9fe 100644 --- a/test/codegen/fuse.ir +++ b/test/codegen/fuse.ir @@ -5,13 +5,13 @@ func @t1(%0: memref) { %z = constant 0 : index %1 = fuse %0[1,3] : memref - %2 = load %1[%z,%z,%z] : memref + %2 = load %1[%z,%z,%z] : f32 ; CHECK: float x2 = *(x1 + z * 1 + z * 32 + z * 16384); } func @t2(%0: memref) { %z = constant 0 : index %1 = fuse %0[1,3] : memref> - %2 = load %1[%z,%z,%z] : memref> + %2 = load %1[%z,%z,%z] : f32 ; CHECK: long x_shape1 = 16 * x_shape2 * 4; ; CHECK-NEXT: long x_stride2 = x_stride4; ; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * 32 + z * x_stride2); @@ -19,13 +19,13 @@ func @t2(%0: memref) { func @t3(%0: memref>) { %z = constant 0 : index %1 = fuse %0[1,2] : memref> - %2 = load %1[%z,%z,%z] : memref> + %2 = load %1[%z,%z,%z] : f32 ; CHECK: float x2 = *(x1 + z * 1 + z * 48 + z * 1536); } func @t4(%0: memref>) { %z = constant 0 : index %1 = fuse %0[0,1] : memref> - %2 = load %1[%z,%z] : memref> + %2 = load %1[%z,%z] : f32 ; CHECK: long x_shape0 = 8 * x_shape1; ; CHECK-NEXT: long x_stride11 = x_stride2; ; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * x_stride11); diff --git a/test/codegen/load.ir b/test/codegen/load.ir index 33f80f6a..632ef086 100644 --- a/test/codegen/load.ir +++ b/test/codegen/load.ir @@ -4,10 +4,10 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @kernel1(%a: memref, %b: memref, %c: group>) { %c5 = constant 5 : index - %0 = load %a[] : memref + %0 = load %a[] : f32 %1 = group_id - %2 = load %b[%c5, %1] : memref - %3 = load %c[%1] : group> + %2 = load %b[%c5, %1] : f32 + %3 = load %c[%1] : memref ; CHECK: float x = *a; ; CHECK-NEXT: long x1 = get_global_id(2); ; CHECK-NEXT: float x2 = *(b + c5 * 1 + x1 * 10); @@ -16,6 +16,6 @@ func @kernel1(%a: memref, %b: memref, %c: group>) func @kernel2(%c: group, offset: 21>) { %0 = group_id - %1 = load %c[%0] : group, offset: 21> + %1 = load %c[%0] : memref ; CHECK: global float* x1 = *(c + x) + 21; } diff --git a/test/codegen/size.ir b/test/codegen/size.ir index c8c8d01d..02874f33 100644 --- a/test/codegen/size.ir +++ b/test/codegen/size.ir @@ -3,14 +3,14 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @t1(%0: memref) { - %1 = size %0[0] : memref - %2 = size %0[1] : memref + %1 = size %0[0] : index + %2 = size %0[1] : index ; CHECK: long x1 = 32; ; CHECK-NEXT: long x2 = 16; } func @t2(%0: memref) { - %1 = size %0[0] : memref - %2 = size %0[1] : memref + %1 = size %0[0] : index + %2 = size %0[1] : index ; CHECK: long x1 = x_shape0; ; CHECK-NEXT: long x2 = x_shape1; } diff --git a/test/codegen/store.ir b/test/codegen/store.ir index 2b09d642..b07ed04b 100644 --- a/test/codegen/store.ir +++ b/test/codegen/store.ir @@ -5,8 +5,8 @@ func @kernel(%a: memref, %b: memref, %c: f32) { %c5 = constant 5 : index %1 = group_id - store %c, %a[] : memref - store %c, %b[%c5, %1] : memref + store %c, %a[] + store %c, %b[%c5, %1] ; CHECK: *a = c; ; CHECK-NEXT: *(b + c5 * 1 + x * 10) = c; } diff --git a/test/codegen/type_mismatch0.ir b/test/codegen/type_mismatch0.ir index db6df785..7cddb539 100644 --- a/test/codegen/type_mismatch0.ir +++ b/test/codegen/type_mismatch0.ir @@ -3,6 +3,6 @@ ; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s func @kernel(%K0: memref) { - %0 = load %K0[] : memref -; CHECK: 6.13-31: Type of SSA value does not match operand type + %0 = load %K0[] : f64 +; CHECK: 6.8-23: Type of operand must match return type } diff --git a/test/codegen/type_mismatch1.ir b/test/codegen/type_mismatch1.ir index b79bf466..230ab15f 100644 --- a/test/codegen/type_mismatch1.ir +++ b/test/codegen/type_mismatch1.ir @@ -6,7 +6,7 @@ func @kernel(%K0: memref, %x: index, %y: index) { %z = constant 0 : index %0 = subview %K0[0:%x] : memref %1 = subview %0[0:%y] : memref - %2 = load %1[%z] : memref - %3 = load %1[%z] : memref> -; CHECK: 10.13-45: Type of SSA value does not match operand type + %2 = load %1[%z] : f64 + %3 = load %1[%z] : f64 +; CHECK: 9.8-24: Type of operand must match return type } diff --git a/test/opt/check-ir/subview.ir b/test/opt/check-ir/subview.ir index 5f06833e..52e664e2 100644 --- a/test/opt/check-ir/subview.ir +++ b/test/opt/check-ir/subview.ir @@ -3,47 +3,30 @@ ; RUN: %tinytc-opt -pcheck-ir -O0 < %s | filecheck %s -; No real checks needed, just check that it does not crash, that is, -; the types put in load match those returned by expand +; No real checks needed, just check that it does not crash ; CHECK: func @t1({{.*}} func @t1(%0: memref) { - %z = constant 0 : index %1 = subview %0[4:8,8:4] : memref - %2 = load %1[%z,%z] : memref> } func @t2(%0: memref, %1: index) { - %z = constant 0 : index %2 = subview %0[2:4,%1] : memref - %3 = load %2[%z] : memref } func @t3(%0: memref, %1: index) { - %z = constant 0 : index %2 = subview %0[2:4,%1:0] : memref - %3 = load %2[%z] : memref } func @t4(%0: memref, %1: index) { - %z = constant 0 : index %2 = subview %0[2:4,%1:1] : memref - %3 = load %2[%z,%z] : memref> } func @t5(%0: memref, %1: index) { - %z = constant 0 : index %2 = subview %0[%1:4] : memref - %3 = load %2[%z] : memref } func @t6(%0: memref, %1: index) { - %z = constant 0 : index %2 = subview %0[%1:%1] : memref - %3 = load %2[%z] : memref } func @t7(%0: memref, %1: index) { - %z = constant 0 : index %2 = subview %0[2:4, %1:%1, 6:7] : memref - %3 = load %2[%z,%z,%z] : memref> } func @t8(%0: memref>, %1: index) { - %z = constant 0 : index %2 = subview %0[2:4, %1:%1, 6:7] : memref> - %3 = load %2[%z,%z,%z] : memref> } diff --git a/test/opt/constant-propagation.ir b/test/opt/constant-propagation.ir index 9a487878..0df6f23e 100644 --- a/test/opt/constant-propagation.ir +++ b/test/opt/constant-propagation.ir @@ -3,15 +3,15 @@ ; RUN: %tinytc-opt -pconstant-propagation < %s | filecheck %s func @known_size(%a: memref, %b: index) { - %0 = size %a[0] : memref - %1 = size %a[1] : memref + %0 = size %a[0] : index + %1 = size %a[1] : index %2 = arith.add %0, %1 : index %3 = arith.add %2, %b : index ; CHECK-LABEL: func @known_size({{.*}} ; CHECK: %0 = constant 64 : index -; CHECK-NEXT: %1 = size %a[0] : memref +; CHECK-NEXT: %1 = size %a[0] : index ; CHECK-NEXT: %2 = constant 32 : index -; CHECK-NEXT: %3 = size %a[1] : memref +; CHECK-NEXT: %3 = size %a[1] : index ; CHECK-NEXT: %4 = constant 96 : index ; CHECK-NEXT: %5 = arith.add %0, %2 : index ; CHECK-NEXT: %6 = arith.add %4, %b : index diff --git a/test/opt/dead-code-elimination.ir b/test/opt/dead-code-elimination.ir index f0097471..ca5b644e 100644 --- a/test/opt/dead-code-elimination.ir +++ b/test/opt/dead-code-elimination.ir @@ -6,12 +6,12 @@ func @dead_if(%a: memref) { %c0 = constant false : bool if %c0 { %c42 = constant 42.0 : f64 - store %c42, %a[] : memref + store %c42, %a[] } %c1 = constant true : bool if %c1 { %c43 = constant 43.0 : f64 - store %c43, %a[] : memref + store %c43, %a[] } ; CHECK-LABEL: func @dead_if({{.*}} ; CHECK-NEXT: %c1 = constant true : bool @@ -30,7 +30,7 @@ func @dead_if_with_yield(%a: memref) { %c43 = constant 43.0 : f64 yield %c43 : f64 } - store %0, %a[] : memref + store %0, %a[] ; Cannot eliminate if that returns results currently ; CHECK-LABEL: func @dead_if_with_yield({{.*}} ; CHECK: %0 = if %c0 -> (f64) { @@ -40,20 +40,20 @@ func @dead_if_with_yield(%a: memref) { ; CHECK-NEXT: %c43 = constant 0x1.58p+5 : f64 ; CHECK-NEXT: yield %c43 : f64 ; CHECK-NEXT: } -; CHECK-NEXT: store %0, %a[] : memref +; CHECK-NEXT: store %0, %a[] } func @dead_loop(%a: memref) { %c2 = constant 2 : index for %0=%c2,%c2 { %c42 = constant 42.0 : f64 - store %c42, %a[] : memref + store %c42, %a[] } %c5 = constant 5 : index %c6 = constant 6 : index for %0=%c5,%c6 { %c43 = constant 43.0 : f64 - store %c43, %a[] : memref + store %c43, %a[] } ; CHECK-LABEL: func @dead_loop({{.*}} ; CHECK-NEXT: %c5 = constant 5 : index diff --git a/test/opt/insert-barrier.ir b/test/opt/insert-barrier.ir index 3c28c0a4..023b5378 100644 --- a/test/opt/insert-barrier.ir +++ b/test/opt/insert-barrier.ir @@ -163,20 +163,20 @@ func @no_barrier_spmd(%a: f32, %b: f32, %A: memref, %B: memref %0 = subgroup_id %1 = cmp.eq %0, %c0 : bool if %1 { - %2 = load %A[%c3,%c4] : memref - store %2, %A[%c3,%c4] : memref + %2 = load %A[%c3,%c4] : f32 + store %2, %A[%c3,%c4] } } - %3 = load %A[%c3,%c4] : memref + %3 = load %A[%c3,%c4] : f32 ; CHECK-LABEL: func @no_barrier_spmd({{.*}} ; CHECK: parallel { ; CHECK-NEXT: %0 = subgroup_id ; CHECK-NEXT: %1 = cmp.eq %0, %c0 : bool ; CHECK-NEXT: if %1 { -; CHECK-NEXT: %2 = load %A[%c3,%c4] : memref +; CHECK-NEXT: %2 = load %A[%c3,%c4] : f32 ; CHECK-NEXT: store %2, %A[%c3,%c4] : memref ; CHECK-NEXT: } ; CHECK-NEXT: } ; CHECK-NEXT: barrier.global -; CHECK-NEXT: %3 = load %A[%c3,%c4] : memref +; CHECK-NEXT: %3 = load %A[%c3,%c4] : f32 } diff --git a/test/spv/alloca.ir b/test/spv/alloca.ir index 1d4f8592..99ebd264 100644 --- a/test/spv/alloca.ir +++ b/test/spv/alloca.ir @@ -29,10 +29,10 @@ func @alloca() { %0 = alloca : memref %1 = alloca : memref %2 = alloca : memref - %3 = load %0[%c0] : memref - %4 = load %1[%c0,%c0] : memref - %5 = load %2[] : memref - %6 = size %1[1] : memref + %3 = load %0[%c0] : i8 + %4 = load %1[%c0,%c0] : f32 + %5 = load %2[] : i16 + %6 = size %1[1] : index %7 = arith.not %6 : index ; CHECK: %[[#]] = OpFunction {{.*}} ; CHECK: %[[#STACK_PTR_I8:]] = OpInBoundsAccessChain %[[#I8_PTR]] %[[#STACK_VAR]] %[[#I64_C0]] diff --git a/test/spv/expand.ir b/test/spv/expand.ir index 634945b6..a7417352 100644 --- a/test/spv/expand.ir +++ b/test/spv/expand.ir @@ -14,17 +14,17 @@ func @f1(%0: memref, %1: index) { %c0 = constant 0 : index %2 = expand %0[1->4x%1x5] : memref - %3 = size %2[0] : memref - %4 = size %2[1] : memref - %5 = size %2[2] : memref - %6 = size %2[3] : memref - %7 = size %2[4] : memref + %3 = size %2[0] : index + %4 = size %2[1] : index + %5 = size %2[2] : index + %6 = size %2[3] : index + %7 = size %2[4] : index %8 = arith.not %3 : index %9 = arith.not %4 : index %10 = arith.not %5 : index %11 = arith.not %6 : index %12 = arith.not %7 : index - %13 = load %2[%c0,%c0,%c0,%c0,%c0] : memref + %13 = load %2[%c0,%c0,%c0,%c0,%c0] : f32 ; CHECK: %[[#]] = OpFunction {{.*}} ; CHECK-NEXT: %[[#P_MR:]] = OpFunctionParameter %[[#]] ; CHECK-NEXT: %[[#P_SHAPE1:]] = OpFunctionParameter %[[#I64]] diff --git a/test/spv/fuse.ir b/test/spv/fuse.ir index 75c8bf8c..2120747a 100644 --- a/test/spv/fuse.ir +++ b/test/spv/fuse.ir @@ -15,9 +15,9 @@ func @f1(%0: memref) { %z = constant 0 : index %1 = fuse %0[1,3] : memref - %2 = size %1[0] : memref - %3 = size %1[1] : memref - %4 = size %1[2] : memref + %2 = size %1[0] : index + %3 = size %1[1] : index + %4 = size %1[2] : index %5 = arith.not %2 : index %6 = arith.not %3 : index %7 = arith.not %4 : index diff --git a/test/spv/load.ir b/test/spv/load.ir index 70248fe4..7e493443 100644 --- a/test/spv/load.ir +++ b/test/spv/load.ir @@ -14,7 +14,7 @@ func @l1(%0: memref) { %2 = constant 0 : index - %3 = load %0[%2,%2] : memref + %3 = load %0[%2,%2] : f32 ; CHECK: %[[#]] = OpFunction {{.*}} ; CHECK-NEXT: %[[#L1_MR:]] = OpFunctionParameter %[[#PTR_F32]] ; CHECK-NEXT: %[[#]] = OpFunctionParameter %[[#I64]] @@ -30,8 +30,8 @@ func @l1(%0: memref) { func @l2(%0: group>, offset: ?>) { %1 = constant 0 : index - %2 = load %0[%1] : group>, offset: ?> - %3 = load %2[%1] : memref> + %2 = load %0[%1] : memref> + %3 = load %2[%1] : f32 ; CHECK: %[[#]] = OpFunction {{.*}} ; CHECK-NEXT: %[[#L2_GROUP:]] = OpFunctionParameter %[[#PTR_PTR_F32]] ; CHECK-NEXT: %[[#L2_PTR_SHAPE:]] = OpFunctionParameter %[[#PTR_I64]] diff --git a/test/spv/size.ir b/test/spv/size.ir index a2a19a1c..b282a39e 100644 --- a/test/spv/size.ir +++ b/test/spv/size.ir @@ -8,10 +8,10 @@ ; CHECK: %[[#I64_C32:]] = OpConstant %[[#I64]] 32 func @size(%0: memref) { - %1 = size %0[0] : memref - %2 = size %0[1] : memref - %3 = size %0[2] : memref - %4 = size %0[3] : memref + %1 = size %0[0] : index + %2 = size %0[1] : index + %3 = size %0[2] : index + %4 = size %0[3] : index %5 = arith.add %1, %2 : index %6 = arith.add %3, %4 : index ; CHECK: %[[#]] = OpFunction {{.*}} diff --git a/test/spv/store.ir b/test/spv/store.ir index 4659cbdb..fa4755ce 100644 --- a/test/spv/store.ir +++ b/test/spv/store.ir @@ -32,9 +32,9 @@ func @si8(%0: memref, %1: memref) { %2 = constant 0 : index %3 = constant -42 : i8 - store %3, %0[%2,%2] : memref - store.atomic %3, %1[] : memref - store.atomic_add %3, %1[] : memref + store %3, %0[%2,%2] + store.atomic %3, %1[] + store.atomic_add %3, %1[] ; CHECK: %[[#]] = OpFunction {{.*}} ; CHECK-NEXT: %[[#SI8_MR1:]] = OpFunctionParameter %[[#PTR_I8]] ; CHECK-NEXT: %[[#SI8_MR2:]] = OpFunctionParameter %[[#PTR_I8]] @@ -50,8 +50,8 @@ func @si8(%0: memref, %1: memref) { func @sf32(%0: memref) { %1 = constant 42.0 : f32 - store.atomic %1, %0[] : memref - store.atomic_add %1, %0[] : memref + store.atomic %1, %0[] + store.atomic_add %1, %0[] ; CHECK: %[[#]] = OpFunction {{.*}} ; CHECK-NEXT: %[[#SF32_MR:]] = OpFunctionParameter %[[#PTR_F32]] ; CHECK: OpAtomicStore %[[#SF32_MR]] %[[#I32_C2]] %[[#I32_C0]] %[[#F32_C42]] @@ -60,8 +60,8 @@ func @sf32(%0: memref) { func @sc64(%0: memref) { %1 = constant [42.0, 1.0] : c64 - store.atomic %1, %0[] : memref - store.atomic_add %1, %0[] : memref + store.atomic %1, %0[] + store.atomic_add %1, %0[] ; CHECK: %[[#]] = OpFunction {{.*}} ; CHECK-NEXT: %[[#SC64_MR:]] = OpFunctionParameter %[[#PTR_C64]] ; CHECK: %[[#SC64_RE_PTR1:]] = OpInBoundsAccessChain %[[#PTR_F64]] %[[#SC64_MR]] %[[#I32_C0]] diff --git a/test/spv/subview.ir b/test/spv/subview.ir index 6284e0c5..1548c147 100644 --- a/test/spv/subview.ir +++ b/test/spv/subview.ir @@ -13,9 +13,9 @@ func @sv1(%K0: memref, %offset: index, %size: index) { %0 = subview %K0[4:%size, %offset] : memref %1 = subview %K0[%offset, 4:%size] : memref - %2 = size %0[0] : memref + %2 = size %0[0] : index %3 = arith.not %2 : index - %4 = size %1[0] : memref> + %4 = size %1[0] : index %5 = arith.not %4 : index ; CHECK: %[[#]] = OpFunction {{.*}} ; CHECK-NEXT: %[[#P_MR:]] = OpFunctionParameter %[[#F32_PTR]] From 903336421968c06e84b0df0db026aef5ee951c8c Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 21 Nov 2024 21:51:30 +0100 Subject: [PATCH 124/297] Update for, foreach, subview, yield Signed-off-by: Carsten Uphoff --- docs/manual/tensor-ir.rst | 10 +- examples/benchmark/main.cpp | 8 +- examples/matrix_chain/test_ader.cpp | 22 +++- examples/matrix_chain/test_volume.cpp | 31 +++-- include/tinytc/tinytc.h | 39 +++--- include/tinytc/tinytc.hpp | 72 +++++++---- include/tinytc/types.h | 90 +++++++------- include/tinytc/types.hpp | 4 + src/codegen_tools.cpp | 6 +- src/error.cpp | 9 ++ src/inst.cpp | 47 ++++---- src/node/inst_node.cpp | 164 ++++++++++++++------------ src/node/inst_node.hpp | 14 +-- src/parser/parser_impl.yy | 56 ++------- src/pass/check_ir.cpp | 4 +- src/pass/clone.cpp | 23 ++-- src/pass/dump_ir.cpp | 14 +-- src/pass/lower_foreach.cpp | 7 +- src/pass/lower_linalg.cpp | 32 ++--- src/recipe/small_gemm_batched.cpp | 12 +- src/recipe/tall_and_skinny.cpp | 20 +++- test/codegen/axpby1.ir | 8 +- test/codegen/dope_vector0.ir | 4 +- test/codegen/for.ir | 2 +- test/codegen/if.ir | 16 +-- test/codegen/type_mismatch1.ir | 2 +- test/opt/check-ir/subview.ir | 16 +-- test/opt/constant-propagation.ir | 4 +- test/opt/dead-code-elimination.ir | 8 +- test/opt/insert-barrier.ir | 2 +- test/opt/insert-lifetime-stop.ir | 2 +- test/spv/for.ir | 4 +- test/spv/if.ir | 16 +-- test/spv/subview.ir | 4 +- 34 files changed, 419 insertions(+), 353 deletions(-) diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index ec54e716..ec7d6fee 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -1206,7 +1206,7 @@ Example: %f1 = constant 1 -> i64 %fn_1, %fn = for %n:i32=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) { %fn = arith.add %fn_2, %fn_1 : i64 - yield %fn_1, %fn : i64, i64 + yield (%fn_1, %fn) } ; %fn_1 contains the fourth Fibonacci number and %fn the fifth Fibonacci number @@ -1327,9 +1327,10 @@ Example: %1 = cmp.lt %0, 16 : i32 %x = if %1 -> (i32) { - yield %0 : i32 + yield (%0) } else { - yield 16 : i32 + %c16 = constant 16 : i32 + yield (%c16) } @@ -1463,6 +1464,7 @@ Restrictions The memref type of the result must conform with the following rules: +#. Element type and address space must match the operand's memref type. #. **Invariant-stride:** The stride is not changed or replaced with '?'. .. code:: @@ -1569,7 +1571,7 @@ Yield .. code:: abnf - instruction =/ "yield" [local-identifier-list] ":" [return-type-list] + instruction =/ "yield" "(" [local-identifier-list] ")" Overview ~~~~~~~~ diff --git a/examples/benchmark/main.cpp b/examples/benchmark/main.cpp index a9dbef03..511b9e4b 100644 --- a/examples/benchmark/main.cpp +++ b/examples/benchmark/main.cpp @@ -98,11 +98,11 @@ auto gemm_kernel_with_inner_repetition(scalar_type ty, transpose tA, transpose t auto calpha = bb.add(make_constant_one(element_ty, my_loc())); auto cbeta = bb.add(update ? make_constant_one(element_ty, my_loc()) : make_constant_zero(element_ty, my_loc())); - auto a = bb.add(make_load(params[0], {gid}, my_loc())); - auto b = bb.add(make_load(params[1], {gid}, my_loc())); - auto c = bb.add(make_load(params[2], {gid}, my_loc())); + auto a = bb.add(make_load(params[0], {gid}, element_ty, my_loc())); + auto b = bb.add(make_load(params[1], {gid}, element_ty, my_loc())); + auto c = bb.add(make_load(params[2], {gid}, element_ty, my_loc())); bb.for_loop( - from, to, index_ty, + index_ty, from, to, [&](region_builder &bb, value const &) { bb.add(make_gemm(tA, tB, atomic, calpha, a, b, cbeta, c, my_loc())); }, diff --git a/examples/matrix_chain/test_ader.cpp b/examples/matrix_chain/test_ader.cpp index 60d24253..d69c0080 100644 --- a/examples/matrix_chain/test_ader.cpp +++ b/examples/matrix_chain/test_ader.cpp @@ -103,11 +103,16 @@ auto test_ader::make_optimized_kernel(bool dump) return {b.nrows(), b.ncols(), 0}; }; auto const offsets3 = array_view(gid); - auto dq = bb.add(make_subview(Q, static_offsets3, static_sizes3(dQ_[0]), offsets3)); + auto dqt = get_memref(element_ty, static_sizes3(dQ_[0])); + auto dq = + bb.add(make_subview(Q, static_offsets3, static_sizes3(dQ_[0]), offsets3, {}, dqt)); for (std::size_t d = 0; d < dim; ++d) { - A(d) = bb.add(make_subview(A(d), static_offsets3, static_sizes3(A_[d]), offsets3)); + auto At = get_memref(element_ty, static_sizes3(A_[d])); + A(d) = + bb.add(make_subview(A(d), static_offsets3, static_sizes3(A_[d]), offsets3, {}, At)); } - auto i = bb.add(make_subview(I, static_offsets3, static_sizes3(I_opt_), offsets3)); + auto it = get_memref(element_ty, static_sizes3(I_opt_)); + auto i = bb.add(make_subview(I, static_offsets3, static_sizes3(I_opt_), offsets3, {}, it)); bb.add(make_axpby(transpose::N, false, c1, dq, c1, i)); int denom = 1; @@ -120,16 +125,21 @@ auto test_ader::make_optimized_kernel(bool dump) auto cfactor = bb.add(make_arith(arithmetic::div, cnum, cdenom, cnum.get_type())); auto bn = Bd_aligned(N_ - n); auto dq_next = bb.add(make_alloca(dQ_[n].local_type(element_ty))); - auto dq_nextv = bb.add(make_subview(dq_next, static_offsets2, {bn, P_})); + auto dq_nextvt = get_memref(element_ty, {bn, P_}, {}, address_space::local); + auto dq_nextv = + bb.add(make_subview(dq_next, static_offsets2, {bn, P_}, {}, {}, dq_nextvt)); auto tmp = bb.add( make_alloca(get_memref(element_ty, {bn, P_}, {1, bn}, address_space::local))); for (std::size_t d = 0; d < dim; ++d) { - auto Kv = bb.add(make_subview(K(d), static_offsets2, {bn, Bd(N_ - n + 1)})); + auto Kvt = get_memref(element_ty, {bn, Bd(N_ - n + 1)}); + auto Kv = + bb.add(make_subview(K(d), static_offsets2, {bn, Bd(N_ - n + 1)}, {}, {}, Kvt)); bb.add(make_gemm(transpose::N, transpose::N, false, c1, Kv, dq, c0, tmp)); bb.add(make_gemm(transpose::N, transpose::N, false, c1, tmp, A(d), d > 0 ? c1 : c0, dq_nextv)); } - auto iv = bb.add(make_subview(i, static_offsets2, {Bd(N_ - n), P_})); + auto ivt = get_memref(element_ty, {Bd(N_ - n), P_}); + auto iv = bb.add(make_subview(i, static_offsets2, {Bd(N_ - n), P_}, {}, {}, ivt)); bb.add(make_axpby(transpose::N, false, cfactor, dq_next, c1, iv)); dq = dq_next; } diff --git a/examples/matrix_chain/test_volume.cpp b/examples/matrix_chain/test_volume.cpp index af463662..ae057c11 100644 --- a/examples/matrix_chain/test_volume.cpp +++ b/examples/matrix_chain/test_volume.cpp @@ -92,15 +92,28 @@ auto test_volume::make_optimized_kernel(bool dump) auto const sizeK2 = std::array{B3_aligned_, B2_}; auto tmp = bb.add( make_alloca(get_memref(element_ty, {B2_aligned_, P_}, {}, address_space::local))); - auto a0 = bb.add(make_subview(A(0), static_offsets3, static_sizes3(A_[0]), offsets3)); - auto a1 = bb.add(make_subview(A(1), static_offsets3, static_sizes3(A_[1]), offsets3)); - auto a2 = bb.add(make_subview(A(2), static_offsets3, static_sizes3(A_[2]), offsets3)); - auto k0 = bb.add(make_subview(K(0), static_offsets2, sizeK2)); - auto k1 = bb.add(make_subview(K(1), static_offsets2, sizeK2)); - auto k2 = bb.add(make_subview(K(2), static_offsets2, sizeK2)); - auto qv = bb.add(make_subview(Q, static_offsets3, {B3_aligned_, P_, 0}, offsets3)); - auto iv = bb.add(make_subview(I, static_offsets3, {B2_aligned_, P_, 0}, offsets3)); - auto tmpv = bb.add(make_subview(tmp, static_offsets2, {B2_, P_})); + + auto a0t = get_memref(element_ty, static_sizes3(A_[0])); + auto a1t = get_memref(element_ty, static_sizes3(A_[1])); + auto a2t = get_memref(element_ty, static_sizes3(A_[2])); + auto k0t = get_memref(element_ty, sizeK2); + auto k1t = get_memref(element_ty, sizeK2); + auto k2t = get_memref(element_ty, sizeK2); + auto qvt = get_memref(element_ty, {B3_aligned_, P_}); + auto ivt = get_memref(element_ty, {B2_aligned_, P_}); + auto tmpvt = get_memref(element_ty, {B2_, P_}, {}, address_space::local); + auto a0 = + bb.add(make_subview(A(0), static_offsets3, static_sizes3(A_[0]), offsets3, {}, a0t)); + auto a1 = + bb.add(make_subview(A(1), static_offsets3, static_sizes3(A_[1]), offsets3, {}, a1t)); + auto a2 = + bb.add(make_subview(A(2), static_offsets3, static_sizes3(A_[2]), offsets3, {}, a2t)); + auto k0 = bb.add(make_subview(K(0), static_offsets2, sizeK2, {}, {}, k0t)); + auto k1 = bb.add(make_subview(K(1), static_offsets2, sizeK2, {}, {}, k1t)); + auto k2 = bb.add(make_subview(K(2), static_offsets2, sizeK2, {}, {}, k2t)); + auto qv = bb.add(make_subview(Q, static_offsets3, {B3_aligned_, P_, 0}, offsets3, {}, qvt)); + auto iv = bb.add(make_subview(I, static_offsets3, {B2_aligned_, P_, 0}, offsets3, {}, ivt)); + auto tmpv = bb.add(make_subview(tmp, static_offsets2, {B2_, P_}, {}, {}, tmpvt)); auto const c0 = bb.add(make_constant_zero(element_ty)); auto const c1 = bb.add(make_constant_one(element_ty)); bb.add(make_gemm(transpose::N, transpose::N, false, c1, iv, a0, c0, tmp)); diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index cf30cae2..c5506b0c 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -766,7 +766,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_subgroup_size_inst_create(tinytc_inst_t *in /** * @brief Create subview instruction * - * @code %value = subview %a[%offset1:%size1,...,%offsetN:%sizeN] : type(%a) @endcode + * @code %value = subview %a[%offset1:%size1,...,%offsetN:%sizeN] : ty @endcode * * @param instr [out] pointer to the inst object created * @param a [in] operand @@ -780,6 +780,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_subgroup_size_inst_create(tinytc_inst_t *in * offset_list_size is 0 * @param size_list_size [in] number of dynamic sizes * @param size_list [in][range(0, size_list_size)] size array; may be nullptr if size_list_size is 0 + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise @@ -788,7 +789,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_subview_inst_create( tinytc_inst_t *instr, tinytc_value_t a, uint32_t static_list_size, const int64_t *static_offset_list, const int64_t *static_size_list, uint32_t offset_list_size, const tinytc_value_t *offset_list, uint32_t size_list_size, const tinytc_value_t *size_list, - const tinytc_location_t *loc); + tinytc_data_type_t ty, const tinytc_location_t *loc); /** * @brief Create store instruction @@ -848,23 +849,23 @@ TINYTC_EXPORT tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinyt * @endcode * * @param instr [out] pointer to the inst object created + * @param loop_var_type [in] type of loop variable * @param from [in] loop begion * @param to [in] loop bound * @param step [in][optional] loop step; can be nullptr - * @param init_list_size [in] length of init_value_list and return_type_list - * @param init_value_list [in][range(0, init_list_size)] array of initial values; can be - * nullptr if init_value_list is 0 - * @param loop_var_type [in] type of loop variable + * @param init_return_list_size [in] length of init_value_list and return_type_list + * @param initial_value_list [in][range(0, init_return_list_size)] array of initial values; can be + * nullptr if init_return_list_size is 0 + * @param return_type_list [in][range(0, init_return_list_size)] array of return types; can be + * nullptr if init_return_list_size is 0 * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_value_t from, - tinytc_value_t to, tinytc_value_t step, - uint32_t init_list_size, - const tinytc_value_t *init_value_list, - tinytc_data_type_t loop_var_type, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_for_inst_create( + tinytc_inst_t *instr, tinytc_data_type_t loop_var_type, tinytc_value_t from, tinytc_value_t to, + tinytc_value_t step, uint32_t init_return_list_size, const tinytc_value_t *initial_value_list, + const tinytc_data_type_t *return_type_list, const tinytc_location_t *loc); /** * @brief Create foreach loop @@ -876,19 +877,17 @@ TINYTC_EXPORT tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinyt * @endcode * * @param instr [out] pointer to the inst object created + * @param loop_var_type [in] type of loop variable * @param dim [in] length of from and to array; must be > 0 * @param from_list [in][range(1, dim)] loop begion * @param to_list [in][range(1, dim)] loop bound - * @param loop_var_type [in] type of loop variable * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, uint32_t dim, - const tinytc_value_t *from_list, - const tinytc_value_t *to_list, - tinytc_data_type_t loop_var_type, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_foreach_inst_create( + tinytc_inst_t *instr, tinytc_data_type_t loop_var_type, uint32_t dim, + const tinytc_value_t *from_list, const tinytc_value_t *to_list, const tinytc_location_t *loc); /** * @brief Create if condition @@ -915,12 +914,13 @@ TINYTC_EXPORT tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc * @brief Create work group instruction * * @code - * %value = work_group work_group_op %operand : type(%operand) + * %value = work_group work_group_op %operand : ty * @endcode * * @param instr [out] pointer to the inst object created * @param operation [in] Work group operation * @param operand [in] operand + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise @@ -928,6 +928,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc TINYTC_EXPORT tinytc_status_t tinytc_work_group_inst_create(tinytc_inst_t *instr, tinytc_work_group_operation_t operation, tinytc_value_t operand, + tinytc_data_type_t ty, const tinytc_location_t *loc); /** diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index cb80479d..360a4c9b 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1443,14 +1443,14 @@ inline inst make_subgroup_size(compiler_context const &ctx, location const &loc * contains "dynamic" * @param size_list Vector of sizes; need to add dynamic sizes here if static_size_list contains * "dynamic" + * @param ty Return type * @param loc Source code location * * @return Instruction */ inline inst make_subview(value a, array_view static_offset_list, - array_view static_size_list, - array_view offset_list = {}, array_view size_list = {}, - location const &loc = {}) { + array_view static_size_list, array_view offset_list, + array_view size_list, data_type ty, location const &loc = {}) { tinytc_inst_t instr; if (static_offset_list.size() != static_size_list.size()) { throw std::invalid_argument( @@ -1472,7 +1472,7 @@ inline inst make_subview(value a, array_view static_offset_list, const tinytc_value_t *sl = reinterpret_cast(size_list.data()); CHECK_STATUS_LOC(tinytc_subview_inst_create(&instr, a, static_len, static_offset_list.data(), static_size_list.data(), offset_len, ol, size_len, - sl, &loc), + sl, ty, &loc), loc); return inst(instr); } @@ -1527,24 +1527,31 @@ inline inst make_sum(transpose tA, bool atomic, value alpha, value A, value beta /** * @brief Make for loop instruction * + * @param loop_var_type Type of loop variable * @param from Loop variable start * @param to Loop variable bound * @param step Loop variable step; can be {} * @param initial_value_list Array of initial values; can be {} - * @param loop_var_type Type of loop variable + * @param return_type_list Array of return types; can be {} * @param loc Source code location * * @return Instruction */ -inline inst make_for(value from, value to, value step, array_view initial_value_list, - data_type loop_var_type, location const &loc = {}) { +inline inst make_for(data_type loop_var_type, value from, value to, value step, + array_view initial_value_list, array_view return_type_list, + location const &loc = {}) { tinytc_inst_t instr; auto len = initial_value_list.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("initial value list too long"); } + if (len != return_type_list.size()) { + throw std::invalid_argument( + "initial value list must have the same length as the return type list"); + } const tinytc_value_t *il = reinterpret_cast(initial_value_list.data()); - CHECK_STATUS_LOC(tinytc_for_inst_create(&instr, from, to, step, len, il, loop_var_type, &loc), + CHECK_STATUS_LOC(tinytc_for_inst_create(&instr, loop_var_type, from, to, step, len, il, + return_type_list.data(), &loc), loc); return inst(instr); } @@ -1552,15 +1559,15 @@ inline inst make_for(value from, value to, value step, array_view initial /** * @brief Make foreach loop instruction * + * @param loop_var_type Type of loop variable * @param from_list List of loop variable start * @param to_list List of loop variable bound - * @param loop_var_type Type of loop variable * @param loc Source code location * * @return Instruction */ -inline inst make_foreach(array_view from_list, array_view to_list, - data_type loop_var_type, location const &loc = {}) { +inline inst make_foreach(data_type loop_var_type, array_view from_list, + array_view to_list, location const &loc = {}) { tinytc_inst_t instr; if (from_list.size() != to_list.size()) { @@ -1576,7 +1583,7 @@ inline inst make_foreach(array_view from_list, array_view to_list, } const tinytc_value_t *fl = reinterpret_cast(from_list.data()); const tinytc_value_t *tl = reinterpret_cast(to_list.data()); - CHECK_STATUS_LOC(tinytc_foreach_inst_create(&instr, from_len, fl, tl, loop_var_type, &loc), + CHECK_STATUS_LOC(tinytc_foreach_inst_create(&instr, loop_var_type, from_len, fl, tl, &loc), loc); return inst(instr); } @@ -1602,12 +1609,22 @@ inline inst make_if(value condition, array_view return_type_list = {} return inst(instr); } -inline inst make_work_group(work_group_operation operation, value operand, +/** + * @brief Make work group instruction + * + * @param operation Work group operation + * @param operand Operand + * @param ty Return type + * @param loc Location + * + * @return + */ +inline inst make_work_group(work_group_operation operation, value operand, data_type ty, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC( tinytc_work_group_inst_create(&instr, static_cast(operation), - operand, &loc), + operand, ty, &loc), loc); return inst(instr); } @@ -1862,8 +1879,8 @@ class region_builder { * @param loc Source code location */ template - void for_loop(value from, value to, data_type loop_var_ty, F &&f, location const &loc = {}) { - for_loop(std::move(from), std::move(to), nullptr, std::move(loop_var_ty), + void for_loop(data_type loop_var_ty, value from, value to, F &&f, location const &loc = {}) { + for_loop(std::move(loop_var_ty), std::move(from), std::move(to), nullptr, std::forward(f), loc); } /** @@ -1872,17 +1889,17 @@ class region_builder { * The loop trip count is passed as second argument to the functor. * * @tparam F Functor type + * @param loop_var_ty Type of loop variable * @param from Loop variable start * @param to Loop variable bound * @param step Loop variable step - * @param loop_var_ty Type of loop variable * @param f Functor * @param loc Source code location */ template - void for_loop(value from, value to, value step, data_type loop_var_ty, F &&f, + void for_loop(data_type loop_var_ty, value from, value to, value step, F &&f, location const &loc = {}) { - auto fi = ::tinytc::make_for(from, to, step, {}, loop_var_ty, loc); + auto fi = ::tinytc::make_for(loop_var_ty, from, to, step, {}, {}, loc); auto reg = region{}; fi.get_regions(reg); auto loop_var = value{}; @@ -1901,18 +1918,21 @@ class region_builder { * The following values are the loop-carried values. * * @tparam F Functor type + * @param loop_var_ty Type of loop variable * @param from Loop variable start * @param to Loop variable bound * @param step Loop variable step * @param initial_value_list Array of initial values; can be {} - * @param loop_var_ty Type of loop variable + * @param return_type_list Array of return types; can be {} * @param f Functor * @param loc Source code location */ template - auto for_loop(value from, value to, value step, array_view initial_value_list, - data_type loop_var_ty, F &&f, location const &loc = {}) -> std::vector { - auto fi = ::tinytc::make_for(from, to, step, initial_value_list, loop_var_ty, loc); + auto for_loop(data_type loop_var_ty, value from, value to, value step, + array_view initial_value_list, array_view return_type_list, + F &&f, location const &loc = {}) -> std::vector { + auto fi = ::tinytc::make_for(loop_var_ty, from, to, step, initial_value_list, + return_type_list, loc); auto reg = region{}; fi.get_regions(reg); auto num_params = reg.get_parameters({}); @@ -1930,16 +1950,16 @@ class region_builder { * @brief Build foreach-loop with functor f(region_builder&, array_view) -> void * * @tparam F Functor type + * @param loop_var_ty Type of loop variable * @param from Loop variable start list * @param to Loop variable bound list - * @param loop_var_ty Type of loop variable * @param f functor * @param loc Source code location */ template - void foreach (array_view from, array_view to, data_type loop_var_ty, F && f, + void foreach (data_type loop_var_ty, array_view from, array_view to, F && f, location const &loc = {}) { - auto fi = ::tinytc::make_foreach(std::move(from), std::move(to), loop_var_ty, loc); + auto fi = ::tinytc::make_foreach(loop_var_ty, std::move(from), std::move(to), loc); auto reg = region{}; fi.get_regions(reg); auto num_params = reg.get_parameters({}); diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 6c8379a5..e7da9c42 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -51,54 +51,58 @@ typedef enum { tinytc_status_ir_invalid_number_of_indices = 0x105, /// Invalid number of indices tinytc_status_ir_expected_boolean = 0x106, ///< Expected a value of boolean type tinytc_status_ir_expected_scalar = 0x107, ///< Expected a value of scalar type - tinytc_status_ir_expected_index = 0x108, ///< Expected a value of index type - tinytc_status_ir_expected_coopmatrix = 0x109, ///< Expected a value of coopmatrix type + tinytc_status_ir_expected_int = 0x108, ///< Expected a value of integer type + tinytc_status_ir_expected_float = 0x109, ///< Expected a value of float type + tinytc_status_ir_expected_complex = 0x10a, ///< Expected a value of complex type + tinytc_status_ir_expected_index = 0x10b, ///< Expected a value of index type + tinytc_status_ir_expected_coopmatrix = 0x10c, ///< Expected a value of coopmatrix type tinytc_status_ir_expected_coopmatrix_or_scalar = - 0x10a, ///< Expected a value of coopmatrix or scalar type + 0x10d, ///< Expected a value of coopmatrix or scalar type tinytc_status_ir_expected_coopmatrix_scalar_or_boolean = - 0x10b, ///< Expected a value of coopmatrix, scalar type, or boolean - tinytc_status_ir_expected_memref = 0x10c, ///< Expected a value of memref type - tinytc_status_ir_expected_memref_or_scalar = 0x10d, ///< Expected memref or scalar type - tinytc_status_ir_expected_memref_or_group = 0x10e, ///< Expected a value of memref or group type - tinytc_status_ir_expected_memref_order_0 = 0x10f, ///< Expected memref of order 0 - tinytc_status_ir_expected_memref_order_1 = 0x110, ///< Expected memref of order 1 - tinytc_status_ir_expected_memref_order_2 = 0x111, ///< Expected memref of order 2 - tinytc_status_ir_expected_memref_order_0_or_1 = 0x112, ///< Expected memref of order 0 or 1 - tinytc_status_ir_expected_memref_order_1_or_2 = 0x113, ///< Expected memref of order 1 or 2 - tinytc_status_ir_expected_memref_order_0_1_or_2 = 0x114, ///< Expected memref of order 0, 1 or 2 - tinytc_status_ir_unexpected_yield = 0x115, ///< Unexpected yield instruction - tinytc_status_ir_yield_mismatch = 0x116, ///< Wrong number of yielded values - tinytc_status_ir_subview_mismatch = 0x117, ///< Mismatch in subview - tinytc_status_ir_invalid_slice = 0x118, ///< Invalid slice - tinytc_status_ir_expand_shape_order_too_small = 0x119, ///< Expand shape too small - tinytc_status_ir_expand_shape_mismatch = 0x11a, ///< Invalid expand shape - tinytc_status_ir_collective_called_from_spmd = 0x11b, ///< Collective instruction from SPMD - tinytc_status_ir_fp_unsupported = 0x11c, ///< Instruction does not support floating type - tinytc_status_ir_spmd_called_from_collective = 0x11d, ///< SPMD instruction from collective - tinytc_status_ir_expected_local_address_space = 0x11e, ///< Expected local address space - tinytc_status_ir_expected_global_address_space = 0x11f, ///< Expected global address space - tinytc_status_ir_address_space_mismatch = 0x120, ///< Address space must match - tinytc_status_ir_invalid_offset = 0x121, ///< Invalid offset - tinytc_status_ir_int_unsupported = 0x122, ///< Instruction does not support int type - tinytc_status_ir_boolean_unsupported = 0x123, ///< Instruction does not support boolean type - tinytc_status_ir_complex_unsupported = 0x124, ///< Instruction does not support complex type + 0x10e, ///< Expected a value of coopmatrix, scalar type, or boolean + tinytc_status_ir_expected_memref = 0x10f, ///< Expected a value of memref type + tinytc_status_ir_expected_memref_or_scalar = 0x110, ///< Expected memref or scalar type + tinytc_status_ir_expected_memref_or_group = 0x111, ///< Expected a value of memref or group type + tinytc_status_ir_expected_memref_order_0 = 0x112, ///< Expected memref of order 0 + tinytc_status_ir_expected_memref_order_1 = 0x113, ///< Expected memref of order 1 + tinytc_status_ir_expected_memref_order_2 = 0x114, ///< Expected memref of order 2 + tinytc_status_ir_expected_memref_order_0_or_1 = 0x115, ///< Expected memref of order 0 or 1 + tinytc_status_ir_expected_memref_order_1_or_2 = 0x116, ///< Expected memref of order 1 or 2 + tinytc_status_ir_expected_memref_order_0_1_or_2 = 0x117, ///< Expected memref of order 0, 1 or 2 + tinytc_status_ir_unexpected_yield = 0x118, ///< Unexpected yield instruction + tinytc_status_ir_yield_mismatch = 0x119, ///< Wrong number of yielded values + tinytc_status_ir_subview_mismatch = 0x11a, ///< Mismatch in subview + tinytc_status_ir_invalid_slice = 0x11b, ///< Invalid slice + tinytc_status_ir_expand_shape_order_too_small = 0x11c, ///< Expand shape too small + tinytc_status_ir_expand_shape_mismatch = 0x11d, ///< Invalid expand shape + tinytc_status_ir_collective_called_from_spmd = 0x11e, ///< Collective instruction from SPMD + tinytc_status_ir_fp_unsupported = 0x11f, ///< Instruction does not support floating type + tinytc_status_ir_spmd_called_from_collective = 0x120, ///< SPMD instruction from collective + tinytc_status_ir_expected_local_address_space = 0x121, ///< Expected local address space + tinytc_status_ir_expected_global_address_space = 0x122, ///< Expected global address space + tinytc_status_ir_address_space_mismatch = 0x123, ///< Address space must match + tinytc_status_ir_invalid_offset = 0x124, ///< Invalid offset + tinytc_status_ir_int_unsupported = 0x125, ///< Instruction does not support int type + tinytc_status_ir_boolean_unsupported = 0x126, ///< Instruction does not support boolean type + tinytc_status_ir_complex_unsupported = 0x127, ///< Instruction does not support complex type tinytc_status_ir_coopmatrix_unsupported = - 0x125, ///< Instruction does not support coopmatrix type - tinytc_status_ir_forbidden_cast = 0x126, ///< Forbidden cast - tinytc_status_ir_invalid_beta = 0x127, ///< Invalid beta value - tinytc_status_ir_init_return_mismatch = 0x128, ///< Mismatch of init values and returned values - tinytc_status_ir_invalid_matrix_use = 0x129, ///< Invalid matrix use - tinytc_status_ir_unsupported_coopmatrix_shape = 0x12a, ///< Unsupported coopmatrix shape - tinytc_status_ir_incompatible_scalar_types = 0x12b, ///< Incompatible scalar types - tinytc_status_ir_constant_mismatch = 0x12c, ///< Constant mismatch - tinytc_status_ir_insufficient_alignment = 0x12d, ///< Insufficient alignment - tinytc_status_ir_must_have_yield = 0x12e, ///< Must have yield instruction + 0x128, ///< Instruction does not support coopmatrix type + tinytc_status_ir_forbidden_cast = 0x129, ///< Forbidden cast + tinytc_status_ir_invalid_beta = 0x12a, ///< Invalid beta value + tinytc_status_ir_init_return_mismatch = 0x12b, ///< Mismatch of init values and returned values + tinytc_status_ir_invalid_matrix_use = 0x12c, ///< Invalid matrix use + tinytc_status_ir_unsupported_coopmatrix_shape = 0x12d, ///< Unsupported coopmatrix shape + tinytc_status_ir_incompatible_scalar_types = 0x12e, ///< Incompatible scalar types + tinytc_status_ir_constant_mismatch = 0x12f, ///< Constant mismatch + tinytc_status_ir_insufficient_alignment = 0x130, ///< Insufficient alignment + tinytc_status_ir_must_have_yield = 0x131, ///< Must have yield instruction tinytc_status_ir_yield_in_else_branch_missing = - 0x12f, ///< Must have yield instruction in else branch - tinytc_status_ir_from_to_mismatch = 0x130, ///< size(from) != size(to) in foreach + 0x132, ///< Must have yield instruction in else branch + tinytc_status_ir_from_to_mismatch = 0x133, ///< size(from) != size(to) in foreach tinytc_status_ir_operand_type_must_match_return_type = - 0x131, /// Operand type must match return type - tinytc_status_ir_invalid_stride = 0x132, ///< Invalid stride + 0x134, /// Operand type must match return type + tinytc_status_ir_invalid_stride = 0x135, ///< Invalid stride + tinytc_status_ir_init_return_type_mismatch = 0x136, ///< Init return type mismatch // SPIR-V errors tinytc_status_spirv_forbidden_forward_declaration = 0x1000, ///< Forward declaration of id is forbidden diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index f63c9902..64d7a6b5 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -61,6 +61,9 @@ enum class status { ir_invalid_number_of_indices = tinytc_status_ir_invalid_number_of_indices, ir_expected_boolean = tinytc_status_ir_expected_boolean, ir_expected_scalar = tinytc_status_ir_expected_scalar, + ir_expected_int = tinytc_status_ir_expected_int, + ir_expected_float = tinytc_status_ir_expected_float, + ir_expected_complex = tinytc_status_ir_expected_complex, ir_expected_index = tinytc_status_ir_expected_index, ir_expected_coopmatrix = tinytc_status_ir_expected_coopmatrix, ir_expected_coopmatrix_or_scalar = tinytc_status_ir_expected_coopmatrix_or_scalar, @@ -105,6 +108,7 @@ enum class status { ir_from_to_mismatch = tinytc_status_ir_from_to_mismatch, ir_operand_type_must_match_return_type = tinytc_status_ir_operand_type_must_match_return_type, ir_invalid_stride = tinytc_status_ir_invalid_stride, + ir_init_return_type_mismatch = tinytc_status_ir_init_return_type_mismatch, spirv_forbidden_forward_declaration = tinytc_status_spirv_forbidden_forward_declaration, spirv_undefined_value = tinytc_status_spirv_undefined_value, spirv_missing_dope_vector = tinytc_status_spirv_missing_dope_vector, diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index f8110a41..0210ead4 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -509,7 +509,7 @@ void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, in instant_constant_fold_add(bb, make_arith(arithmetic::mul, c_sgs, sg_id_cast, ity)); auto block_end = instant_constant_fold_add(bb, make_arith(arithmetic::mul, c_sgs, blocks, ity)); - bb.for_loop(std::move(block_start), std::move(block_end), c_sgs_tiles, ity, + bb.for_loop(ity, std::move(block_start), std::move(block_end), c_sgs_tiles, [&](region_builder &bb, value block) { body(bb, block, false, c_sgs); }); }); @@ -566,7 +566,7 @@ void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int bloc instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, rem, ity)); auto step_1 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, c_tiles, ity)); - bb.for_loop(std::move(block_start_1), std::move(block_end_1), std::move(step_1), ity, + bb.for_loop(ity, std::move(block_start_1), std::move(block_end_1), std::move(step_1), [&](region_builder &bb, value block) { body(bb, block, bs_1); }); }); @@ -577,7 +577,7 @@ void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int bloc auto tmp3 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, rem, ity)); auto block_start = instant_constant_fold_add(bb, make_arith(arithmetic::add, tmp3, tmp2, ity)); auto step = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs, c_tiles, ity)); - bb.for_loop(std::move(block_start), loop_trip_count, std::move(step), ity, + bb.for_loop(ity, std::move(block_start), loop_trip_count, std::move(step), [&](region_builder &bb, value block) { body(bb, block, bs); }); } diff --git a/src/error.cpp b/src/error.cpp index 5368474c..8fead6f1 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -150,6 +150,12 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Expected boolean type"; case tinytc_status_ir_expected_scalar: return "Expected scalar type"; + case tinytc_status_ir_expected_int: + return "Expected integer type"; + case tinytc_status_ir_expected_float: + return "Expected floating point type"; + case tinytc_status_ir_expected_complex: + return "Expected complex type"; case tinytc_status_ir_expected_index: return "Expected index type"; case tinytc_status_ir_expected_coopmatrix: @@ -239,6 +245,9 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Type of operand must match return type"; case tinytc_status_ir_invalid_stride: return "Invalid stride"; + case tinytc_status_ir_init_return_type_mismatch: + return "Type of initializer does not match return type or the number of return types is " + "not equal the number of initializers"; // SPIR-V case tinytc_status_spirv_forbidden_forward_declaration: return "Forward declaration of id is forbidden"; diff --git a/src/inst.cpp b/src/inst.cpp index c4e1adf6..5eb9f418 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -581,7 +581,7 @@ tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t a, uint32_t stat const int64_t *static_offset_list, const int64_t *static_size_list, uint32_t offset_list_size, const tinytc_value_t *offset_list, uint32_t size_list_size, const tinytc_value_t *size_list, - const tinytc_location_t *loc) { + tinytc_data_type_t ty, const tinytc_location_t *loc) { if (instr == nullptr || (static_list_size > 0 && (static_offset_list == nullptr || static_size_list == nullptr)) || (offset_list_size > 0 && offset_list == nullptr) || @@ -589,12 +589,12 @@ tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t a, uint32_t stat return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = - std::make_unique(a, array_view{static_offset_list, static_list_size}, - array_view{static_size_list, static_list_size}, - array_view{offset_list, offset_list_size}, - array_view{size_list, size_list_size}, get_optional(loc)) - .release(); + *instr = std::make_unique(a, array_view{static_offset_list, static_list_size}, + array_view{static_size_list, static_list_size}, + array_view{offset_list, offset_list_size}, + array_view{size_list, size_list_size}, ty, + get_optional(loc)) + .release(); }); } @@ -627,37 +627,38 @@ tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinytc_transpose_t }); } -tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_value_t from, tinytc_value_t to, - tinytc_value_t step, uint32_t init_list_size, +tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_data_type_t loop_var_type, + tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, + uint32_t init_return_list_size, const tinytc_value_t *initial_value_list, - tinytc_data_type_t loop_var_type, + const tinytc_data_type_t *return_type_list, const tinytc_location_t *loc) { if (instr == nullptr || loop_var_type == nullptr || from == nullptr || to == nullptr || - (init_list_size != 0 && initial_value_list == nullptr)) { + (init_return_list_size != 0 && + (initial_value_list == nullptr || return_type_list == nullptr))) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(from, to, step, - array_view{initial_value_list, init_list_size}, - loop_var_type, get_optional(loc)) + *instr = std::make_unique(loop_var_type, from, to, step, + array_view{initial_value_list, init_return_list_size}, + array_view{return_type_list, init_return_list_size}, + get_optional(loc)) .release(); }); } -tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, uint32_t dim, - const tinytc_value_t *from_list, +tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, tinytc_data_type_t loop_var_type, + uint32_t dim, const tinytc_value_t *from_list, const tinytc_value_t *to_list, - tinytc_data_type_t loop_var_type, const tinytc_location_t *loc) { if (instr == nullptr || loop_var_type == nullptr || from_list == nullptr || to_list == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = - std::make_unique(array_view{from_list, dim}, array_view{to_list, dim}, - loop_var_type, get_optional(loc)) - .release(); + *instr = std::make_unique(loop_var_type, array_view{from_list, dim}, + array_view{to_list, dim}, get_optional(loc)) + .release(); }); } @@ -679,14 +680,14 @@ tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condi tinytc_status_t tinytc_work_group_inst_create(tinytc_inst_t *instr, tinytc_work_group_operation_t operation, - tinytc_value_t operand, + tinytc_value_t operand, tinytc_data_type_t ty, const tinytc_location_t *loc) { if (instr == nullptr || operand == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { *instr = std::make_unique(enum_cast(operation), - operand, get_optional(loc)) + operand, ty, get_optional(loc)) .release(); }); } diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index b19c67ed..d4d88411 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -113,7 +113,7 @@ void check_memref_shape(memref_data_type *rt, std::int64_t ri, memref_data_type if (rt->shape(ri) != ot->shape(oi)) { auto extra_info = std::ostringstream{} << "Size of mode " << ri << " does not match operand mode " << oi << " [" - << rt->shape(oi) << "!=" << ot->shape(oi) << "]"; + << rt->shape(ri) << "!=" << ot->shape(oi) << "]"; throw compilation_error(loc, status::ir_invalid_shape, std::move(extra_info).str()); } } @@ -122,7 +122,7 @@ void check_memref_stride(memref_data_type *rt, std::int64_t ri, memref_data_type if (!is_dynamic_value(rt->stride(ri)) && rt->stride(ri) != ot->stride(oi)) { auto extra_info = std::ostringstream{} << "Stride of mode " << ri << " does not match operand stride " << oi << " [" - << rt->stride(oi) << "!=" << ot->stride(oi) << "]"; + << rt->stride(ri) << "!=" << ot->stride(oi) << "]"; throw compilation_error(loc, status::ir_invalid_stride, std::move(extra_info).str()); } } @@ -133,6 +133,23 @@ void check_memref_mode(memref_data_type *rt, std::int64_t ri, memref_data_type * check_memref_stride(rt, ri, ot, oi, loc); } +auto get_and_check_memref_type_addrspace(tinytc_value const &operand, tinytc_data_type_t ty, + location const &loc) + -> std::pair { + auto rt = dyn_cast(ty); + if (!rt) { + throw compilation_error(loc, status::ir_expected_memref); + } + auto ot = get_memref_type(loc, operand); + if (rt->element_data_ty() != ot->element_data_ty()) { + throw compilation_error(loc, {&operand}, status::ir_scalar_mismatch); + } + if (rt->addrspace() != ot->addrspace()) { + throw compilation_error(loc, {&operand}, status::ir_address_space_mismatch); + } + return {ot, rt}; +} + blas_a2_inst::blas_a2_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t beta, tinytc_value_t B, bool atomic, location const &lc) : standard_inst{tid}, atomic_(atomic) { @@ -615,24 +632,14 @@ expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, expanded_mode_(expanded_mode), static_expand_shape_(std::move(static_expand_shape0)) { op(0, op0); for (std::size_t i = 0; i < expand_shape0.size(); ++i) { + check_index_ty(loc(), *expand_shape0[i]); op(1 + i, expand_shape0[i]); } loc(lc); - auto rt = dyn_cast(ty); - if (!rt) { - throw compilation_error(loc(), status::ir_expected_memref); - } + auto [ot, rt] = get_and_check_memref_type_addrspace(operand(), ty, loc()); - auto m = get_memref_type(loc(), operand()); - if (rt->element_data_ty() != m->element_data_ty()) { - throw compilation_error(loc(), {&operand()}, status::ir_scalar_mismatch); - } - if (rt->addrspace() != m->addrspace()) { - throw compilation_error(loc(), {&operand()}, status::ir_address_space_mismatch); - } - - bool const range_ok = 0 <= expanded_mode_ && expanded_mode_ < m->dim(); + bool const range_ok = 0 <= expanded_mode_ && expanded_mode_ < ot->dim(); if (!range_ok) { throw compilation_error(loc(), {&operand()}, status::ir_out_of_bounds); } @@ -646,9 +653,9 @@ expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, } for (std::int64_t i = 0; i < expanded_mode_; ++i) { - check_memref_mode(rt, i, m, i, loc()); + check_memref_mode(rt, i, ot, i, loc()); } - auto stride = m->stride(expanded_mode_); + auto stride = ot->stride(expanded_mode_); for (std::size_t i = 0; i < static_expand_shape_.size(); ++i) { const auto mode = expanded_mode_ + i; if (rt->shape(mode) != static_expand_shape()[i]) { @@ -666,16 +673,16 @@ expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, ? dynamic : stride * rt->shape(mode); } - for (std::int64_t i = expanded_mode_ + 1; i < m->dim(); ++i) { - check_memref_mode(rt, i + static_expand_shape_.size() - 1, m, i, loc()); + for (std::int64_t i = expanded_mode_ + 1; i < ot->dim(); ++i) { + check_memref_mode(rt, i + static_expand_shape_.size() - 1, ot, i, loc()); } result(0) = value_node{ty, this, lc}; } -for_inst::for_inst(tinytc_value_t from0, tinytc_value_t to0, tinytc_value_t step0, - array_view init_values, tinytc_data_type_t loop_var_type, - location const &lc) +for_inst::for_inst(tinytc_data_type_t loop_var_type, tinytc_value_t from0, tinytc_value_t to0, + tinytc_value_t step0, array_view init_values, + array_view return_types, location const &lc) : loop_inst{IK::for_loop, (step0 ? 3 : 2) + static_cast(init_values.size()), static_cast(init_values.size())} { op(op_from, from0); @@ -687,15 +694,20 @@ for_inst::for_inst(tinytc_value_t from0, tinytc_value_t to0, tinytc_value_t step body().set_num_params(1 + init_values.size()); body().set_param(0, loop_var_type, lc); body().loc(lc); - for (std::size_t i = 0; i < init_values.size(); ++i) { - body().set_param(1 + i, init_values[i]->ty(), lc); - result(i) = value_node{init_values[i]->ty(), this, lc}; + for (std::size_t i = 0; i < return_types.size(); ++i) { + if (!isa(*return_types[i]) && !isa(*return_types[i]) && + !isa(*return_types[i])) { + throw compilation_error(loc(), status::ir_expected_coopmatrix_scalar_or_boolean); + } + body().set_param(1 + i, return_types[i], lc); + result(i) = value_node{return_types[i], this, lc}; + } + if (init_values.size() != return_types.size()) { + throw compilation_error(loc(), status::ir_init_return_type_mismatch); } for (std::size_t i = 0; i < init_values.size(); ++i) { - if (!isa(*init_values[i]->ty()) && - !isa(*init_values[i]->ty()) && - !isa(*init_values[i]->ty())) { - throw compilation_error(loc(), status::ir_expected_coopmatrix_scalar_or_boolean); + if (init_values[i]->ty() != return_types[i]) { + throw compilation_error(loc(), {init_values[i]}, status::ir_init_return_type_mismatch); } op(op_init() + i, init_values[i]); } @@ -704,20 +716,26 @@ for_inst::for_inst(tinytc_value_t from0, tinytc_value_t to0, tinytc_value_t step auto lvt = get_scalar_type(loc(), loop_var()); auto fromt = get_scalar_type(loc(), from()); auto tot = get_scalar_type(loc(), to()); - bool step_ok = true; + + if (!is_integer_type(lvt->ty())) { + throw compilation_error(loc(), status::ir_expected_int); + } + if (lvt->ty() != fromt->ty()) { + throw compilation_error(loc(), {&from()}, status::ir_scalar_mismatch); + } + if (lvt->ty() != tot->ty()) { + throw compilation_error(loc(), {&to()}, status::ir_scalar_mismatch); + } if (has_step()) { auto stept = get_scalar_type(loc(), step()); - step_ok = lvt->ty() == stept->ty(); - } - - if (!is_integer_type(lvt->ty()) || lvt->ty() != fromt->ty() || lvt->ty() != tot->ty() || - !step_ok) { - throw compilation_error(loc(), status::ir_scalar_mismatch); + if (lvt->ty() != stept->ty()) { + throw compilation_error(loc(), {&step()}, status::ir_scalar_mismatch); + } } } -foreach_inst::foreach_inst(array_view from, array_view to, - tinytc_data_type_t loop_var_type, location const &lc) +foreach_inst::foreach_inst(tinytc_data_type_t loop_var_type, array_view from, + array_view to, location const &lc) : loop_inst{IK::foreach_loop, static_cast(from.size() + to.size()), std::int64_t{0}} { std::int64_t op_no = 0; @@ -757,35 +775,24 @@ fuse_inst::fuse_inst(tinytc_value_t op0, std::int64_t from, std::int64_t to, tin op(0, op0); loc(lc); - auto rt = dyn_cast(ty); - if (!rt) { - throw compilation_error(loc(), status::ir_expected_memref); - } - - auto m = get_memref_type(loc(), operand()); - if (rt->element_data_ty() != m->element_data_ty()) { - throw compilation_error(loc(), {&operand()}, status::ir_scalar_mismatch); - } - if (rt->addrspace() != m->addrspace()) { - throw compilation_error(loc(), {&operand()}, status::ir_address_space_mismatch); - } + auto [ot, rt] = get_and_check_memref_type_addrspace(operand(), ty, loc()); - bool const range_ok = 0 <= from_ && from_ < to_ && to_ < m->dim(); + bool const range_ok = 0 <= from_ && from_ < to_ && to_ < ot->dim(); if (!range_ok) { throw compilation_error(loc(), status::ir_out_of_bounds); } for (std::int64_t i = 0; i < from_; ++i) { - check_memref_mode(rt, i, m, i, loc()); + check_memref_mode(rt, i, ot, i, loc()); } std::int64_t prod = 1; for (std::int64_t i = from_; i <= to_; ++i) { - if (is_dynamic_value(m->shape(i))) { + if (is_dynamic_value(ot->shape(i))) { prod = dynamic; break; } - prod *= m->shape(i); + prod *= ot->shape(i); } if (rt->shape(from_) != prod) { auto extra_info = std::ostringstream{} << "Size of mode " << from_ @@ -793,10 +800,10 @@ fuse_inst::fuse_inst(tinytc_value_t op0, std::int64_t from, std::int64_t to, tin << rt->shape(from_) << "!=" << prod << ")"; throw compilation_error(loc(), status::ir_invalid_shape, std::move(extra_info).str()); } - check_memref_stride(rt, from_, m, from_, loc()); + check_memref_stride(rt, from_, ot, from_, loc()); - for (std::int64_t i = to_ + 1; i < m->dim(); ++i) { - check_memref_mode(rt, i - to_ + from_, m, i, loc()); + for (std::int64_t i = to_ + 1; i < ot->dim(); ++i) { + check_memref_mode(rt, i - to_ + from_, ot, i, loc()); } result(0) = value_node{ty, this, lc}; @@ -1016,25 +1023,28 @@ size_inst::size_inst(tinytc_value_t op0, std::int64_t mode, tinytc_data_type_t t subview_inst::subview_inst(tinytc_value_t op0, array_view static_offsets0, array_view static_sizes0, array_view offsets0, array_view sizes0, - location const &lc) + tinytc_data_type_t ty, location const &lc) : standard_inst{IK::subview, static_cast(1 + offsets0.size() + sizes0.size())}, static_offsets_(std::move(static_offsets0)), static_sizes_(std::move(static_sizes0)) { op(0, op0); { std::size_t i = 1; for (auto const &val : offsets0) { + check_index_ty(loc(), *val); op(i++, val); } num_dyn_offsets_ = i - 1; for (auto const &val : sizes0) { + check_index_ty(loc(), *val); op(i++, val); } } loc(lc); - auto m = get_memref_type(loc(), operand()); - if (m->dim() != static_cast(static_offsets_.size()) || - m->dim() != static_cast(static_sizes_.size())) { + auto [ot, rt] = get_and_check_memref_type_addrspace(operand(), ty, loc()); + + if (ot->dim() != static_cast(static_offsets_.size()) || + ot->dim() != static_cast(static_sizes_.size())) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } if (std::count(static_offsets_.begin(), static_offsets_.end(), dynamic) != num_dyn_offsets_ || @@ -1043,24 +1053,27 @@ subview_inst::subview_inst(tinytc_value_t op0, array_view static_o throw compilation_error(loc(), status::ir_subview_mismatch); } - auto shape = std::vector{}; - auto stride = std::vector{}; - shape.reserve(m->dim()); - stride.reserve(m->dim()); - for (std::int64_t i = 0; i < m->dim(); ++i) { + std::int64_t ri = 0; + for (std::int64_t i = 0; i < ot->dim(); ++i) { auto offset = static_offsets_[i]; auto size = static_sizes_[i]; if ((offset < 0 && !is_dynamic_value(offset)) || (size < 0 && !is_dynamic_value(size))) { throw compilation_error(loc(), status::ir_invalid_slice); } if (size > 0 || is_dynamic_value(size)) { - shape.push_back(size); - stride.push_back(m->stride(i)); + if (rt->shape(ri) != size) { + auto extra_info = std::ostringstream{} << "Size of mode " << ri + << " does not match slice size [" + << rt->shape(ri) << "!=" << size << "]"; + throw compilation_error(loc(), status::ir_invalid_shape, + std::move(extra_info).str()); + } + check_memref_stride(rt, ri, ot, i, loc()); + ++ri; } } - auto result_ty = memref_data_type::get(m->element_data_ty(), shape, stride, m->addrspace()); - result(0) = value_node{result_ty, this, lc}; + result(0) = value_node{ty, this, lc}; } store_inst::store_inst(store_flag flag, tinytc_value_t val0, tinytc_value_t op0, @@ -1115,16 +1128,21 @@ sum_inst::sum_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, tinyt } work_group_inst::work_group_inst(work_group_operation operation, tinytc_value_t operand0, - location const &lc) + tinytc_data_type_t ty, location const &lc) : standard_inst{IK::work_group}, operation_(operation) { loc(lc); op(0, operand0); - if (!isa(*(operand().ty()))) { + if (!isa(*ty)) { throw compilation_error(loc(), status::ir_expected_scalar); } - result(0) = value_node{operand().ty(), this, lc}; + if (operand().ty() != ty) { + throw compilation_error(loc(), {&operand()}, + status::ir_operand_type_must_match_return_type); + } + + result(0) = value_node{ty, this, lc}; } yield_inst::yield_inst(array_view vals, location const &lc) diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 6c24b6b3..f2aca4d4 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -657,9 +657,9 @@ class for_inst : public loop_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::for_loop; } enum op_number { op_from = 0, op_to = 1, op_step = 2 }; - for_inst(tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, - array_view init_values, tinytc_data_type_t loop_var_type, - location const &loc = {}); + for_inst(tinytc_data_type_t loop_var_type, tinytc_value_t from, tinytc_value_t to, + tinytc_value_t step, array_view init_values, + array_view return_types, location const &loc = {}); inline auto from() -> tinytc_value & { return op(op_from); } inline auto from() const -> tinytc_value const & { return op(op_from); } @@ -688,8 +688,8 @@ class for_inst : public loop_inst { class foreach_inst : public loop_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::foreach_loop; } - foreach_inst(array_view from, array_view to, - tinytc_data_type_t loop_var_type, location const &lc = {}); + foreach_inst(tinytc_data_type_t loop_var_type, array_view from, + array_view to, location const &lc = {}); inline auto dim() const -> std::int64_t { return num_operands() / 2; } inline auto loop_vars() { return body().params(); } @@ -792,7 +792,7 @@ class subview_inst : public standard_inst { inline static bool classof(inst_node const &i) { return i.type_id() == IK::subview; } subview_inst(tinytc_value_t op, array_view static_offsets, array_view static_sizes, array_view offsets, - array_view sizes, location const &lc = {}); + array_view sizes, tinytc_data_type_t ty, location const &lc = {}); inline auto static_offsets() const -> array_view { return static_offsets_; } inline auto static_sizes() const -> array_view { return static_sizes_; } @@ -847,7 +847,7 @@ class sum_inst : public blas_a2_inst { class work_group_inst : public standard_inst<1, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::work_group; } - work_group_inst(work_group_operation operation, tinytc_value_t operand, + work_group_inst(work_group_operation operation, tinytc_value_t operand, tinytc_data_type_t ty, location const &lc = {}); inline auto operation() const -> work_group_operation { return operation_; } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index f1aea607..7c227855 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -53,14 +53,6 @@ namespace tinytc { - void check_type(tinytc_value_t val, tinytc_data_type_t ty, location &loc1, location &loc2) { - if (val->ty() != ty) { - auto loc = loc1; - loc.end = loc2.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); - } - } - void report_error(compiler_context const& cctx, compilation_error const& e) { if (e.extra_info().size() > 0) { auto what = (std::ostringstream{} << e.what() << " (" << e.extra_info() << ')').str(); @@ -581,24 +573,13 @@ ger_inst: ; for_inst: - FOR LOCAL_IDENTIFIER[loop_var] for_loop_var_type EQUALS var[from] COMMA var[to] optional_step optional_loop_carried_values[lcv] { - check_type($from, $for_loop_var_type, @from, @for_loop_var_type); - check_type($to, $for_loop_var_type, @to, @for_loop_var_type); - if ($optional_step) { - check_type($optional_step, $for_loop_var_type, @optional_step, @for_loop_var_type); - } + FOR LOCAL_IDENTIFIER[loop_var] for_loop_var_type EQUALS var[from] COMMA var[to] optional_step optional_loop_carried_values[lcv] { try { auto &[lcv_id, lcv_init, lcv_type] = $lcv; - if (lcv_init.size() != lcv_type.size()) { - throw parser::syntax_error(@lcv, "Length of init value list must match scalar type list"); - } - for (std::size_t i = 0; i < lcv_init.size(); ++i) { - check_type(lcv_init[i], lcv_type[i], @lcv, @lcv); - } location loc = @FOR; - loc.end = @for_loop_var_type.end; - auto inode = std::make_unique($from, $to, $optional_step, lcv_init, - $for_loop_var_type, loc); + loc.end = @lcv.end; + auto inode = std::make_unique($for_loop_var_type, $from, $to, $optional_step, + lcv_init, lcv_type, loc); ctx.push_scope(); auto &loop_var = inode->loop_var(); ctx.val($loop_var, loop_var, @loop_var); @@ -654,7 +635,7 @@ foreach_inst: location loc = @FOREACH; loc.end = @for_loop_var_type.end; auto inode = - std::make_unique($from, $to, $for_loop_var_type, loc); + std::make_unique($for_loop_var_type, $from, $to, loc); ctx.push_scope(); auto loop_vars = inode->loop_vars().begin(); for (std::int64_t i = 0; i < inode->dim(); ++i) { @@ -731,16 +712,8 @@ sum_inst: ; yield_inst: - YIELD optional_value_list[vals] COLON optional_return_type_list[tys] { - if ($vals.size() != $tys.size()) { - location loc = @vals; - loc.end = @tys.end; - throw syntax_error(loc, "Identifier and scalar type list must have the same length"); - } - for (std::size_t i = 0; i < $vals.size(); ++i) { - check_type($vals[i], $tys[i], @vals, @tys); - } - $$ = inst{std::make_unique(std::move($vals)).release()}; + YIELD LPAREN optional_value_list[vals] RPAREN { + $$ = inst{std::make_unique(std::move($vals), @yield_inst).release()}; } ; @@ -991,7 +964,6 @@ expand_shape: integer_constant_or_identifier: var { - check_type($var, get_scalar(ctx.cctx(), scalar_type::index), @var, @var); $$ = $var; } | INTEGER_CONSTANT { @@ -1058,7 +1030,6 @@ group_size_inst: if_inst: IF var[condition] optional_returned_values { - check_type($condition, get_boolean(ctx.cctx()), @condition, @condition); try { auto loc = @IF; loc.end = @optional_returned_values.end; @@ -1149,12 +1120,7 @@ subgroup_size_inst: ; subview_inst: - SUBVIEW var LSQBR optional_slice_list RSQBR COLON memref_type { - if ($var->ty() != $memref_type) { - auto loc = @var; - loc.end = @memref_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); - } + SUBVIEW var LSQBR optional_slice_list RSQBR COLON memref_type[ty] { try { auto static_offsets = std::vector{}; auto static_sizes = std::vector{}; @@ -1182,7 +1148,7 @@ subview_inst: } $$ = inst { std::make_unique(std::move($var), std::move(static_offsets), std::move(static_sizes), - std::move(offsets), std::move(sizes), @subview_inst) + std::move(offsets), std::move(sizes), std::move($ty), @subview_inst) .release() }; } catch (compilation_error const &e) { @@ -1221,10 +1187,10 @@ slice_size: work_group_inst: WORK_GROUP WORK_GROUP_OPERATION[operation] var[a] COLON data_type[ty] { - check_type($a, $ty, @a, @ty); try { $$ = inst { - std::make_unique($operation, std::move($a), @work_group_inst) + std::make_unique($operation, std::move($a), std::move($ty), + @work_group_inst) .release() }; } catch (compilation_error const &e) { diff --git a/src/pass/check_ir.cpp b/src/pass/check_ir.cpp index b2938f7f..29d1c240 100644 --- a/src/pass/check_ir.cpp +++ b/src/pass/check_ir.cpp @@ -21,11 +21,11 @@ void check_ir_pass::check_yield(region_node const ®, inst_node const &in, throw compilation_error(reg.loc(), yield_missing_status); } if (yield->num_operands() != in.num_results()) { - throw compilation_error(reg.loc(), status::ir_yield_mismatch); + throw compilation_error(yield->loc(), status::ir_yield_mismatch); } for (std::int64_t i = 0; i < in.num_results(); ++i) { if (yield->op(i).ty() != in.result(i).ty()) { - throw compilation_error(reg.loc(), status::ir_yield_mismatch); + throw compilation_error(yield->loc(), {&yield->op(i)}, status::ir_yield_mismatch); } } } diff --git a/src/pass/clone.cpp b/src/pass/clone.cpp index d323ea41..8f38d9ac 100644 --- a/src/pass/clone.cpp +++ b/src/pass/clone.cpp @@ -104,14 +104,18 @@ auto inst_cloner::operator()(ger_inst &in) -> std::unique_ptr { subs(&in.beta()), subs(&in.C()), in.atomic(), in.loc()); } auto inst_cloner::operator()(for_inst &in) -> std::unique_ptr { - return std::make_unique( - subs(&in.from()), subs(&in.to()), in.has_step() ? subs(&in.step()) : nullptr, - subs_value_range(in.iter_init()), in.body().param(0).ty(), in.loc()); + auto return_types = std::vector(in.num_results()); + for (std::int64_t i = 0; i < in.num_results(); ++i) { + return_types[i] = in.result(0).ty(); + } + return std::make_unique(in.body().param(0).ty(), subs(&in.from()), subs(&in.to()), + in.has_step() ? subs(&in.step()) : nullptr, + subs_value_range(in.iter_init()), return_types, in.loc()); } auto inst_cloner::operator()(foreach_inst &in) -> std::unique_ptr { - return std::make_unique(subs_value_range(in.from()), subs_value_range(in.to()), - in.body().param(0).ty(), in.loc()); + return std::make_unique(in.body().param(0).ty(), subs_value_range(in.from()), + subs_value_range(in.to()), in.loc()); } auto inst_cloner::operator()(hadamard_inst &in) -> std::unique_ptr { @@ -152,9 +156,9 @@ auto inst_cloner::operator()(subgroup_size_inst &in) -> std::unique_ptr std::unique_ptr { - return std::make_unique(subs(&in.operand()), in.static_offsets(), - in.static_sizes(), subs_value_range(in.offsets()), - subs_value_range(in.sizes()), in.loc()); + return std::make_unique( + subs(&in.operand()), in.static_offsets(), in.static_sizes(), subs_value_range(in.offsets()), + subs_value_range(in.sizes()), in.result(0).ty(), in.loc()); } auto inst_cloner::operator()(store_inst &in) -> std::unique_ptr { @@ -168,7 +172,8 @@ auto inst_cloner::operator()(sum_inst &in) -> std::unique_ptr { } auto inst_cloner::operator()(work_group_inst &in) -> std::unique_ptr { - return std::make_unique(in.operation(), subs(&in.operand()), in.loc()); + return std::make_unique(in.operation(), subs(&in.operand()), in.result(0).ty(), + in.loc()); } auto inst_cloner::operator()(yield_inst &in) -> std::unique_ptr { diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index c49282e3..1613e824 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -456,9 +456,7 @@ void dump_ir_pass::operator()(subview_inst const &s) { } *os_ << "]"; *os_ << " : "; - visit(*this, *s.operand().ty()); - *os_ << " ; -> "; - visit(*this, *s.result()->ty()); + visit(*this, *s.result(0).ty()); } void dump_ir_pass::operator()(store_inst const &e) { @@ -488,19 +486,15 @@ void dump_ir_pass::operator()(work_group_inst const &in) { *os_ << " = work_group." << to_string(in.operation()) << " "; dump_val(in.operand()); *os_ << " : "; - visit(*this, *in.operand().ty()); + visit(*this, *in.result(0).ty()); } void dump_ir_pass::operator()(yield_inst const &y) { - *os_ << "yield "; + *os_ << "yield ("; if (y.num_operands() > 0) { do_with_infix(y.op_begin(), y.op_end(), [this](auto const &i) { dump_val(i); }, ", "); - *os_ << " : "; - do_with_infix( - y.op_begin(), y.op_end(), [this](auto const &i) { visit(*this, *i.ty()); }, ", "); - } else { - *os_ << ":"; } + *os_ << ")"; } void dump_ir_pass::dump_region(region_node const ®) { diff --git a/src/pass/lower_foreach.cpp b/src/pass/lower_foreach.cpp index 7d14b0fd..3a10affc 100644 --- a/src/pass/lower_foreach.cpp +++ b/src/pass/lower_foreach.cpp @@ -66,15 +66,16 @@ auto foreach_generator::operator()(foreach_inst &in) -> inst { auto const make_inner_loop_nest = [&](region_builder &bb, value from1, value to1) { tinytc_region_t current_region = bb.get_region().get(); for (std::int64_t i = in.dim() - 1; i > 1; --i) { - auto for_i = std::make_unique( - &from[i], &to[i], nullptr, array_view{}, ity, in.loc()); + auto for_i = std::make_unique(ity, &from[i], &to[i], nullptr, + array_view{}, + array_view{}, in.loc()); cloner.set_subs(&loop_vars[i], &for_i->loop_var()); tinytc_region_t next_region = &for_i->body(); current_region->insts().push_back(for_i.release()); current_region = next_region; } region_builder{current_region}.for_loop( - from1, to1, ity, + ity, from1, to1, [&](region_builder &bb, value loop_var1) { cloner.set_subs(&loop_vars[1], loop_var1.get()); cloner.clone_region(in.body(), *bb.get_region()); diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index 13a9dcf6..67c0117b 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -56,7 +56,8 @@ void gemm_microkernel(region_builder &bb, transpose tA, transpose tB, bool atomi value c_acc) -> value { auto c_step = bb.add(make_constant(k_block_size, index_ty, loc)); auto return_values = bb.for_loop( - K0, K1, c_step, {c_acc}, index_ty, [&](region_builder &bb, array_view p) { + index_ty, K0, K1, c_step, {c_acc}, {c_acc->ty()}, + [&](region_builder &bb, array_view p) { const auto k = p[0]; value pos_a[2] = {m_block, k}; @@ -163,9 +164,9 @@ class linalg_generator { auto get_memref_type(value_node const &v) const -> const memref_data_type *; template - void add_foreach(array_view from, array_view to, - data_type loop_var_ty, F &&f, location const &loc = {}) { - auto fi = std::make_unique(std::move(from), std::move(to), loop_var_ty, loc); + void add_foreach(data_type loop_var_ty, array_view from, + array_view to, F &&f, location const &loc = {}) { + auto fi = std::make_unique(loop_var_ty, std::move(from), std::move(to), loc); auto bb = region_builder{&fi->body()}; f(bb, fi->loop_vars()); add(inst{fi.release()}); @@ -213,7 +214,7 @@ void linalg_generator::operator()(axpby_inst &in) { auto c0 = add(make_constant(0, index_ty, in.loc())); auto c_shape0 = add(make_size(&in.B(), 0, index_ty, in.loc())); add_foreach( - {c0.get()}, {c_shape0.get()}, index_ty, + index_ty, {c0.get()}, {c_shape0.get()}, [&](region_builder &bb, auto loop_vars) { auto a = bb.add(make_load(&in.A(), {&loop_vars[0]}, at->element_data_ty(), in.loc())); @@ -226,7 +227,7 @@ void linalg_generator::operator()(axpby_inst &in) { auto c_shape0 = add(make_size(&in.B(), 0, index_ty, in.loc())); auto c_shape1 = add(make_size(&in.B(), 1, index_ty, in.loc())); add_foreach( - {c0.get(), c0.get()}, {c_shape0.get(), c_shape1.get()}, index_ty, + index_ty, {c0.get(), c0.get()}, {c_shape0.get(), c_shape1.get()}, [&](region_builder &bb, auto loop_vars) { auto a_idx = std::array{&loop_vars[0], &loop_vars[1]}; if (in.tA() == transpose::T) { @@ -246,7 +247,7 @@ void linalg_generator::operator()(ger_inst &in) { auto c_shape0 = add(make_size(&in.C(), 0, index_ty, in.loc())); auto c_shape1 = add(make_size(&in.C(), 1, index_ty, in.loc())); add_foreach( - {c0.get(), c0.get()}, {c_shape0.get(), c_shape1.get()}, index_ty, + index_ty, {c0.get(), c0.get()}, {c_shape0.get(), c_shape1.get()}, [&](region_builder &bb, auto loop_vars) { auto at = get_memref_type(in.A()); auto bt = get_memref_type(in.B()); @@ -337,13 +338,14 @@ void linalg_generator::operator()(gemv_inst &in) { auto c_shape0 = add(make_size(&in.C(), 0, index_ty, in.loc())); auto ct = get_memref_type(in.C()); add_foreach( - {c0.get()}, {c_shape0.get()}, index_ty, + index_ty, {c0.get()}, {c_shape0.get()}, [&](region_builder &bb, auto loop_vars) { auto c_init = bb.add(make_constant_zero(ct->element_data_ty())); auto K = bb.add(make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, index_ty, in.loc())); auto c_acc = bb.for_loop( - c0, K, {}, {c_init}, index_ty, [&](region_builder &bb, array_view p) { + index_ty, c0, K, {}, {c_init}, {ct->element_data_ty()}, + [&](region_builder &bb, array_view p) { auto a_idx = std::array{&loop_vars[0], p[0]}; if (in.tA() == transpose::T) { std::swap(a_idx[0], a_idx[1]); @@ -367,7 +369,7 @@ void linalg_generator::operator()(hadamard_inst &in) { auto c0 = add(make_constant(0, index_ty, in.loc())); auto c_shape0 = add(make_size(&in.C(), 0, index_ty, in.loc())); add_foreach( - {c0.get()}, {c_shape0.get()}, index_ty, + index_ty, {c0.get()}, {c_shape0.get()}, [&](region_builder &bb, auto loop_vars) { auto at = get_memref_type(in.A()); auto bt = get_memref_type(in.B()); @@ -410,13 +412,14 @@ void linalg_generator::operator()(sum_inst &in) { auto c_init = bb.add(make_constant_zero(bt->element_data_ty(), in.loc())); auto acc = bb.for_loop( - from_index, c_trip_count, c_step, {c_init}, index_ty, + index_ty, from_index, c_trip_count, c_step, {c_init}, {bt->element_data_ty()}, [&](region_builder &bb, array_view args) { auto a = bb.add(make_load(&in.A(), {args[0]}, at->element_data_ty(), in.loc())); auto sum = mixed_precision_arithmetic(bb, arithmetic::add, args[1], a, in.loc()); bb.add(make_yield({sum}, in.loc())); }); - auto sum = bb.add(make_work_group(work_group_operation::reduce_add, acc[0], in.loc())); + auto sum = bb.add( + make_work_group(work_group_operation::reduce_add, acc[0], acc[0]->ty(), in.loc())); bb.if_condition( is_from_0, [&](region_builder &bb) { @@ -428,13 +431,14 @@ void linalg_generator::operator()(sum_inst &in) { auto c0 = add(make_constant(0, index_ty, in.loc())); auto c_shape0 = add(make_size(&in.B(), 0, index_ty, in.loc())); add_foreach( - {c0.get()}, {c_shape0.get()}, index_ty, + index_ty, {c0.get()}, {c_shape0.get()}, [&](region_builder &bb, auto loop_vars) { auto K = bb.add(make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, index_ty, in.loc())); auto c_init = bb.add(make_constant_zero(bt->element_data_ty())); auto acc = bb.for_loop( - c0, K, {}, {c_init}, index_ty, [&](region_builder &bb, array_view args) { + index_ty, c0, K, {}, {c_init}, {bt->element_data_ty()}, + [&](region_builder &bb, array_view args) { auto index_list = std::array{&loop_vars[0], args[0]}; if (in.tA() == transpose::T) { std::swap(index_list[0], index_list[1]); diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index 56dff324..2d9ec76e 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -111,12 +111,18 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( auto bb = region_builder{fn_body}; auto gid = bb.add(make_group_id(ctx_, my_loc())); + auto at = + get_memref(ty_, A_static_sizes, {1, ldA}, address_space::global, my_loc()); + auto bt = + get_memref(ty_, B_static_sizes, {1, ldB}, address_space::global, my_loc()); + auto ct = + get_memref(ty_, C_static_sizes, {1, ldC}, address_space::global, my_loc()); auto a = bb.add(make_subview(params[1], static_offsets, A_static_sizes, - array_view{gid}, {}, my_loc())); + array_view{gid}, {}, at, my_loc())); auto b = bb.add(make_subview(params[2], static_offsets, B_static_sizes, - array_view{gid}, {}, my_loc())); + array_view{gid}, {}, bt, my_loc())); auto c = bb.add(make_subview(params[4], static_offsets, C_static_sizes, - array_view{gid}, {}, my_loc())); + array_view{gid}, {}, ct, my_loc())); auto beta = is_beta_nonzero ? params[3] : bb.add(make_constant_zero(ty_, my_loc())); bb.add(make_gemm(tA_, tB_, false, params[0], std::move(a), std::move(b), beta, std::move(c), my_loc())); diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index b2c3d2bb..1ef845e3 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -119,10 +119,14 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( auto const static_gemm = [&](region_builder &bb) { auto const A_static_sizes = std::array{M_block_size, K}; auto const C_static_sizes = std::array{M_block_size, N}; + auto at = + get_memref(ty_, A_static_sizes, {1, ldA}, address_space::global, my_loc()); + auto ct = + get_memref(ty_, C_static_sizes, {1, ldC}, address_space::global, my_loc()); auto a = bb.add( - make_subview(A, static_offsets, A_static_sizes, offsets, {}, my_loc())); + make_subview(A, static_offsets, A_static_sizes, offsets, {}, at, my_loc())); auto c = bb.add( - make_subview(C, static_offsets, C_static_sizes, offsets, {}, my_loc())); + make_subview(C, static_offsets, C_static_sizes, offsets, {}, ct, my_loc())); bb.add(make_gemm(transpose::N, transpose::N, false, alpha, a, B, beta, c, my_loc())); }; @@ -130,10 +134,14 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( auto const A_static_sizes = std::array{dynamic, K}; auto const C_static_sizes = std::array{dynamic, N}; auto const sizes = array_view(dyn_block_size); - auto a = bb.add( - make_subview(A, static_offsets, A_static_sizes, offsets, sizes, my_loc())); - auto c = bb.add( - make_subview(C, static_offsets, C_static_sizes, offsets, sizes, my_loc())); + auto at = + get_memref(ty_, A_static_sizes, {1, ldA}, address_space::global, my_loc()); + auto ct = + get_memref(ty_, C_static_sizes, {1, ldC}, address_space::global, my_loc()); + auto a = bb.add(make_subview(A, static_offsets, A_static_sizes, offsets, sizes, + at, my_loc())); + auto c = bb.add(make_subview(C, static_offsets, C_static_sizes, offsets, sizes, + ct, my_loc())); bb.add(make_gemm(transpose::N, transpose::N, false, alpha, a, B, beta, c, my_loc())); }; diff --git a/test/codegen/axpby1.ir b/test/codegen/axpby1.ir index 5a19eff2..0986ad5c 100644 --- a/test/codegen/axpby1.ir +++ b/test/codegen/axpby1.ir @@ -22,12 +22,12 @@ func @axpby3(%alpha: f32, %A: memref, %B: memref) %lb = constant 0 : index %ub = constant 5 : index for %i=%lb,%ub { - %A0 = subview %A[0:48,0:48,0:4,%i] : memref - %B0 = subview %B[0:48,0:48,0:4,%i] : memref + %A0 = subview %A[0:48,0:48,0:4,%i] : memref + %B0 = subview %B[0:48,0:48,0:4,%i] : memref %ub1 = constant 4 : index for %j=%lb,%ub1 { - %A1 = subview %A0[0:48,0:48,%j] : memref - %B1 = subview %B0[0:48,0:48,%j] : memref + %A1 = subview %A0[0:48,0:48,%j] : memref + %B1 = subview %B0[0:48,0:48,%j] : memref axpby.t %alpha, %A1, %z, %B1 } } diff --git a/test/codegen/dope_vector0.ir b/test/codegen/dope_vector0.ir index 71d6bcae..53cd8c33 100644 --- a/test/codegen/dope_vector0.ir +++ b/test/codegen/dope_vector0.ir @@ -3,14 +3,14 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @kernel(%K0: memref, %offset: index, %size: index) { - %0 = subview %K0[4:%size, %offset] : memref + %0 = subview %K0[4:%size, %offset] : memref ; CHECK: void kernel({{.*}} ; CHECK-NEXT: global float* x = K0 + 4ll * 1 + offset * K0_stride1; ; CHECK-NEXT: long x_shape0 = size; } func @kernel2(%K0: memref, %offset: index, %size: index) { - %0 = subview %K0[%offset, 4:%size] : memref + %0 = subview %K0[%offset, 4:%size] : memref> ; CHECK: void kernel2({{.*}} ; CHECK-NEXT: global float* x = K0 + offset * 1 + 4ll * K0_stride1; ; CHECK-NEXT: long x_shape0 = size; diff --git a/test/codegen/for.ir b/test/codegen/for.ir index 0e6a8fd1..cbf47dad 100644 --- a/test/codegen/for.ir +++ b/test/codegen/for.ir @@ -23,7 +23,7 @@ func @for2(%fib: memref) { %f1 = constant 1 : i64 %fn_1, %fn = for %n:i32=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) { %fn = arith.add %fn_2, %fn_1 : i64 - yield %fn_1, %fn : i64, i64 + yield (%fn_1, %fn) } store %fn, %fib[] ; CHECK-LABEL: void for2({{.*}} diff --git a/test/codegen/if.ir b/test/codegen/if.ir index dbb44400..3d45f546 100644 --- a/test/codegen/if.ir +++ b/test/codegen/if.ir @@ -32,9 +32,9 @@ func @if2(%0: i32) { %c16 = constant 16 : i32 %1 = cmp.lt %0, %c16 : bool if %1 -> () { - yield : + yield () } else { - yield : + yield () } ; CHECK: if (x1) { ; CHECK-NEXT: } else { @@ -45,9 +45,9 @@ func @if3(%0: i32) { %c16 = constant 16 : i32 %1 = cmp.lt %0, %c16 : bool %x = if %1 -> (i32) { - yield %0 : i32 + yield (%0) } else { - yield %c16 : i32 + yield (%c16) } ; CHECK: int x2; ; CHECK-NEXT: if (x1) { @@ -64,16 +64,16 @@ func @if4(%0: i32) { if %1 { } %one = constant 1.0 : f32 - yield %0, %one : i32, f32 + yield (%0, %one) } else { %z = if %1 -> (f32) { %one = constant 1.0 : f32 - yield %one : f32 + yield (%one) } else { %zero = constant 0.0 : f32 - yield %zero : f32 + yield (%zero) } - yield %c16, %z : i32, f32 + yield (%c16, %z) } ; CHECK: int x2; ; CHECK-NEXT: float y; diff --git a/test/codegen/type_mismatch1.ir b/test/codegen/type_mismatch1.ir index 230ab15f..c7e05a22 100644 --- a/test/codegen/type_mismatch1.ir +++ b/test/codegen/type_mismatch1.ir @@ -4,7 +4,7 @@ ; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s func @kernel(%K0: memref, %x: index, %y: index) { %z = constant 0 : index - %0 = subview %K0[0:%x] : memref + %0 = subview %K0[0:%x] : memref %1 = subview %0[0:%y] : memref %2 = load %1[%z] : f64 %3 = load %1[%z] : f64 diff --git a/test/opt/check-ir/subview.ir b/test/opt/check-ir/subview.ir index 52e664e2..c1736a64 100644 --- a/test/opt/check-ir/subview.ir +++ b/test/opt/check-ir/subview.ir @@ -7,26 +7,26 @@ ; CHECK: func @t1({{.*}} func @t1(%0: memref) { - %1 = subview %0[4:8,8:4] : memref + %1 = subview %0[4:8,8:4] : memref> } func @t2(%0: memref, %1: index) { - %2 = subview %0[2:4,%1] : memref + %2 = subview %0[2:4,%1] : memref } func @t3(%0: memref, %1: index) { - %2 = subview %0[2:4,%1:0] : memref + %2 = subview %0[2:4,%1:0] : memref } func @t4(%0: memref, %1: index) { - %2 = subview %0[2:4,%1:1] : memref + %2 = subview %0[2:4,%1:1] : memref> } func @t5(%0: memref, %1: index) { - %2 = subview %0[%1:4] : memref + %2 = subview %0[%1:4] : memref } func @t6(%0: memref, %1: index) { - %2 = subview %0[%1:%1] : memref + %2 = subview %0[%1:%1] : memref } func @t7(%0: memref, %1: index) { - %2 = subview %0[2:4, %1:%1, 6:7] : memref + %2 = subview %0[2:4, %1:%1, 6:7] : memref> } func @t8(%0: memref>, %1: index) { - %2 = subview %0[2:4, %1:%1, 6:7] : memref> + %2 = subview %0[2:4, %1:%1, 6:7] : memref> } diff --git a/test/opt/constant-propagation.ir b/test/opt/constant-propagation.ir index 0df6f23e..fd2ff580 100644 --- a/test/opt/constant-propagation.ir +++ b/test/opt/constant-propagation.ir @@ -41,7 +41,7 @@ func @known_loop_iter_args() { %c5 = constant 5 : index %0 = arith.add %c1, %c5 : index %2 = for %i=%c1,%c5 init(%1=%0) -> (index) { - yield %1 : index + yield (%1) } ; CHECK-LABEL: func @known_loop_iter_args({{.*}} ; CHECK-NEXT: %c1 = constant 1 : index @@ -49,7 +49,7 @@ func @known_loop_iter_args() { ; CHECK-NEXT: %0 = constant 6 : index ; CHECK-NEXT: %1 = arith.add %c1, %c5 : index ; CHECK-NEXT: %3 = for %i:index=%c1,%c5 init(%2=%0) -> (index) { -; CHECK-NEXT: yield %2 : index +; CHECK-NEXT: yield (%2) ; CHECK-NEXT: } } diff --git a/test/opt/dead-code-elimination.ir b/test/opt/dead-code-elimination.ir index ca5b644e..6bd29473 100644 --- a/test/opt/dead-code-elimination.ir +++ b/test/opt/dead-code-elimination.ir @@ -25,20 +25,20 @@ func @dead_if_with_yield(%a: memref) { %c0 = constant false : bool %0 = if %c0 -> (f64) { %c42 = constant 42.0 : f64 - yield %c42 : f64 + yield (%c42) } else { %c43 = constant 43.0 : f64 - yield %c43 : f64 + yield (%c43) } store %0, %a[] ; Cannot eliminate if that returns results currently ; CHECK-LABEL: func @dead_if_with_yield({{.*}} ; CHECK: %0 = if %c0 -> (f64) { ; CHECK-NEXT: %c42 = constant 0x1.5p+5 : f64 -; CHECK-NEXT: yield %c42 : f64 +; CHECK-NEXT: yield (%c42) ; CHECK-NEXT: } else { ; CHECK-NEXT: %c43 = constant 0x1.58p+5 : f64 -; CHECK-NEXT: yield %c43 : f64 +; CHECK-NEXT: yield (%c43) ; CHECK-NEXT: } ; CHECK-NEXT: store %0, %a[] } diff --git a/test/opt/insert-barrier.ir b/test/opt/insert-barrier.ir index 023b5378..cb718c0c 100644 --- a/test/opt/insert-barrier.ir +++ b/test/opt/insert-barrier.ir @@ -70,7 +70,7 @@ func @respect_manual_barrier(%a: f32, %b: f32, %A: memref, %B: memref, %C: memref) { %B = alloca : memref - %0 = subview %B[0:8,0:8] : memref + %0 = subview %B[0:8,0:8] : memref axpby.n %a, %B, %b, %C axpby.n %a, %A, %b, %0 ; CHECK-LABEL: func @war_alias({{.*}} diff --git a/test/opt/insert-lifetime-stop.ir b/test/opt/insert-lifetime-stop.ir index 3b318188..090de005 100644 --- a/test/opt/insert-lifetime-stop.ir +++ b/test/opt/insert-lifetime-stop.ir @@ -36,7 +36,7 @@ func @use_alias(%a: f32, %A: memref, %C: memref) { ; CHECK-LABEL: func @use_alias{{.*}} %B = alloca : memref %0 = fuse %B[1,3] : memref - %1 = subview %0[0:8,0:8] : memref + %1 = subview %0[0:8,0:8] : memref,local> gemm.n.n %a, %A, %1, %a, %C ; CHECK: gemm.n.n{{.*}} ; CHECK-NEXT: lifetime_stop %B diff --git a/test/spv/for.ir b/test/spv/for.ir index d7ce4aa2..659fe509 100644 --- a/test/spv/for.ir +++ b/test/spv/for.ir @@ -46,7 +46,7 @@ func @for2() { %f1 = constant 1 : i64 %fn_1, %fn = for %n:i32=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) { %fn = arith.add %fn_2, %fn_1 : i64 - yield %fn_1, %fn : i64, i64 + yield (%fn_1, %fn) } %neg_fn = arith.neg %fn : i64 ; CHECK: %[[#]] = OpFunction {{.*}} @@ -77,7 +77,7 @@ func @for3() subgroup_size(16) { %m_init = constant 1 : coopmatrix %m = for %n:i16=%from,%to init(%m_iter=%m_init) -> (coopmatrix) { %m_update = arith.add %m_iter, %m_init : coopmatrix - yield %m_update : coopmatrix + yield (%m_update) } %neg_m = arith.neg %m : coopmatrix ; CHECK: %[[#]] = OpFunction {{.*}} diff --git a/test/spv/if.ir b/test/spv/if.ir index 7dfc728b..1a4182e0 100644 --- a/test/spv/if.ir +++ b/test/spv/if.ir @@ -35,9 +35,9 @@ func @if0(%0: i32) { func @if1() { %c1 = constant true : bool if %c1 -> (){ - yield : + yield () } else { - yield : + yield () } ; Just check that it does not crash ; CHECK: %[[#]] = OpFunction {{.*}} @@ -47,15 +47,15 @@ func @if2(%0: i32) { %c1 = constant true : bool %x = if %c1 -> (i32) { %1 = if %c1 -> (i32) { - yield %0 : i32 + yield (%0) } else { %c0 = constant 0 : i32 - yield %c0 : i32 + yield (%c0) } - yield %1 : i32 + yield (%1) } else { %1 = arith.not %0 : i32 - yield %1 : i32 + yield (%1) } %y = arith.not %x : i32 ; CHECK: %[[#]] = OpFunction {{.*}} @@ -85,10 +85,10 @@ func @if3() subgroup_size(16) { %c1 = constant true : bool %y, %x = if %c1 -> (bool,coopmatrix) { %0 = constant 1.0 : coopmatrix - yield %c1, %0 : bool, coopmatrix + yield (%c1, %0) } else { %1 = constant 0.0 : coopmatrix - yield %c1, %1 : bool, coopmatrix + yield (%c1, %1) } %z = arith.neg %x : coopmatrix ; CHECK: %[[#]] = OpFunction {{.*}} diff --git a/test/spv/subview.ir b/test/spv/subview.ir index 1548c147..9d6bfd80 100644 --- a/test/spv/subview.ir +++ b/test/spv/subview.ir @@ -11,8 +11,8 @@ ; CHECK: %[[#I64_C4:]] = OpConstant %[[#I64]] 4 func @sv1(%K0: memref, %offset: index, %size: index) { - %0 = subview %K0[4:%size, %offset] : memref - %1 = subview %K0[%offset, 4:%size] : memref + %0 = subview %K0[4:%size, %offset] : memref + %1 = subview %K0[%offset, 4:%size] : memref> %2 = size %0[0] : index %3 = arith.not %2 : index %4 = size %1[0] : index From b4a7b993fa93d300003c631f4d22a3917a869159 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 22 Nov 2024 08:51:31 +0100 Subject: [PATCH 125/297] Replace several instructions by builtin instruction Signed-off-by: Carsten Uphoff --- docs/api/builder_capi.rst | 63 +++++---------- docs/api/builder_capi.yaml | 9 +-- docs/api/builder_cxxapi.rst | 79 +++++++----------- docs/api/builder_cxxapi.yaml | 13 ++- docs/manual/tensor-ir.rst | 103 +++++++++--------------- examples/benchmark/main.cpp | 2 +- examples/matrix_chain/test_ader.cpp | 3 +- examples/matrix_chain/test_volume.cpp | 2 +- include/tinytc/tinytc.h | 108 +++++-------------------- include/tinytc/tinytc.hpp | 111 +++++++------------------- include/tinytc/types.h | 99 +++++++++++++---------- include/tinytc/types.hpp | 11 +++ src/error.cpp | 2 + src/inst.cpp | 87 +++++++------------- src/node/inst_node.cpp | 52 ++++++++++-- src/node/inst_node.hpp | 98 +++++------------------ src/parser/lexer.re | 14 ++-- src/parser/parser_impl.yy | 52 +++--------- src/pass/clone.cpp | 27 +------ src/pass/clone.hpp | 7 +- src/pass/convert_to_opencl.cpp | 64 +++++---------- src/pass/convert_to_opencl.hpp | 7 +- src/pass/dump_ir.cpp | 36 ++------- src/pass/dump_ir.hpp | 7 +- src/pass/lower_foreach.cpp | 6 +- src/pass/lower_linalg.cpp | 10 +-- src/recipe/small_gemm_batched.cpp | 3 +- src/recipe/tall_and_skinny.cpp | 2 +- src/spv/converter.cpp | 58 +++++++------- src/spv/converter.hpp | 7 +- test/codegen/load.ir | 4 +- test/codegen/store.ir | 2 +- test/codegen/subgroup.ir | 8 +- test/opt/check-ir/nesting2.ir | 4 +- test/opt/insert-barrier.ir | 4 +- test/spv/builtin.ir | 12 +-- 36 files changed, 424 insertions(+), 752 deletions(-) diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index 8194f2f3..98aaec96 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -16,6 +16,8 @@ Common * :ref:`tinytc_arithmetic_unary_t` + * :ref:`tinytc_builtin_t` + * :ref:`tinytc_checked_flag_t` * :ref:`tinytc_cmp_condition_t` @@ -42,6 +44,8 @@ Common * :ref:`tinytc_arithmetic_unary_to_string` + * :ref:`tinytc_builtin_to_string` + * :ref:`tinytc_checked_flag_to_string` * :ref:`tinytc_cmp_condition_to_string` @@ -112,6 +116,11 @@ tinytc_arithmetic_unary_t .. doxygenenum:: tinytc_arithmetic_unary_t +tinytc_builtin_t +................ + +.. doxygenenum:: tinytc_builtin_t + tinytc_checked_flag_t ..................... @@ -173,6 +182,11 @@ tinytc_arithmetic_unary_to_string .. doxygenfunction:: tinytc_arithmetic_unary_to_string +tinytc_builtin_to_string +........................ + +.. doxygenfunction:: tinytc_builtin_to_string + tinytc_checked_flag_to_string ............................. @@ -398,6 +412,8 @@ Instruction * :ref:`tinytc_arith_unary_inst_create` + * :ref:`tinytc_builtin_inst_create` + * :ref:`tinytc_cast_inst_create` * :ref:`tinytc_cmp_inst_create` @@ -436,30 +452,18 @@ Instruction * :ref:`tinytc_ger_inst_create` - * :ref:`tinytc_group_id_inst_create` - - * :ref:`tinytc_group_size_inst_create` - * :ref:`tinytc_hadamard_inst_create` * :ref:`tinytc_if_inst_create` * :ref:`tinytc_load_inst_create` - * :ref:`tinytc_num_subgroups_inst_create` - * :ref:`tinytc_parallel_inst_create` * :ref:`tinytc_size_inst_create` * :ref:`tinytc_store_inst_create` - * :ref:`tinytc_subgroup_id_inst_create` - - * :ref:`tinytc_subgroup_local_id_inst_create` - - * :ref:`tinytc_subgroup_size_inst_create` - * :ref:`tinytc_subview_inst_create` * :ref:`tinytc_sum_inst_create` @@ -497,6 +501,11 @@ tinytc_arith_unary_inst_create .. doxygenfunction:: tinytc_arith_unary_inst_create +tinytc_builtin_inst_create +.......................... + +.. doxygenfunction:: tinytc_builtin_inst_create + tinytc_cast_inst_create ....................... @@ -592,16 +601,6 @@ tinytc_ger_inst_create .. doxygenfunction:: tinytc_ger_inst_create -tinytc_group_id_inst_create -........................... - -.. doxygenfunction:: tinytc_group_id_inst_create - -tinytc_group_size_inst_create -............................. - -.. doxygenfunction:: tinytc_group_size_inst_create - tinytc_hadamard_inst_create ........................... @@ -617,11 +616,6 @@ tinytc_load_inst_create .. doxygenfunction:: tinytc_load_inst_create -tinytc_num_subgroups_inst_create -................................ - -.. doxygenfunction:: tinytc_num_subgroups_inst_create - tinytc_parallel_inst_create ........................... @@ -637,21 +631,6 @@ tinytc_store_inst_create .. doxygenfunction:: tinytc_store_inst_create -tinytc_subgroup_id_inst_create -.............................. - -.. doxygenfunction:: tinytc_subgroup_id_inst_create - -tinytc_subgroup_local_id_inst_create -.................................... - -.. doxygenfunction:: tinytc_subgroup_local_id_inst_create - -tinytc_subgroup_size_inst_create -................................ - -.. doxygenfunction:: tinytc_subgroup_size_inst_create - tinytc_subview_inst_create .......................... diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index 27a0017c..5c255ab6 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -6,6 +6,7 @@ Builder C-API: - tinytc_address_space_t - tinytc_arithmetic_t - tinytc_arithmetic_unary_t + - tinytc_builtin_t - tinytc_checked_flag_t - tinytc_cmp_condition_t - tinytc_matrix_use_t @@ -19,6 +20,7 @@ Builder C-API: - tinytc_address_space_to_string - tinytc_arithmetic_to_string - tinytc_arithmetic_unary_to_string + - tinytc_builtin_to_string - tinytc_checked_flag_to_string - tinytc_cmp_condition_to_string - tinytc_matrix_use_to_string @@ -65,6 +67,7 @@ Builder C-API: - tinytc_axpby_inst_create - tinytc_arith_inst_create - tinytc_arith_unary_inst_create + - tinytc_builtin_inst_create - tinytc_cast_inst_create - tinytc_cmp_inst_create - tinytc_constant_inst_create_boolean @@ -84,18 +87,12 @@ Builder C-API: - tinytc_gemm_inst_create - tinytc_gemv_inst_create - tinytc_ger_inst_create - - tinytc_group_id_inst_create - - tinytc_group_size_inst_create - tinytc_hadamard_inst_create - tinytc_if_inst_create - tinytc_load_inst_create - - tinytc_num_subgroups_inst_create - tinytc_parallel_inst_create - tinytc_size_inst_create - tinytc_store_inst_create - - tinytc_subgroup_id_inst_create - - tinytc_subgroup_local_id_inst_create - - tinytc_subgroup_size_inst_create - tinytc_subview_inst_create - tinytc_sum_inst_create - tinytc_work_group_inst_create diff --git a/docs/api/builder_cxxapi.rst b/docs/api/builder_cxxapi.rst index 2187115a..1a5fe645 100644 --- a/docs/api/builder_cxxapi.rst +++ b/docs/api/builder_cxxapi.rst @@ -16,6 +16,8 @@ Common * :ref:`arithmetic_unary` + * :ref:`builtin` + * :ref:`cmp_condition` * :ref:`matrix_use` @@ -38,6 +40,8 @@ Common * :ref:`to_string(arithmetic_unary)` + * :ref:`to_string(builtin)` + * :ref:`to_string(checked_flag)` * :ref:`to_string(cmp_condition)` @@ -86,6 +90,11 @@ arithmetic_unary .. doxygenenum:: tinytc::arithmetic_unary +builtin +....... + +.. doxygenenum:: tinytc::builtin + cmp_condition ............. @@ -139,6 +148,11 @@ to_string(arithmetic_unary) .. doxygenfunction:: tinytc::to_string(arithmetic_unary) +to_string(builtin) +.................. + +.. doxygenfunction:: tinytc::to_string(builtin) + to_string(checked_flag) ....................... @@ -323,9 +337,11 @@ Instruction * :ref:`make_axpby` - * :ref:`make_arith(arithmetic,value,value,location const&)` + * :ref:`make_arith(arithmetic,value,value,data_type,location const&)` + + * :ref:`make_arith(arithmetic_unary,value,data_type,location const&)` - * :ref:`make_arith(arithmetic_unary,value,location const&)` + * :ref:`make_builtin` * :ref:`make_cast` @@ -367,30 +383,18 @@ Instruction * :ref:`make_ger` - * :ref:`make_group_id` - - * :ref:`make_group_size` - * :ref:`make_hadamard` * :ref:`make_if` * :ref:`make_load` - * :ref:`make_num_subgroups` - * :ref:`make_parallel` * :ref:`make_size` * :ref:`make_store` - * :ref:`make_subgroup_id` - - * :ref:`make_subgroup_local_id` - - * :ref:`make_subgroup_size` - * :ref:`make_subview` * :ref:`make_sum` @@ -416,15 +420,20 @@ make_axpby .. doxygenfunction:: tinytc::make_axpby -make_arith(arithmetic,value,value,location const&) -.................................................. +make_arith(arithmetic,value,value,data_type,location const&) +............................................................ + +.. doxygenfunction:: tinytc::make_arith(arithmetic,value,value,data_type,location const&) -.. doxygenfunction:: tinytc::make_arith(arithmetic,value,value,location const&) +make_arith(arithmetic_unary,value,data_type,location const&) +............................................................ -make_arith(arithmetic_unary,value,location const&) -.................................................. +.. doxygenfunction:: tinytc::make_arith(arithmetic_unary,value,data_type,location const&) + +make_builtin +............ -.. doxygenfunction:: tinytc::make_arith(arithmetic_unary,value,location const&) +.. doxygenfunction:: tinytc::make_builtin make_cast ......... @@ -526,16 +535,6 @@ make_ger .. doxygenfunction:: tinytc::make_ger -make_group_id -............. - -.. doxygenfunction:: tinytc::make_group_id - -make_group_size -............... - -.. doxygenfunction:: tinytc::make_group_size - make_hadamard ............. @@ -551,11 +550,6 @@ make_load .. doxygenfunction:: tinytc::make_load -make_num_subgroups -.................. - -.. doxygenfunction:: tinytc::make_num_subgroups - make_parallel ............. @@ -571,21 +565,6 @@ make_store .. doxygenfunction:: tinytc::make_store -make_subgroup_id -................ - -.. doxygenfunction:: tinytc::make_subgroup_id - -make_subgroup_local_id -...................... - -.. doxygenfunction:: tinytc::make_subgroup_local_id - -make_subgroup_size -.................. - -.. doxygenfunction:: tinytc::make_subgroup_size - make_subview ............ diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index c79be440..7e4d2b57 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -6,6 +6,7 @@ Builder C++-API: - tinytc::address_space - tinytc::arithmetic - tinytc::arithmetic_unary + - tinytc::builtin - tinytc::cmp_condition - tinytc::matrix_use - tinytc::scalar_type @@ -17,6 +18,7 @@ Builder C++-API: - tinytc::to_string(address_space) - tinytc::to_string(arithmetic) - tinytc::to_string(arithmetic_unary) + - tinytc::to_string(builtin) - tinytc::to_string(checked_flag) - tinytc::to_string(cmp_condition) - tinytc::to_string(matrix_use) @@ -54,8 +56,9 @@ Builder C++-API: function: - tinytc::make_alloca - tinytc::make_axpby - - tinytc::make_arith(arithmetic,value,value,location const&) - - tinytc::make_arith(arithmetic_unary,value,location const&) + - tinytc::make_arith(arithmetic,value,value,data_type,location const&) + - tinytc::make_arith(arithmetic_unary,value,data_type,location const&) + - tinytc::make_builtin - tinytc::make_cast - tinytc::make_cmp - tinytc::make_constant(bool,data_type,location const&) @@ -76,18 +79,12 @@ Builder C++-API: - tinytc::make_gemm - tinytc::make_gemv - tinytc::make_ger - - tinytc::make_group_id - - tinytc::make_group_size - tinytc::make_hadamard - tinytc::make_if - tinytc::make_load - - tinytc::make_num_subgroups - tinytc::make_parallel - tinytc::make_size - tinytc::make_store - - tinytc::make_subgroup_id - - tinytc::make_subgroup_local_id - - tinytc::make_subgroup_size - tinytc::make_subview - tinytc::make_sum - tinytc::make_work_group diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index ec7d6fee..9aac2c82 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -823,6 +823,32 @@ Attribute Description .local Ensure that local memory accesses become visible to the work-group. ========= ====================================================================================== +Builtin (mixed) +............... + +.. code:: abnf + + mixed-builtin-type = ".group_id" / + ".group_size" / + ".num_subgroups" / + ".subgroup_size" + value-instruction =/ "builtin" mixed-builtin-type ":" integer-type + +Overview +~~~~~~~~ + +Returns a builtin value. +The following table shows the builtins' description and the types that are returned. + +============== ===== ==================================================================== +Builtin Type Description +============== ===== ==================================================================== +.group_id index Returns the group id, an integer inbetween 0 and the group size - 1 +.group_size index Returns the group size +.num_subgroups i32 Returns the number of subgroups the work-group is divided in +.subgroup_size i32 Returns the subgroup size +============== ===== ==================================================================== + Cast .... @@ -1273,30 +1299,6 @@ The memref type of the result must conform with the following rules: fuse %0[1,2] : memref> ; %0: memref> fuse %0[0,1] : memref> ; %0: memref> -Group id -........ - -.. code:: abnf - - value-instruction =/ "group_id" ":" "index" - -Overview -~~~~~~~~ - -Returns the group id, an integer of type "index" inbetween 0 and the group size - 1. - -Group size -.......... - -.. code:: abnf - - value-instruction =/ "group_size" ":" "index" - -Overview -~~~~~~~~ - -Returns the group size, an integer of type "index". - If .. @@ -1371,18 +1373,6 @@ Examples: #. ``load %0[%1] : memref ; %0: group>`` #. ``load %0[%1] : memref ; %0: group, offset: ?>`` -Number of subgroups -................... - -.. code:: abnf - - value-instruction =/ "num_subgroups" ":" "i32" - -Overview -~~~~~~~~ - -Returns the number of subgroups the work-group is divided in; i32 integer. - Size .... @@ -1407,19 +1397,6 @@ Op.-No. Type Description 2 integer-constant mode index ======= ================ =========== -Subgroup size -............. - -.. code:: abnf - - value-instruction =/ "subgroup_size" ":" "i32" - -Overview -~~~~~~~~ - -Returns the subgroup size; i32 integer. - - Subview ....... @@ -1597,29 +1574,27 @@ Additional instructions SPMD instructions ----------------- -Subgroup id -........... +Builtin (SPMD) +.............. .. code:: abnf - value-instruction =/ "subgroup_id" ":" "i32" + spmd-builtin-type = ".subgroup_id" / + ".subgroup_local_id" + value-instruction =/ "builtin" spmd-builtin-type ":" integer-type Overview ~~~~~~~~ -Returns the subgroup id; i32 integer from 0 to num_subgroups - 1. - -Subgroup local id -................. - -.. code:: abnf - - value-instruction =/ "subgroup_local_id" ":" "i32" - -Overview -~~~~~~~~ +Returns a builtin value. +The following table shows the builtins' description and the types that are returned. -Returns the work-item id within the subgroup; i32 integer from 0 to subgroup_size - 1. +=================== ===== ================================================================================= +Builtin Type Description +=================== ===== ================================================================================= +.subgroup_id i32 Returns the subgroup id; integer from 0 to num_subgroups - 1. +.subgroup_local_id i32 Returns the work-item id within the subgroup; integer from 0 to subgroup_size - 1 +=================== ===== ================================================================================= Sample code =========== diff --git a/examples/benchmark/main.cpp b/examples/benchmark/main.cpp index 511b9e4b..2aaf077d 100644 --- a/examples/benchmark/main.cpp +++ b/examples/benchmark/main.cpp @@ -92,7 +92,7 @@ auto gemm_kernel_with_inner_repetition(scalar_type ty, transpose tA, transpose t fn_body.get_parameters(params); auto bb = region_builder{fn_body}; - auto gid = bb.add(make_group_id(ctx, my_loc())); + auto gid = bb.add(make_builtin(builtin::group_id, index_ty, my_loc())); auto from = bb.add(make_constant_zero(index_ty, my_loc())); auto to = bb.add(make_constant(repetitions, index_ty, my_loc())); auto calpha = bb.add(make_constant_one(element_ty, my_loc())); diff --git a/examples/matrix_chain/test_ader.cpp b/examples/matrix_chain/test_ader.cpp index d69c0080..6bb0a471 100644 --- a/examples/matrix_chain/test_ader.cpp +++ b/examples/matrix_chain/test_ader.cpp @@ -97,7 +97,8 @@ auto test_ader::make_optimized_kernel(bool dump) auto bb = region_builder{fn_body}; auto const c0 = bb.add(make_constant_zero(element_ty)); auto const c1 = bb.add(make_constant_one(element_ty)); - auto const gid = bb.add(make_group_id(ctx)); + auto const gid = + bb.add(make_builtin(builtin::group_id, get_scalar(ctx, scalar_type::index))); auto const static_offsets3 = std::array{0, 0, dynamic}; auto const static_sizes3 = [](matrix_batch const &b) -> std::array { return {b.nrows(), b.ncols(), 0}; diff --git a/examples/matrix_chain/test_volume.cpp b/examples/matrix_chain/test_volume.cpp index ae057c11..3b24f601 100644 --- a/examples/matrix_chain/test_volume.cpp +++ b/examples/matrix_chain/test_volume.cpp @@ -82,7 +82,7 @@ auto test_volume::make_optimized_kernel(bool dump) I.set_name("I"); auto bb = region_builder{fn_body}; - auto gid = bb.add(make_group_id(ctx)); + auto gid = bb.add(make_builtin(builtin::group_id, get_scalar(ctx, scalar_type::index))); auto const static_offsets2 = std::array{0, 0}; auto const static_offsets3 = std::array{0, 0, dynamic}; auto const static_sizes3 = [](matrix_batch const &b) -> std::array { diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index c5506b0c..853f4134 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -192,6 +192,8 @@ TINYTC_EXPORT char const *tinytc_address_space_to_string(tinytc_address_space_t TINYTC_EXPORT char const *tinytc_arithmetic_to_string(tinytc_arithmetic_t op); //! Convert arithmetic operation type to string (unary) TINYTC_EXPORT char const *tinytc_arithmetic_unary_to_string(tinytc_arithmetic_unary_t op); +//! Convert builtin type to string +TINYTC_EXPORT char const *tinytc_builtin_to_string(tinytc_builtin_t b); //! Convert checked flag to string TINYTC_EXPORT char const *tinytc_checked_flag_to_string(tinytc_checked_flag_t flag); //! Convert cmp condition to string @@ -243,6 +245,23 @@ TINYTC_EXPORT tinytc_status_t tinytc_arith_unary_inst_create(tinytc_inst_t *inst tinytc_data_type_t ty, const tinytc_location_t *loc); +/** + * @brief Create builtin instruction + * + * @code %value = builtin. : %ty @endcode + * + * @param instr [out] pointer to the inst object created + * @param btype [in] builtin type + * @param ty [in] result type + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_builtin_inst_create(tinytc_inst_t *instr, + tinytc_builtin_t btype, + tinytc_data_type_t ty, + const tinytc_location_t *loc); + /** * @brief Create cast instruction * @@ -543,35 +562,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_load_inst_create(tinytc_inst_t *instr, tiny const tinytc_value_t *index_list, tinytc_data_type_t ty, const tinytc_location_t *loc); -/** - * @brief Create group_id instruction - * - * @code %value = group_id @endcode - * - * @param instr [out] pointer to the inst object created - * @param ctx [in] compiler context - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_group_id_inst_create(tinytc_inst_t *instr, - tinytc_compiler_context_t ctx, - const tinytc_location_t *loc); - -/** - * @brief Create group_size instruction - * - * @code %value = group_size @endcode - * - * @param instr [out] pointer to the inst object created - * @param ctx [in] compiler context - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_group_size_inst_create(tinytc_inst_t *instr, - tinytc_compiler_context_t ctx, - const tinytc_location_t *loc); /** * @brief Create GEMM instruction @@ -671,21 +661,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_hadamard_inst_create( tinytc_inst_t *instr, tinytc_bool_t atomic, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, tinytc_value_t beta, tinytc_value_t C, const tinytc_location_t *loc); -/** - * @brief Create num_subgroups instruction - * - * @code %value = num_subgroups @endcode - * - * @param instr [out] pointer to the inst object created - * @param ctx [in] compiler context - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_num_subgroups_inst_create(tinytc_inst_t *instr, - tinytc_compiler_context_t ctx, - const tinytc_location_t *loc); - /** * @brief Create parallel region * @@ -718,51 +693,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_size_inst_create(tinytc_inst_t *instr, tiny int64_t mode, tinytc_data_type_t ty, const tinytc_location_t *loc); -/** - * @brief Create subgroup_id instruction - * - * @code %value = subgroup_id @endcode - * - * @param instr [out] pointer to the inst object created - * @param ctx [in] compiler context - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_subgroup_id_inst_create(tinytc_inst_t *instr, - tinytc_compiler_context_t ctx, - const tinytc_location_t *loc); - -/** - * @brief Create subgroup_local_id instruction - * - * @code %value = subgroup_local_id @endcode - * - * @param instr [out] pointer to the inst object created - * @param ctx [in] compiler context - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_subgroup_local_id_inst_create(tinytc_inst_t *instr, - tinytc_compiler_context_t ctx, - const tinytc_location_t *loc); - -/** - * @brief Create subgroup_size instruction - * - * @code %value = subgroup_size @endcode - * - * @param instr [out] pointer to the inst object created - * @param ctx [in] compiler context - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_subgroup_size_inst_create(tinytc_inst_t *instr, - tinytc_compiler_context_t ctx, - const tinytc_location_t *loc); - /** * @brief Create subview instruction * diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 360a4c9b..9c8ef7d7 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -715,6 +715,17 @@ inline char const *to_string(arithmetic_unary op) { return ::tinytc_arithmetic_unary_to_string(static_cast<::tinytc_arithmetic_unary_t>(op)); } +/** + * @brief Convert builtin type to string + * + * @param b Builtin type + * + * @return C-string + */ +inline char const *to_string(builtin b) { + return ::tinytc_builtin_to_string(static_cast<::tinytc_builtin_t>(b)); +} + /** * @brief Convert checked flag string * @@ -905,6 +916,22 @@ inline inst make_arith(arithmetic_unary op, value a, data_type ty, location cons return inst(instr); } +/** + * @brief Make builtin instruction + * + * @param btype Builtin type + * @param ty Result type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_builtin(builtin btype, data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC( + tinytc_builtin_inst_create(&instr, static_cast(btype), ty, &loc), loc); + return inst(instr); +} + /** * @brief Make cast instruction * @@ -1232,34 +1259,6 @@ inline inst make_load(value a, array_view index_list, tinytc_data_type_t return inst(instr); } -/** - * @brief Make group id instruction - * - * @param ctx compiler context - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_group_id(compiler_context const &ctx, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_group_id_inst_create(&instr, ctx.get(), &loc), loc); - return inst(instr); -} - -/** - * @brief Make group size instruction - * - * @param ctx compiler context - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_group_size(compiler_context const &ctx, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_group_size_inst_create(&instr, ctx.get(), &loc), loc); - return inst(instr); -} - /** * @brief Make GEMM instruction * @@ -1348,20 +1347,6 @@ inline inst make_hadamard(bool atomic, value alpha, value A, value B, value beta return inst(instr); } -/** - * @brief Make num_subgroups instruction - * - * @param ctx compiler context - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_num_subgroups(compiler_context const &ctx, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_num_subgroups_inst_create(&instr, ctx.get(), &loc), loc); - return inst(instr); -} - /** * @brief Make parallel region * @@ -1391,48 +1376,6 @@ inline inst make_size(value a, std::int64_t mode, data_type ty, location const & return inst(instr); } -/** - * @brief Make subgroup_id instruction - * - * @param ctx compiler context - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_subgroup_id(compiler_context const &ctx, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_subgroup_id_inst_create(&instr, ctx.get(), &loc), loc); - return inst(instr); -} - -/** - * @brief Make subgroup_local_id instruction - * - * @param ctx compiler context - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_subgroup_local_id(compiler_context const &ctx, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_subgroup_local_id_inst_create(&instr, ctx.get(), &loc), loc); - return inst(instr); -} - -/** - * @brief Make subgroup_size instruction - * - * @param ctx compiler context - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_subgroup_size(compiler_context const &ctx, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_subgroup_size_inst_create(&instr, ctx.get(), &loc), loc); - return inst(instr); -} - /** * @brief Make subview instruction * diff --git a/include/tinytc/types.h b/include/tinytc/types.h index e7da9c42..b11c2470 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -54,55 +54,56 @@ typedef enum { tinytc_status_ir_expected_int = 0x108, ///< Expected a value of integer type tinytc_status_ir_expected_float = 0x109, ///< Expected a value of float type tinytc_status_ir_expected_complex = 0x10a, ///< Expected a value of complex type - tinytc_status_ir_expected_index = 0x10b, ///< Expected a value of index type - tinytc_status_ir_expected_coopmatrix = 0x10c, ///< Expected a value of coopmatrix type + tinytc_status_ir_expected_i32 = 0x10b, ///< Expected a value of i32 type + tinytc_status_ir_expected_index = 0x10c, ///< Expected a value of index type + tinytc_status_ir_expected_coopmatrix = 0x10d, ///< Expected a value of coopmatrix type tinytc_status_ir_expected_coopmatrix_or_scalar = - 0x10d, ///< Expected a value of coopmatrix or scalar type + 0x10e, ///< Expected a value of coopmatrix or scalar type tinytc_status_ir_expected_coopmatrix_scalar_or_boolean = - 0x10e, ///< Expected a value of coopmatrix, scalar type, or boolean - tinytc_status_ir_expected_memref = 0x10f, ///< Expected a value of memref type - tinytc_status_ir_expected_memref_or_scalar = 0x110, ///< Expected memref or scalar type - tinytc_status_ir_expected_memref_or_group = 0x111, ///< Expected a value of memref or group type - tinytc_status_ir_expected_memref_order_0 = 0x112, ///< Expected memref of order 0 - tinytc_status_ir_expected_memref_order_1 = 0x113, ///< Expected memref of order 1 - tinytc_status_ir_expected_memref_order_2 = 0x114, ///< Expected memref of order 2 - tinytc_status_ir_expected_memref_order_0_or_1 = 0x115, ///< Expected memref of order 0 or 1 - tinytc_status_ir_expected_memref_order_1_or_2 = 0x116, ///< Expected memref of order 1 or 2 - tinytc_status_ir_expected_memref_order_0_1_or_2 = 0x117, ///< Expected memref of order 0, 1 or 2 - tinytc_status_ir_unexpected_yield = 0x118, ///< Unexpected yield instruction - tinytc_status_ir_yield_mismatch = 0x119, ///< Wrong number of yielded values - tinytc_status_ir_subview_mismatch = 0x11a, ///< Mismatch in subview - tinytc_status_ir_invalid_slice = 0x11b, ///< Invalid slice - tinytc_status_ir_expand_shape_order_too_small = 0x11c, ///< Expand shape too small - tinytc_status_ir_expand_shape_mismatch = 0x11d, ///< Invalid expand shape - tinytc_status_ir_collective_called_from_spmd = 0x11e, ///< Collective instruction from SPMD - tinytc_status_ir_fp_unsupported = 0x11f, ///< Instruction does not support floating type - tinytc_status_ir_spmd_called_from_collective = 0x120, ///< SPMD instruction from collective - tinytc_status_ir_expected_local_address_space = 0x121, ///< Expected local address space - tinytc_status_ir_expected_global_address_space = 0x122, ///< Expected global address space - tinytc_status_ir_address_space_mismatch = 0x123, ///< Address space must match - tinytc_status_ir_invalid_offset = 0x124, ///< Invalid offset - tinytc_status_ir_int_unsupported = 0x125, ///< Instruction does not support int type - tinytc_status_ir_boolean_unsupported = 0x126, ///< Instruction does not support boolean type - tinytc_status_ir_complex_unsupported = 0x127, ///< Instruction does not support complex type + 0x10f, ///< Expected a value of coopmatrix, scalar type, or boolean + tinytc_status_ir_expected_memref = 0x110, ///< Expected a value of memref type + tinytc_status_ir_expected_memref_or_scalar = 0x111, ///< Expected memref or scalar type + tinytc_status_ir_expected_memref_or_group = 0x112, ///< Expected a value of memref or group type + tinytc_status_ir_expected_memref_order_0 = 0x113, ///< Expected memref of order 0 + tinytc_status_ir_expected_memref_order_1 = 0x114, ///< Expected memref of order 1 + tinytc_status_ir_expected_memref_order_2 = 0x115, ///< Expected memref of order 2 + tinytc_status_ir_expected_memref_order_0_or_1 = 0x116, ///< Expected memref of order 0 or 1 + tinytc_status_ir_expected_memref_order_1_or_2 = 0x117, ///< Expected memref of order 1 or 2 + tinytc_status_ir_expected_memref_order_0_1_or_2 = 0x118, ///< Expected memref of order 0, 1 or 2 + tinytc_status_ir_unexpected_yield = 0x119, ///< Unexpected yield instruction + tinytc_status_ir_yield_mismatch = 0x11a, ///< Wrong number of yielded values + tinytc_status_ir_subview_mismatch = 0x11b, ///< Mismatch in subview + tinytc_status_ir_invalid_slice = 0x11c, ///< Invalid slice + tinytc_status_ir_expand_shape_order_too_small = 0x11d, ///< Expand shape too small + tinytc_status_ir_expand_shape_mismatch = 0x11e, ///< Invalid expand shape + tinytc_status_ir_collective_called_from_spmd = 0x11f, ///< Collective instruction from SPMD + tinytc_status_ir_fp_unsupported = 0x120, ///< Instruction does not support floating type + tinytc_status_ir_spmd_called_from_collective = 0x121, ///< SPMD instruction from collective + tinytc_status_ir_expected_local_address_space = 0x122, ///< Expected local address space + tinytc_status_ir_expected_global_address_space = 0x123, ///< Expected global address space + tinytc_status_ir_address_space_mismatch = 0x124, ///< Address space must match + tinytc_status_ir_invalid_offset = 0x125, ///< Invalid offset + tinytc_status_ir_int_unsupported = 0x126, ///< Instruction does not support int type + tinytc_status_ir_boolean_unsupported = 0x127, ///< Instruction does not support boolean type + tinytc_status_ir_complex_unsupported = 0x128, ///< Instruction does not support complex type tinytc_status_ir_coopmatrix_unsupported = - 0x128, ///< Instruction does not support coopmatrix type - tinytc_status_ir_forbidden_cast = 0x129, ///< Forbidden cast - tinytc_status_ir_invalid_beta = 0x12a, ///< Invalid beta value - tinytc_status_ir_init_return_mismatch = 0x12b, ///< Mismatch of init values and returned values - tinytc_status_ir_invalid_matrix_use = 0x12c, ///< Invalid matrix use - tinytc_status_ir_unsupported_coopmatrix_shape = 0x12d, ///< Unsupported coopmatrix shape - tinytc_status_ir_incompatible_scalar_types = 0x12e, ///< Incompatible scalar types - tinytc_status_ir_constant_mismatch = 0x12f, ///< Constant mismatch - tinytc_status_ir_insufficient_alignment = 0x130, ///< Insufficient alignment - tinytc_status_ir_must_have_yield = 0x131, ///< Must have yield instruction + 0x129, ///< Instruction does not support coopmatrix type + tinytc_status_ir_forbidden_cast = 0x12a, ///< Forbidden cast + tinytc_status_ir_invalid_beta = 0x12b, ///< Invalid beta value + tinytc_status_ir_init_return_mismatch = 0x12c, ///< Mismatch of init values and returned values + tinytc_status_ir_invalid_matrix_use = 0x12d, ///< Invalid matrix use + tinytc_status_ir_unsupported_coopmatrix_shape = 0x12e, ///< Unsupported coopmatrix shape + tinytc_status_ir_incompatible_scalar_types = 0x12f, ///< Incompatible scalar types + tinytc_status_ir_constant_mismatch = 0x130, ///< Constant mismatch + tinytc_status_ir_insufficient_alignment = 0x131, ///< Insufficient alignment + tinytc_status_ir_must_have_yield = 0x132, ///< Must have yield instruction tinytc_status_ir_yield_in_else_branch_missing = - 0x132, ///< Must have yield instruction in else branch - tinytc_status_ir_from_to_mismatch = 0x133, ///< size(from) != size(to) in foreach + 0x133, ///< Must have yield instruction in else branch + tinytc_status_ir_from_to_mismatch = 0x134, ///< size(from) != size(to) in foreach tinytc_status_ir_operand_type_must_match_return_type = - 0x134, /// Operand type must match return type - tinytc_status_ir_invalid_stride = 0x135, ///< Invalid stride - tinytc_status_ir_init_return_type_mismatch = 0x136, ///< Init return type mismatch + 0x135, /// Operand type must match return type + tinytc_status_ir_invalid_stride = 0x136, ///< Invalid stride + tinytc_status_ir_init_return_type_mismatch = 0x137, ///< Init return type mismatch // SPIR-V errors tinytc_status_spirv_forbidden_forward_declaration = 0x1000, ///< Forward declaration of id is forbidden @@ -297,6 +298,16 @@ typedef enum { tinytc_arithmetic_unary_re = 5 ///< real part } tinytc_arithmetic_unary_t; +//! Builtin values +typedef enum { + tinytc_builtin_group_id = 0, ///< group id + tinytc_builtin_group_size = 1, ///< group size + tinytc_builtin_num_subgroups = 2, ///< number of subgroups + tinytc_builtin_subgroup_size = 3, ///< subgroup size + tinytc_builtin_subgroup_id = 4, ///< subgroup id + tinytc_builtin_subgroup_local_id = 5, ///< subgroup local id +} tinytc_builtin_t; + //! Compare operation typedef enum { tinytc_cmp_condition_eq = 0, ///< equals diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 64d7a6b5..9444fc50 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -64,6 +64,7 @@ enum class status { ir_expected_int = tinytc_status_ir_expected_int, ir_expected_float = tinytc_status_ir_expected_float, ir_expected_complex = tinytc_status_ir_expected_complex, + ir_expected_i32 = tinytc_status_ir_expected_i32, ir_expected_index = tinytc_status_ir_expected_index, ir_expected_coopmatrix = tinytc_status_ir_expected_coopmatrix, ir_expected_coopmatrix_or_scalar = tinytc_status_ir_expected_coopmatrix_or_scalar, @@ -274,6 +275,16 @@ enum class arithmetic_unary { re = tinytc_arithmetic_unary_re ///< real part }; +//! Builtin values +enum class builtin { + group_id = tinytc_builtin_group_id, ///< group id + group_size = tinytc_builtin_group_size, ///< group size + num_subgroups = tinytc_builtin_num_subgroups, ///< number of subgroups + subgroup_size = tinytc_builtin_subgroup_size, ///< subgroup size + subgroup_id = tinytc_builtin_subgroup_id, ///< subgroup id + subgroup_local_id = tinytc_builtin_subgroup_local_id, ///< subgroup local id +}; + //! Compare operation enum class cmp_condition { eq = tinytc_cmp_condition_eq, ///< equals diff --git a/src/error.cpp b/src/error.cpp index 8fead6f1..5a392699 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -156,6 +156,8 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Expected floating point type"; case tinytc_status_ir_expected_complex: return "Expected complex type"; + case tinytc_status_ir_expected_i32: + return "Expected i32 type"; case tinytc_status_ir_expected_index: return "Expected index type"; case tinytc_status_ir_expected_coopmatrix: diff --git a/src/inst.cpp b/src/inst.cpp index 5eb9f418..8e2b9202 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -80,6 +80,24 @@ char const *tinytc_arithmetic_unary_to_string(tinytc_arithmetic_unary_t op) { return "unknown"; } +char const *tinytc_builtin_to_string(tinytc_builtin_t b) { + switch (b) { + case tinytc_builtin_group_id: + return "group_id"; + case tinytc_builtin_group_size: + return "group_size"; + case tinytc_builtin_num_subgroups: + return "num_subgroups"; + case tinytc_builtin_subgroup_size: + return "subgroup_size"; + case tinytc_builtin_subgroup_id: + return "subgroup_id"; + case tinytc_builtin_subgroup_local_id: + return "subgroup_local_id"; + } + return "unknown"; +} + char const *tinytc_checked_flag_to_string(tinytc_checked_flag_t flag) { switch (flag) { case tinytc_checked_flag_none: @@ -403,6 +421,17 @@ tinytc_status_t tinytc_axpby_inst_create(tinytc_inst_t *instr, tinytc_transpose_ }); } +tinytc_status_t tinytc_builtin_inst_create(tinytc_inst_t *instr, tinytc_builtin_t btype, + tinytc_data_type_t ty, const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = std::make_unique(enum_cast(btype), ty, get_optional(loc)) + .release(); + }); +} + tinytc_status_t tinytc_expand_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t expanded_mode, uint32_t static_expand_shape_size, const int64_t *static_expand_shape, @@ -445,24 +474,6 @@ tinytc_status_t tinytc_load_inst_create(tinytc_inst_t *instr, tinytc_value_t a, }); } -tinytc_status_t tinytc_group_id_inst_create(tinytc_inst_t *instr, tinytc_compiler_context_t ctx, - const tinytc_location_t *loc) { - if (instr == nullptr || ctx == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code( - [&] { *instr = std::make_unique(ctx, get_optional(loc)).release(); }); -} - -tinytc_status_t tinytc_group_size_inst_create(tinytc_inst_t *instr, tinytc_compiler_context_t ctx, - const tinytc_location_t *loc) { - if (instr == nullptr || ctx == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code( - [&] { *instr = std::make_unique(ctx, get_optional(loc)).release(); }); -} - tinytc_status_t tinytc_gemm_inst_create(tinytc_inst_t *instr, tinytc_transpose_t tA, tinytc_transpose_t tB, tinytc_bool_t atomic, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, @@ -519,16 +530,6 @@ tinytc_status_t tinytc_hadamard_inst_create(tinytc_inst_t *instr, tinytc_bool_t }); } -tinytc_status_t tinytc_num_subgroups_inst_create(tinytc_inst_t *instr, - tinytc_compiler_context_t ctx, - const tinytc_location_t *loc) { - if (instr == nullptr || ctx == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code( - [&] { *instr = std::make_unique(ctx, get_optional(loc)).release(); }); -} - tinytc_status_t tinytc_parallel_inst_create(tinytc_inst_t *instr, const tinytc_location_t *loc) { if (instr == nullptr) { return tinytc_status_invalid_arguments; @@ -546,36 +547,6 @@ tinytc_status_t tinytc_size_inst_create(tinytc_inst_t *instr, tinytc_value_t a, [&] { *instr = std::make_unique(a, mode, ty, get_optional(loc)).release(); }); } -tinytc_status_t tinytc_subgroup_id_inst_create(tinytc_inst_t *instr, tinytc_compiler_context_t ctx, - const tinytc_location_t *loc) { - if (instr == nullptr || ctx == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code( - [&] { *instr = std::make_unique(ctx, get_optional(loc)).release(); }); -} - -tinytc_status_t tinytc_subgroup_local_id_inst_create(tinytc_inst_t *instr, - tinytc_compiler_context_t ctx, - const tinytc_location_t *loc) { - if (instr == nullptr || ctx == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *instr = std::make_unique(ctx, get_optional(loc)).release(); - }); -} - -tinytc_status_t tinytc_subgroup_size_inst_create(tinytc_inst_t *instr, - tinytc_compiler_context_t ctx, - const tinytc_location_t *loc) { - if (instr == nullptr || ctx == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code( - [&] { *instr = std::make_unique(ctx, get_optional(loc)).release(); }); -} - tinytc_status_t tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t a, uint32_t static_list_size, const int64_t *static_offset_list, const int64_t *static_size_list, diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index d4d88411..ee973fbb 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -55,12 +55,8 @@ auto tinytc_inst::kind() const -> tinytc::inst_execution_kind { case tinytc::IK::expand: case tinytc::IK::fuse: case tinytc::IK::load: - case tinytc::IK::group_id: - case tinytc::IK::group_size: case tinytc::IK::if_: - case tinytc::IK::num_subgroups: case tinytc::IK::size: - case tinytc::IK::subgroup_size: case tinytc::IK::subview: case tinytc::IK::store: case tinytc::IK::work_group: @@ -69,9 +65,8 @@ auto tinytc_inst::kind() const -> tinytc::inst_execution_kind { case tinytc::IK::for_loop: case tinytc::IK::last_loop: return tinytc::inst_execution_kind::mixed; - case tinytc::IK::subgroup_id: - case tinytc::IK::subgroup_local_id: - return tinytc::inst_execution_kind::spmd; + case tinytc::IK::builtin: + return tinytc::dyn_cast(this)->kind(); }; throw tinytc::internal_compiler_error(); } @@ -387,6 +382,49 @@ arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0 } } +builtin_inst::builtin_inst(builtin btype, tinytc_data_type_t ty, location const &lc) + : standard_inst{IK::builtin}, btype_{btype} { + loc(lc); + + auto rt = dyn_cast(ty); + if (!rt) { + throw compilation_error(loc(), status::ir_expected_scalar); + } + + switch (builtin_type()) { + case builtin::group_id: + case builtin::group_size: + if (rt->ty() != scalar_type::index) { + throw compilation_error(loc(), status::ir_expected_index); + } + break; + case builtin::num_subgroups: + case builtin::subgroup_size: + case builtin::subgroup_id: + case builtin::subgroup_local_id: + if (rt->ty() != scalar_type::i32) { + throw compilation_error(loc(), status::ir_expected_i32); + } + break; + } + + result(0) = value_node{ty, this, lc}; +} + +auto builtin_inst::kind() const -> tinytc::inst_execution_kind { + switch (builtin_type()) { + case builtin::group_id: + case builtin::group_size: + case builtin::num_subgroups: + case builtin::subgroup_size: + return tinytc::inst_execution_kind::mixed; + case builtin::subgroup_id: + case builtin::subgroup_local_id: + return tinytc::inst_execution_kind::spmd; + } + return tinytc::inst_execution_kind::spmd; +} + cast_inst::cast_inst(tinytc_value_t a0, tinytc_data_type_t to_ty, location const &lc) : standard_inst{IK::cast} { op(op_a, a0); diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index f2aca4d4..b1e7d58b 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -42,6 +42,7 @@ enum class IK { arith, arith_unary, barrier, + builtin, cast, compare, constant, @@ -52,16 +53,10 @@ enum class IK { expand, fuse, load, - group_id, - group_size, lifetime_stop, if_, - num_subgroups, parallel, size, - subgroup_id, - subgroup_local_id, - subgroup_size, subview, store, work_group, @@ -85,17 +80,15 @@ enum class IK { last_loop }; using inst_nodes = - type_list; + class if_inst, class parallel_inst, class size_inst, class subview_inst, + class store_inst, class sum_inst, class work_group_inst, class yield_inst>; using result_range = iterator_range_wrapper; using const_result_range = iterator_range_wrapper; @@ -413,6 +406,19 @@ class barrier_inst : public standard_inst<0, 0> { std::int32_t fence_flags_; }; +class builtin_inst : public standard_inst<0, 1> { + public: + inline static bool classof(inst_node const &i) { return i.type_id() == IK::builtin; } + builtin_inst(builtin btype, tinytc_data_type_t ty, location const &lc = {}); + + inline auto builtin_type() const -> builtin { return btype_; } + + auto kind() const -> tinytc::inst_execution_kind; + + private: + builtin btype_; +}; + class cast_inst : public standard_inst<1, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::cast; } @@ -588,26 +594,6 @@ class load_inst : public standard_inst { inline auto index_list() const { return operands() | std::views::drop(1); } }; -class group_id_inst : public standard_inst<0, 1> { - public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK::group_id; } - inline group_id_inst(tinytc_compiler_context_t ctx, location const &lc = {}) - : standard_inst{IK::group_id} { - loc(lc); - result(0) = value_node{scalar_data_type::get(ctx, scalar_type::index), this, lc}; - } -}; - -class group_size_inst : public standard_inst<0, 1> { - public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK::group_size; } - inline group_size_inst(tinytc_compiler_context_t ctx, location const &lc = {}) - : standard_inst{IK::group_size} { - loc(lc); - result(0) = value_node{scalar_data_type::get(ctx, scalar_type::index), this, lc}; - } -}; - class lifetime_stop_inst : public standard_inst<1, 0> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::lifetime_stop; } @@ -725,16 +711,6 @@ class if_inst : public standard_inst<1, dynamic, 2> { inline bool is_otherwise_empty() const { return otherwise().insts().empty(); } }; -class num_subgroups_inst : public standard_inst<0, 1> { - public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK::num_subgroups; } - inline num_subgroups_inst(tinytc_compiler_context_t ctx, location const &lc = {}) - : standard_inst{IK::num_subgroups} { - loc(lc); - result(0) = value_node{scalar_data_type::get(ctx, scalar_type::i32), this, lc}; - } -}; - class parallel_inst : public standard_inst<0, 0, 1> { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::parallel; } @@ -757,36 +733,6 @@ class size_inst : public standard_inst<1, 1> { std::int64_t mode_; }; -class subgroup_id_inst : public standard_inst<0, 1> { - public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK::subgroup_id; } - inline subgroup_id_inst(tinytc_compiler_context_t ctx, location const &lc = {}) - : standard_inst{IK::subgroup_id} { - loc(lc); - result(0) = value_node{scalar_data_type::get(ctx, scalar_type::i32), this, lc}; - } -}; - -class subgroup_local_id_inst : public standard_inst<0, 1> { - public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK::subgroup_local_id; } - inline subgroup_local_id_inst(tinytc_compiler_context_t ctx, location const &lc = {}) - : standard_inst{IK::subgroup_local_id} { - loc(lc); - result(0) = value_node{scalar_data_type::get(ctx, scalar_type::i32), this, lc}; - } -}; - -class subgroup_size_inst : public standard_inst<0, 1> { - public: - inline static bool classof(inst_node const &i) { return i.type_id() == IK::subgroup_size; } - inline subgroup_size_inst(tinytc_compiler_context_t ctx, location const &lc = {}) - : standard_inst{IK::subgroup_size} { - loc(lc); - result(0) = value_node{scalar_data_type::get(ctx, scalar_type::i32), this, lc}; - } -}; - class subview_inst : public standard_inst { public: inline static bool classof(inst_node const &i) { return i.type_id() == IK::subview; } diff --git a/src/parser/lexer.re b/src/parser/lexer.re index 84d82230..38cb6091 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -150,6 +150,7 @@ lex: "axpby" { adv_loc(); return parser::make_AXPBY(loc_); } "arith" { adv_loc(); return parser::make_ARITH(loc_); } "barrier" { adv_loc(); return parser::make_BARRIER(loc_); } + "builtin" { adv_loc(); return parser::make_BUILTIN(loc_); } "gemm" { adv_loc(); return parser::make_GEMM(loc_); } "gemv" { adv_loc(); return parser::make_GEMV(loc_); } "ger" { adv_loc(); return parser::make_GER(loc_); } @@ -165,17 +166,12 @@ lex: "expand" { adv_loc(); return parser::make_EXPAND(loc_); } "fuse" { adv_loc(); return parser::make_FUSE(loc_); } "load" { adv_loc(); return parser::make_LOAD(loc_); } - "group_id" { adv_loc(); return parser::make_GROUP_ID(loc_); } - "group_size" { adv_loc(); return parser::make_GROUP_SIZE(loc_); } "for" { adv_loc(); return parser::make_FOR(loc_); } "foreach" { adv_loc(); return parser::make_FOREACH(loc_); } "if" { adv_loc(); return parser::make_IF(loc_); } - "num_subgroups" { adv_loc(); return parser::make_NUM_SUBGROUPS(loc_); } "parallel" { adv_loc(); return parser::make_PARALLEL(loc_); } "else" { adv_loc(); return parser::make_ELSE(loc_); } "size" { adv_loc(); return parser::make_SIZE(loc_); } - "subgroup_id" { adv_loc(); return parser::make_SUBGROUP_ID(loc_); } - "subgroup_local_id" { adv_loc(); return parser::make_SUBGROUP_LOCAL_ID(loc_); } "subview" { adv_loc(); return parser::make_SUBVIEW(loc_); } "store" { adv_loc(); return parser::make_STORE(loc_); } "sum" { adv_loc(); return parser::make_SUM(loc_); } @@ -202,6 +198,14 @@ lex: ".im" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::im, loc_); } ".re" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::re, loc_); } + // builtin + ".group_id" { adv_loc(); return parser::make_BUILTIN_TYPE(builtin::group_id, loc_); } + ".group_size" { adv_loc(); return parser::make_BUILTIN_TYPE(builtin::group_size, loc_); } + ".num_subgroups" { adv_loc(); return parser::make_BUILTIN_TYPE(builtin::num_subgroups, loc_); } + ".subgroup_size" { adv_loc(); return parser::make_BUILTIN_TYPE(builtin::subgroup_size, loc_); } + ".subgroup_id" { adv_loc(); return parser::make_BUILTIN_TYPE(builtin::subgroup_id, loc_); } + ".subgroup_local_id" { adv_loc(); return parser::make_BUILTIN_TYPE(builtin::subgroup_local_id, loc_); } + // comparison condition ".eq" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::eq, loc_); } ".ne" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::ne, loc_); } diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 7c227855..7452dafc 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -120,6 +120,7 @@ GER "ger" HADAMARD "hadamard" ALLOCA "alloca" + BUILTIN "builtin" CAST "cast" CMP "cmp" CONSTANT "constant" @@ -132,15 +133,10 @@ LOAD "load" FOR "for" FOREACH "foreach" - GROUP_ID "group_id" - GROUP_SIZE "group_size" IF "if" ELSE "else" - NUM_SUBGROUPS "num_subgroups" PARALLEL "parallel" SIZE "size" - SUBGROUP_ID "subgroup_id" - SUBGROUP_LOCAL_ID "subgroup_local_id" SUBVIEW "subview" STORE "store" SUM "sum" @@ -156,6 +152,7 @@ %token FLOATING_TYPE %token ARITHMETIC %token ARITHMETIC_UNARY +%token BUILTIN_TYPE %token CMP_CONDITION %token WORK_GROUP_OPERATION %token MATRIX_USE @@ -213,6 +210,7 @@ %nterm alloca_inst %nterm arith_inst %nterm arith_unary_inst +%nterm builtin_inst %nterm cast_inst %nterm compare_inst %nterm constant_inst @@ -226,14 +224,8 @@ %nterm > expand_shape %nterm fuse_inst %nterm load_inst -%nterm group_id_inst -%nterm group_size_inst -%nterm num_subgroups_inst %nterm parallel_inst %nterm size_inst -%nterm subgroup_id_inst -%nterm subgroup_local_id_inst -%nterm subgroup_size_inst %nterm store_inst %nterm store_flag %nterm subview_inst @@ -721,6 +713,7 @@ valued_inst: alloca_inst { $$ = std::move($1); } | arith_inst { $$ = std::move($1); } | arith_unary_inst { $$ = std::move($1); } + | builtin_inst { $$ = std::move($1); } | cast_inst { $$ = std::move($1); } | compare_inst { $$ = std::move($1); } | constant_inst { $$ = std::move($1); } @@ -730,15 +723,9 @@ valued_inst: | expand_inst { $$ = std::move($1); } | for_inst { $$ = std::move($1); } | fuse_inst { $$ = std::move($1); } - | group_id_inst { $$ = std::move($1); } - | group_size_inst { $$ = std::move($1); } | if_inst { $$ = std::move($1); } | load_inst { $$ = std::move($1); } - | num_subgroups_inst { $$ = std::move($1); } | size_inst { $$ = std::move($1); } - | subgroup_id_inst { $$ = std::move($1); } - | subgroup_local_id_inst { $$ = std::move($1); } - | subgroup_size_inst { $$ = std::move($1); } | subview_inst { $$ = std::move($1); } | work_group_inst { $$ = std::move($1); } ; @@ -786,6 +773,11 @@ arith_unary_inst: } ; +builtin_inst: + BUILTIN BUILTIN_TYPE COLON data_type[ty] { + $$ = inst{std::make_unique($BUILTIN_TYPE, $ty, @builtin_inst).release()}; + } +; cast_inst: CAST var[a] COLON data_type[to] { @@ -1020,14 +1012,6 @@ store_flag: | ATOMIC_ADD { $$ = store_flag::atomic_add; } ; -group_id_inst: - GROUP_ID { $$ = inst{std::make_unique(ctx.cctx().get(), @GROUP_ID).release()}; } -; - -group_size_inst: - GROUP_SIZE { $$ = inst{std::make_unique(ctx.cctx().get(), @GROUP_SIZE).release()}; } -; - if_inst: IF var[condition] optional_returned_values { try { @@ -1072,10 +1056,6 @@ return_type_list: } ; -num_subgroups_inst: - NUM_SUBGROUPS { $$ = inst{std::make_unique(ctx.cctx().get(), @NUM_SUBGROUPS).release()}; } -; - parallel_inst: PARALLEL { try { @@ -1105,20 +1085,6 @@ size_inst: } ; -subgroup_id_inst: - SUBGROUP_ID { $$ = inst{std::make_unique(ctx.cctx().get(), @SUBGROUP_ID).release()}; } -; - -subgroup_local_id_inst: - SUBGROUP_LOCAL_ID { - $$ = inst{std::make_unique(ctx.cctx().get(), @SUBGROUP_LOCAL_ID).release()}; - } -; - -subgroup_size_inst: - SUBGROUP_SIZE { $$ = inst{std::make_unique(ctx.cctx().get(), @SUBGROUP_SIZE).release()}; } -; - subview_inst: SUBVIEW var LSQBR optional_slice_list RSQBR COLON memref_type[ty] { try { diff --git a/src/pass/clone.cpp b/src/pass/clone.cpp index 8f38d9ac..1b0e2162 100644 --- a/src/pass/clone.cpp +++ b/src/pass/clone.cpp @@ -35,6 +35,9 @@ auto inst_cloner::operator()(arith_unary_inst &in) -> std::unique_ptr std::unique_ptr { return std::make_unique(in.fence_flags(), in.loc()); } +auto inst_cloner::operator()(builtin_inst &in) -> std::unique_ptr { + return std::make_unique(in.builtin_type(), in.result(0).ty(), in.loc()); +} auto inst_cloner::operator()(cast_inst &in) -> std::unique_ptr { return std::make_unique(subs(&in.a()), in.result(0).ty(), in.loc()); } @@ -80,14 +83,6 @@ auto inst_cloner::operator()(load_inst &in) -> std::unique_ptr { return std::make_unique(subs(&in.operand()), subs_value_range(in.index_list()), in.result(0).ty(), in.loc()); } -auto inst_cloner::operator()(group_id_inst &in) -> std::unique_ptr { - return std::make_unique(in.context(), in.loc()); -} - -auto inst_cloner::operator()(group_size_inst &in) -> std::unique_ptr { - return std::make_unique(in.context(), in.loc()); -} - auto inst_cloner::operator()(gemm_inst &in) -> std::unique_ptr { return std::make_unique(in.tA(), in.tB(), subs(&in.alpha()), subs(&in.A()), subs(&in.B()), subs(&in.beta()), subs(&in.C()), in.atomic(), @@ -131,10 +126,6 @@ auto inst_cloner::operator()(if_inst &in) -> std::unique_ptr { return std::make_unique(subs(&in.condition()), return_types, in.loc()); } -auto inst_cloner::operator()(num_subgroups_inst &in) -> std::unique_ptr { - return std::make_unique(in.context(), in.loc()); -} - auto inst_cloner::operator()(parallel_inst &in) -> std::unique_ptr { return std::make_unique(in.loc()); } @@ -143,18 +134,6 @@ auto inst_cloner::operator()(size_inst &in) -> std::unique_ptr { return std::make_unique(subs(&in.operand()), in.mode(), in.result(0).ty(), in.loc()); } -auto inst_cloner::operator()(subgroup_id_inst &in) -> std::unique_ptr { - return std::make_unique(in.context(), in.loc()); -} - -auto inst_cloner::operator()(subgroup_local_id_inst &in) -> std::unique_ptr { - return std::make_unique(in.context(), in.loc()); -} - -auto inst_cloner::operator()(subgroup_size_inst &in) -> std::unique_ptr { - return std::make_unique(in.context(), in.loc()); -} - auto inst_cloner::operator()(subview_inst &in) -> std::unique_ptr { return std::make_unique( subs(&in.operand()), in.static_offsets(), in.static_sizes(), subs_value_range(in.offsets()), diff --git a/src/pass/clone.hpp b/src/pass/clone.hpp index d7855216..b70e9cd8 100644 --- a/src/pass/clone.hpp +++ b/src/pass/clone.hpp @@ -20,6 +20,7 @@ class inst_cloner { auto operator()(arith_inst &in) -> std::unique_ptr; auto operator()(arith_unary_inst &in) -> std::unique_ptr; auto operator()(barrier_inst &in) -> std::unique_ptr; + auto operator()(builtin_inst &in) -> std::unique_ptr; auto operator()(cast_inst &in) -> std::unique_ptr; auto operator()(compare_inst &in) -> std::unique_ptr; auto operator()(constant_inst &in) -> std::unique_ptr; @@ -30,8 +31,6 @@ class inst_cloner { auto operator()(expand_inst &in) -> std::unique_ptr; auto operator()(fuse_inst &in) -> std::unique_ptr; auto operator()(load_inst &in) -> std::unique_ptr; - auto operator()(group_id_inst &in) -> std::unique_ptr; - auto operator()(group_size_inst &in) -> std::unique_ptr; auto operator()(lifetime_stop_inst &in) -> std::unique_ptr; auto operator()(gemm_inst &in) -> std::unique_ptr; auto operator()(gemv_inst &in) -> std::unique_ptr; @@ -40,12 +39,8 @@ class inst_cloner { auto operator()(foreach_inst &in) -> std::unique_ptr; auto operator()(hadamard_inst &in) -> std::unique_ptr; auto operator()(if_inst &in) -> std::unique_ptr; - auto operator()(num_subgroups_inst &in) -> std::unique_ptr; auto operator()(parallel_inst &in) -> std::unique_ptr; auto operator()(size_inst &in) -> std::unique_ptr; - auto operator()(subgroup_id_inst &in) -> std::unique_ptr; - auto operator()(subgroup_local_id_inst &in) -> std::unique_ptr; - auto operator()(subgroup_size_inst &in) -> std::unique_ptr; auto operator()(subview_inst &in) -> std::unique_ptr; auto operator()(store_inst &in) -> std::unique_ptr; auto operator()(sum_inst &in) -> std::unique_ptr; diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp index 953458a2..ee85b01a 100644 --- a/src/pass/convert_to_opencl.cpp +++ b/src/pass/convert_to_opencl.cpp @@ -454,6 +454,28 @@ std::vector convert_to_opencl_pass::operator()(arith_unary_inst cons throw compilation_error(a.loc(), status::ir_expected_coopmatrix_scalar_or_boolean); } +std::vector convert_to_opencl_pass::operator()(builtin_inst const &in) { + auto lhs = declare(*in.result()); + auto rhs = [&]() -> clir::expr { + switch (in.builtin_type()) { + case builtin::group_id: + return clir::get_global_id(2); + case builtin::group_size: + return clir::get_global_size(2); + case builtin::num_subgroups: + return clir::get_num_sub_groups(); + case builtin::subgroup_size: + return clir::get_sub_group_size(); + case builtin::subgroup_id: + return clir::get_sub_group_id(); + case builtin::subgroup_local_id: + return clir::get_sub_group_local_id(); + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + return {declaration_assignment(visit(*this, *in.result(0).ty()), std::move(lhs), rhs())}; +} + std::vector convert_to_opencl_pass::operator()(cast_inst const &c) { auto const make = [](clir::expr a, scalar_type aty, scalar_type rty) -> clir::expr { if (is_complex_type(aty) && is_complex_type(rty)) { @@ -1024,20 +1046,6 @@ std::vector convert_to_opencl_pass::operator()(load_inst const &e) { return clinst; } -std::vector convert_to_opencl_pass::operator()(group_id_inst const &g) { - auto rhs = clir::get_global_id(2); - auto lhs = declare(*g.result()); - return { - declaration_assignment(visit(*this, *g.result()->ty()), std::move(lhs), std::move(rhs))}; -} - -std::vector convert_to_opencl_pass::operator()(group_size_inst const &g) { - auto rhs = clir::get_global_size(2); - auto lhs = declare(*g.result()); - return { - declaration_assignment(visit(*this, *g.result()->ty()), std::move(lhs), std::move(rhs))}; -} - std::vector convert_to_opencl_pass::operator()(lifetime_stop_inst const &) { return {}; } @@ -1325,13 +1333,6 @@ std::vector convert_to_opencl_pass::operator()(if_inst const &in) { return clinst; } -std::vector convert_to_opencl_pass::operator()(num_subgroups_inst const &sg) { - auto rhs = clir::get_num_sub_groups(); - auto lhs = declare(*sg.result()); - return { - declaration_assignment(visit(*this, *sg.result()->ty()), std::move(lhs), std::move(rhs))}; -} - std::vector convert_to_opencl_pass::operator()(parallel_inst const &p) { return {run_on_region(p.body())}; } @@ -1344,27 +1345,6 @@ std::vector convert_to_opencl_pass::operator()(size_inst const &s) { dv.shape(s.mode()))}; } -std::vector convert_to_opencl_pass::operator()(subgroup_id_inst const &sg) { - auto rhs = clir::get_sub_group_id(); - auto lhs = declare(*sg.result()); - return { - declaration_assignment(visit(*this, *sg.result()->ty()), std::move(lhs), std::move(rhs))}; -} - -std::vector convert_to_opencl_pass::operator()(subgroup_local_id_inst const &sg) { - auto rhs = clir::get_sub_group_local_id(); - auto lhs = declare(*sg.result()); - return { - declaration_assignment(visit(*this, *sg.result()->ty()), std::move(lhs), std::move(rhs))}; -} - -std::vector convert_to_opencl_pass::operator()(subgroup_size_inst const &sg) { - auto rhs = clir::get_sub_group_size(); - auto lhs = declare(*sg.result()); - return { - declaration_assignment(visit(*this, *sg.result()->ty()), std::move(lhs), std::move(rhs))}; -} - std::vector convert_to_opencl_pass::operator()(subview_inst const &s) { auto result_var = declare(*s.result()); auto t = get_memref_type(s.operand()); diff --git a/src/pass/convert_to_opencl.hpp b/src/pass/convert_to_opencl.hpp index 7ac30813..599a1636 100644 --- a/src/pass/convert_to_opencl.hpp +++ b/src/pass/convert_to_opencl.hpp @@ -71,6 +71,7 @@ class convert_to_opencl_pass { std::vector operator()(alloca_inst const &a); std::vector operator()(axpby_inst const &a); std::vector operator()(barrier_inst const &b); + std::vector operator()(builtin_inst const &b); std::vector operator()(arith_inst const &a); std::vector operator()(arith_unary_inst const &a); std::vector operator()(cast_inst const &c); @@ -83,8 +84,6 @@ class convert_to_opencl_pass { std::vector operator()(expand_inst const &e); std::vector operator()(fuse_inst const &f); std::vector operator()(load_inst const &e); - std::vector operator()(group_id_inst const &g); - std::vector operator()(group_size_inst const &g); std::vector operator()(lifetime_stop_inst const &l); std::vector operator()(gemm_inst const &g); std::vector operator()(gemv_inst const &g); @@ -93,12 +92,8 @@ class convert_to_opencl_pass { std::vector operator()(foreach_inst const &in); std::vector operator()(hadamard_inst const &g); std::vector operator()(if_inst const &in); - std::vector operator()(num_subgroups_inst const &sg); std::vector operator()(parallel_inst const &p); std::vector operator()(size_inst const &s); - std::vector operator()(subgroup_id_inst const &sg); - std::vector operator()(subgroup_local_id_inst const &sg); - std::vector operator()(subgroup_size_inst const &sg); std::vector operator()(subview_inst const &s); std::vector operator()(store_inst const &s); std::vector operator()(sum_inst const &s); diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp index 1613e824..38d1c3b4 100644 --- a/src/pass/dump_ir.cpp +++ b/src/pass/dump_ir.cpp @@ -138,6 +138,12 @@ void dump_ir_pass::operator()(barrier_inst const &b) { } } +void dump_ir_pass::operator()(builtin_inst const &in) { + dump_val(in.result(0)); + *os_ << " = builtin." << to_string(in.builtin_type()) << " : "; + visit(*this, *in.result(0).ty()); +} + void dump_ir_pass::operator()(cast_inst const &c) { dump_val(c.result(0)); *os_ << " = cast "; @@ -283,16 +289,6 @@ void dump_ir_pass::operator()(load_inst const &e) { visit(*this, *e.result(0).ty()); } -void dump_ir_pass::operator()(group_id_inst const &g) { - dump_val(g.result(0)); - *os_ << " = group_id"; -} - -void dump_ir_pass::operator()(group_size_inst const &g) { - dump_val(g.result(0)); - *os_ << " = group_size"; -} - void dump_ir_pass::operator()(lifetime_stop_inst const &l) { *os_ << "lifetime_stop "; dump_val(l.object()); @@ -393,11 +389,6 @@ void dump_ir_pass::operator()(if_inst const &in) { } } -void dump_ir_pass::operator()(num_subgroups_inst const &sg) { - dump_val(sg.result(0)); - *os_ << " = num_subgroups"; -} - void dump_ir_pass::operator()(parallel_inst const &p) { *os_ << "parallel "; dump_region(p.body()); @@ -412,21 +403,6 @@ void dump_ir_pass::operator()(size_inst const &s) { visit(*this, *s.result(0).ty()); } -void dump_ir_pass::operator()(subgroup_id_inst const &sg) { - dump_val(sg.result(0)); - *os_ << " = subgroup_id"; -} - -void dump_ir_pass::operator()(subgroup_local_id_inst const &sg) { - dump_val(sg.result(0)); - *os_ << " = subgroup_local_id"; -} - -void dump_ir_pass::operator()(subgroup_size_inst const &sg) { - dump_val(sg.result(0)); - *os_ << " = subgroup_size"; -} - void dump_ir_pass::operator()(subview_inst const &s) { dump_val(s.result(0)); *os_ << " = subview "; diff --git a/src/pass/dump_ir.hpp b/src/pass/dump_ir.hpp index 627059ca..89051324 100644 --- a/src/pass/dump_ir.hpp +++ b/src/pass/dump_ir.hpp @@ -35,6 +35,7 @@ class dump_ir_pass { void operator()(arith_inst const &a); void operator()(arith_unary_inst const &a); void operator()(barrier_inst const &b); + void operator()(builtin_inst const &in); void operator()(cast_inst const &c); void operator()(compare_inst const &c); void operator()(constant_inst const &c); @@ -45,8 +46,6 @@ class dump_ir_pass { void operator()(expand_inst const &e); void operator()(fuse_inst const &f); void operator()(load_inst const &e); - void operator()(group_id_inst const &g); - void operator()(group_size_inst const &g); void operator()(lifetime_stop_inst const &l); void operator()(gemm_inst const &g); void operator()(gemv_inst const &g); @@ -55,12 +54,8 @@ class dump_ir_pass { void operator()(foreach_inst const &p); void operator()(hadamard_inst const &g); void operator()(if_inst const &in); - void operator()(num_subgroups_inst const &sg); void operator()(parallel_inst const &p); void operator()(size_inst const &s); - void operator()(subgroup_id_inst const &sg); - void operator()(subgroup_local_id_inst const &sg); - void operator()(subgroup_size_inst const &sg); void operator()(subview_inst const &s); void operator()(store_inst const &s); void operator()(sum_inst const &s); diff --git a/src/pass/lower_foreach.cpp b/src/pass/lower_foreach.cpp index 3a10affc..11b40563 100644 --- a/src/pass/lower_foreach.cpp +++ b/src/pass/lower_foreach.cpp @@ -20,7 +20,8 @@ void make_loop0(region_builder &bb, value from, value to, value sg_id, int sgs, auto ity = from->ty(); auto ctx = compiler_context{sg_id->context(), true}; auto bool_ty = get_boolean(ctx); - auto sg_lid_i32 = bb.add(make_subgroup_local_id(ctx, loc)); + auto i32_ty = get_scalar(ctx, scalar_type::i32); + auto sg_lid_i32 = bb.add(make_builtin(builtin::subgroup_local_id, i32_ty, loc)); auto sg_lid = bb.add(make_cast(sg_lid_i32, ity, loc)); auto size = bb.add(make_arith(arithmetic::sub, to, from, ity, loc)); auto work_item_offset = bb.add(make_arith(arithmetic::add, from, sg_lid, ity, loc)); @@ -54,7 +55,8 @@ auto foreach_generator::operator()(foreach_inst &in) -> inst { tinytc_region_t body = ¶llel->child_region(0); auto bb = region_builder{body}; - auto sg_id = bb.add(make_subgroup_id(compiler_context{in.context(), true}, in.loc())); + auto i32_ty = scalar_data_type::get(in.context(), scalar_type::i32); + auto sg_id = bb.add(make_builtin(builtin::subgroup_id, i32_ty, in.loc())); auto cloner = inst_cloner{}; auto loop_vars = in.loop_vars().begin(); diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index 67c0117b..eb2323f7 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -197,9 +197,9 @@ void linalg_generator::operator()(axpby_inst &in) { tinytc_region_t body = ¶llel->child_region(0); auto bb = region_builder{body}; - auto sg_id = bb.add(make_subgroup_id(ctx, in.loc())); - auto sg_lid = bb.add(make_subgroup_local_id(ctx, in.loc())); auto i32_ty = get_scalar(ctx, scalar_type::i32); + auto sg_id = bb.add(make_builtin(builtin::subgroup_id, i32_ty, in.loc())); + auto sg_lid = bb.add(make_builtin(builtin::subgroup_local_id, i32_ty, in.loc())); auto c0 = bb.add(make_constant(0, i32_ty)); auto cond0 = bb.add(make_cmp(cmp_condition::eq, sg_id, c0, bool_ty, in.loc())); auto cond1 = bb.add(make_cmp(cmp_condition::eq, sg_lid, c0, bool_ty, in.loc())); @@ -273,7 +273,7 @@ void linalg_generator::operator()(gemm_inst &in) { auto i32_ty = get_scalar(ctx, scalar_type::i32); auto index_ty = get_scalar(ctx, scalar_type::index); - auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); + auto sgid = bb.add(make_builtin(builtin::subgroup_id, i32_ty, in.loc())); auto c_m_tiles = bb.add(make_constant(tiling_.m_tiles(), i32_ty, in.loc())); auto sg_n = bb.add(make_arith(arithmetic::div, sgid, c_m_tiles, i32_ty, in.loc())); auto sg_m = bb.add(make_arith(arithmetic::rem, sgid, c_m_tiles, i32_ty, in.loc())); @@ -396,8 +396,8 @@ void linalg_generator::operator()(sum_inst &in) { auto bt = get_memref_type(in.B()); if (bt->dim() == 0) { auto c_sgs = bb.add(make_constant(core_cfg_.subgroup_size, i32_ty, in.loc())); - auto sgid = bb.add(make_subgroup_id(ctx, in.loc())); - auto m = bb.add(make_subgroup_local_id(ctx, in.loc())); + auto sgid = bb.add(make_builtin(builtin::subgroup_id, i32_ty, in.loc())); + auto m = bb.add(make_builtin(builtin::subgroup_local_id, i32_ty, in.loc())); auto from0 = bb.add(make_arith(arithmetic::mul, sgid, c_sgs, i32_ty, in.loc())); auto from1 = bb.add(make_arith(arithmetic::add, from0, m, i32_ty, in.loc())); auto from_index = bb.add(make_cast(from1, index_ty, in.loc())); diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index 2d9ec76e..cc1bc3d9 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -83,6 +83,7 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( return exception_to_status_code( [&] { auto const ty_ = get_scalar(ctx_, enum_cast(ty)); + auto const index_ty = get_scalar(ctx_, scalar_type::index); auto const tA_ = enum_cast(tA); auto const tB_ = enum_cast(tB); @@ -110,7 +111,7 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( auto bb = region_builder{fn_body}; - auto gid = bb.add(make_group_id(ctx_, my_loc())); + auto gid = bb.add(make_builtin(builtin::group_id, index_ty, my_loc())); auto at = get_memref(ty_, A_static_sizes, {1, ldA}, address_space::global, my_loc()); auto bt = diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 1ef845e3..1975a442 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -108,7 +108,7 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( auto const body = [&](region_builder &bb, value alpha, value A, value B, bool is_beta_nonzero, value beta_arg, value C) { auto c_M_block_size = bb.add(make_constant(M_block_size, index_ty, my_loc())); - auto gid = bb.add(make_group_id(ctx_, my_loc())); + auto gid = bb.add(make_builtin(builtin::group_id, index_ty, my_loc())); auto m = bb.add( make_arith(arithmetic::mul, gid, c_M_block_size, gid.get_type(), my_loc())); auto beta = is_beta_nonzero ? beta_arg : bb.add(make_constant_zero(ty_, my_loc())); diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index e41ea04a..b3676e91 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -787,6 +787,37 @@ void inst_converter::operator()(barrier_inst const &in) { mod_->add(scope, scope, memory_semantics); } +void inst_converter::operator()(builtin_inst const &in) { + switch (in.builtin_type()) { + case builtin::group_id: { + auto gid = load_builtin(BuiltIn::GlobalInvocationId); + auto index_ty = unique_.spv_ty(scalar_type::index); + declare(in.result(0), + mod_->add(index_ty, gid, std::vector{2})); + break; + } + case builtin::group_size: { + auto gs = load_builtin(BuiltIn::GlobalSize); + auto index_ty = unique_.spv_ty(scalar_type::index); + declare(in.result(0), + mod_->add(index_ty, gs, std::vector{2})); + break; + } + case builtin::num_subgroups: + declare(in.result(0), load_builtin(BuiltIn::NumSubgroups)); + break; + case builtin::subgroup_size: + declare(in.result(0), load_builtin(BuiltIn::SubgroupSize)); + break; + case builtin::subgroup_id: + declare(in.result(0), load_builtin(BuiltIn::SubgroupId)); + break; + case builtin::subgroup_local_id: + declare(in.result(0), load_builtin(BuiltIn::SubgroupLocalInvocationId)); + break; + } +} + void inst_converter::operator()(cast_inst const &in) { auto spv_to_ty = unique_.spv_ty(in.result(0).ty()); @@ -1478,19 +1509,6 @@ void inst_converter::operator()(fuse_inst const &in) { } } -void inst_converter::operator()(group_id_inst const &in) { - auto gid = load_builtin(BuiltIn::GlobalInvocationId); - auto index_ty = unique_.spv_ty(scalar_type::index); - declare(in.result(0), - mod_->add(index_ty, gid, std::vector{2})); -} -void inst_converter::operator()(group_size_inst const &in) { - auto gs = load_builtin(BuiltIn::GlobalSize); - auto index_ty = unique_.spv_ty(scalar_type::index); - declare(in.result(0), - mod_->add(index_ty, gs, std::vector{2})); -} - void inst_converter::operator()(if_inst const &in) { const std::int64_t num_results = num_yielded_vals(in.result_begin(), in.result_end()); @@ -1598,10 +1616,6 @@ void inst_converter::operator()(load_inst const &in) { } } -void inst_converter::operator()(num_subgroups_inst const &in) { - declare(in.result(0), load_builtin(BuiltIn::NumSubgroups)); -} - void inst_converter::operator()(parallel_inst const &in) { run_on_region(in.body()); } void inst_converter::operator()(size_inst const &in) { @@ -1642,16 +1656,6 @@ void inst_converter::operator()(store_inst const &in) { } } -void inst_converter::operator()(subgroup_id_inst const &in) { - declare(in.result(0), load_builtin(BuiltIn::SubgroupId)); -} -void inst_converter::operator()(subgroup_local_id_inst const &in) { - declare(in.result(0), load_builtin(BuiltIn::SubgroupLocalInvocationId)); -} -void inst_converter::operator()(subgroup_size_inst const &in) { - declare(in.result(0), load_builtin(BuiltIn::SubgroupSize)); -} - void inst_converter::operator()(subview_inst const &in) { auto spv_index_ty = unique_.spv_ty(scalar_type::index); auto spv_result_ty = unique_.spv_ty(in.result(0).ty()); diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp index 4aeefca3..13e5ff8c 100644 --- a/src/spv/converter.hpp +++ b/src/spv/converter.hpp @@ -70,6 +70,7 @@ class inst_converter { void operator()(arith_inst const &in); void operator()(arith_unary_inst const &in); void operator()(barrier_inst const &in); + void operator()(builtin_inst const &in); void operator()(cast_inst const &in); void operator()(compare_inst const &in); void operator()(constant_inst const &in); @@ -80,18 +81,12 @@ class inst_converter { void operator()(expand_inst const &in); void operator()(for_inst const &in); void operator()(fuse_inst const &in); - void operator()(group_id_inst const &in); - void operator()(group_size_inst const &in); void operator()(if_inst const &in); void operator()(lifetime_stop_inst const &in); void operator()(load_inst const &in); - void operator()(num_subgroups_inst const &in); void operator()(parallel_inst const &in); void operator()(size_inst const &in); void operator()(store_inst const &in); - void operator()(subgroup_id_inst const &in); - void operator()(subgroup_local_id_inst const &in); - void operator()(subgroup_size_inst const &in); void operator()(subview_inst const &in); void operator()(work_group_inst const &in); void operator()(yield_inst const &in); diff --git a/test/codegen/load.ir b/test/codegen/load.ir index 632ef086..f485201a 100644 --- a/test/codegen/load.ir +++ b/test/codegen/load.ir @@ -5,7 +5,7 @@ func @kernel1(%a: memref, %b: memref, %c: group>) { %c5 = constant 5 : index %0 = load %a[] : f32 - %1 = group_id + %1 = builtin.group_id : index %2 = load %b[%c5, %1] : f32 %3 = load %c[%1] : memref ; CHECK: float x = *a; @@ -15,7 +15,7 @@ func @kernel1(%a: memref, %b: memref, %c: group>) } func @kernel2(%c: group, offset: 21>) { - %0 = group_id + %0 = builtin.group_id : index %1 = load %c[%0] : memref ; CHECK: global float* x1 = *(c + x) + 21; } diff --git a/test/codegen/store.ir b/test/codegen/store.ir index b07ed04b..358b1e4c 100644 --- a/test/codegen/store.ir +++ b/test/codegen/store.ir @@ -4,7 +4,7 @@ ; RUN: %tinytc-oc < %s | filecheck %s func @kernel(%a: memref, %b: memref, %c: f32) { %c5 = constant 5 : index - %1 = group_id + %1 = builtin.group_id : index store %c, %a[] store %c, %b[%c5, %1] ; CHECK: *a = c; diff --git a/test/codegen/subgroup.ir b/test/codegen/subgroup.ir index d6094a0a..0881b43e 100644 --- a/test/codegen/subgroup.ir +++ b/test/codegen/subgroup.ir @@ -4,10 +4,10 @@ ; RUN: %tinytc-oc -O0 < %s | filecheck %s func @t1() { parallel { - %0 = num_subgroups - %1 = subgroup_id - %2 = subgroup_local_id - %3 = subgroup_size + %0 = builtin.num_subgroups : i32 + %1 = builtin.subgroup_id : i32 + %2 = builtin.subgroup_local_id : i32 + %3 = builtin.subgroup_size : i32 } ; CHECK: int x = get_num_sub_groups(); ; CHECK-NEXT: int x1 = get_sub_group_id(); diff --git a/test/opt/check-ir/nesting2.ir b/test/opt/check-ir/nesting2.ir index f9aa0714..8f8e89fa 100644 --- a/test/opt/check-ir/nesting2.ir +++ b/test/opt/check-ir/nesting2.ir @@ -3,6 +3,6 @@ ; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @illegal_nesting() { - %0 = subgroup_id -; CHECK: 6.10-20: SPMD instruction must not be called from collective region + %0 = builtin.subgroup_id : i32 +; CHECK: 6.10-34: SPMD instruction must not be called from collective region } diff --git a/test/opt/insert-barrier.ir b/test/opt/insert-barrier.ir index cb718c0c..636473fc 100644 --- a/test/opt/insert-barrier.ir +++ b/test/opt/insert-barrier.ir @@ -160,7 +160,7 @@ func @no_barrier_spmd(%a: f32, %b: f32, %A: memref, %B: memref %c3 = constant 3 : index %c4 = constant 4 : index parallel { - %0 = subgroup_id + %0 = builtin.subgroup_id : i32 %1 = cmp.eq %0, %c0 : bool if %1 { %2 = load %A[%c3,%c4] : f32 @@ -170,7 +170,7 @@ func @no_barrier_spmd(%a: f32, %b: f32, %A: memref, %B: memref %3 = load %A[%c3,%c4] : f32 ; CHECK-LABEL: func @no_barrier_spmd({{.*}} ; CHECK: parallel { -; CHECK-NEXT: %0 = subgroup_id +; CHECK-NEXT: %0 = builtin.subgroup_id : i32 ; CHECK-NEXT: %1 = cmp.eq %0, %c0 : bool ; CHECK-NEXT: if %1 { ; CHECK-NEXT: %2 = load %A[%c3,%c4] : f32 diff --git a/test/spv/builtin.ir b/test/spv/builtin.ir index d1214d52..6093b3a0 100644 --- a/test/spv/builtin.ir +++ b/test/spv/builtin.ir @@ -31,13 +31,13 @@ ; CHECK: %[[#VAR6]] = OpVariable %[[#PTR_TO_I32]] Input func @tbuiltin() { - %0 = group_id - %1 = group_size - %2 = num_subgroups - %3 = subgroup_size + %0 = builtin.group_id : index + %1 = builtin.group_size : index + %2 = builtin.num_subgroups : i32 + %3 = builtin.subgroup_size : i32 parallel { - %4 = subgroup_id - %5 = subgroup_local_id + %4 = builtin.subgroup_id : i32 + %5 = builtin.subgroup_local_id : i32 } ; CHECK: %[[#VAR1_LOAD:]] = OpLoad %[[#I64V3]] %[[#VAR1]] Aligned 32 ; CHECK: %[[#]] = OpCompositeExtract %[[#I64]] %[[#VAR1_LOAD]] 2 From 36fc12f88f5f88a5243283460285600a37e37293 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 22 Nov 2024 09:20:41 +0100 Subject: [PATCH 126/297] Add experimental env to test SPIR-V backend Signed-off-by: Carsten Uphoff --- src/ze/kernel.cpp | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/src/ze/kernel.cpp b/src/ze/kernel.cpp index 55063eb1..f116cc76 100644 --- a/src/ze/kernel.cpp +++ b/src/ze/kernel.cpp @@ -96,8 +96,11 @@ tinytc_ze_kernel_bundle_create_with_program(ze_module_handle_t *bundle, ze_conte return tinytc_status_invalid_arguments; } + const bool use_spirv_backend = getenv("TINYTC_SPIRV") != nullptr; + tinytc_core_info_t info = nullptr; tinytc_source_t src = nullptr; + tinytc_binary_t bin = nullptr; tinytc_status_t status = tinytc_status_success; if (status = tinytc_ze_core_info_create(&info, device); status != tinytc_status_success) { @@ -107,15 +110,31 @@ tinytc_ze_kernel_bundle_create_with_program(ze_module_handle_t *bundle, ze_conte status != tinytc_status_success) { goto err; } - if (status = tinytc_prog_compile_to_opencl(&src, prg, info); status != tinytc_status_success) { - goto err; - } - if (status = tinytc_ze_kernel_bundle_create_with_source(bundle, context, device, src); - status != tinytc_status_success) { - goto err; + if (!use_spirv_backend) { + if (status = tinytc_prog_compile_to_opencl(&src, prg, info); + status != tinytc_status_success) { + goto err; + } + if (status = tinytc_ze_kernel_bundle_create_with_source(bundle, context, device, src); + status != tinytc_status_success) { + goto err; + } + } else { + if (status = tinytc_prog_compile_to_spirv_and_assemble(&bin, prg, info); + status != tinytc_status_success) { + goto err; + } + if (status = tinytc_ze_kernel_bundle_create_with_binary(bundle, context, device, bin); + status != tinytc_status_success) { + goto err; + } } err: - tinytc_source_release(src); + if (!use_spirv_backend) { + tinytc_source_release(src); + } else { + tinytc_binary_release(bin); + } tinytc_core_info_release(info); return status; From ebe4f489bef0ac9c36ce5307ac7b3d1b634c46c1 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 22 Nov 2024 01:33:43 -0800 Subject: [PATCH 127/297] Example bugfix Signed-off-by: Carsten Uphoff --- examples/benchmark/main.cpp | 30 ++++++++++++++++----------- examples/matrix_chain/test_ader.cpp | 20 +++++++++++------- examples/matrix_chain/test_volume.cpp | 30 ++++++++++++++++++--------- examples/matrix_chain/test_volume.hpp | 2 ++ src/recipe/small_gemm_batched.cpp | 12 +++++------ 5 files changed, 59 insertions(+), 35 deletions(-) diff --git a/examples/benchmark/main.cpp b/examples/benchmark/main.cpp index 2aaf077d..1f51670b 100644 --- a/examples/benchmark/main.cpp +++ b/examples/benchmark/main.cpp @@ -58,6 +58,9 @@ auto gemm_kernel_with_inner_repetition(scalar_type ty, transpose tA, transpose t std::array C_stride, std::int32_t repetitions, bool dump, queue q) -> source { auto ctx = make_compiler_context(); + ctx.set_error_reporter( + [](char const *what, const tinytc_location_t *, void *) { std::cerr << what << std::endl; }, + nullptr); char const *file_name = std::source_location::current().file_name(); auto const source_id = ctx.add_source(file_name, ""); @@ -70,23 +73,25 @@ auto gemm_kernel_with_inner_repetition(scalar_type ty, transpose tA, transpose t ++l.end.column; return l; }; - auto const make_type = [](data_type element_ty, transpose t, int64_t A, std::int64_t B, - std::array const &stride, location const &loc) { + auto const make_memref = [](data_type element_ty, transpose t, int64_t A, std::int64_t B, + std::array const &stride, location const &loc) { auto s = std::array{A, B}; if (t == transpose::T) { std::swap(s[0], s[1]); } - auto mr = get_memref(element_ty, s, stride, address_space::global, loc); - return get_group(mr, 0, loc); + return get_memref(element_ty, s, stride, address_space::global, loc); }; auto kernel = [&](compiler_context const &ctx) { auto index_ty = get_scalar(ctx, scalar_type::index); auto element_ty = get_scalar(ctx, ty); - auto A_ty = make_type(element_ty, tA, M, K, A_stride, my_loc()); - auto B_ty = make_type(element_ty, tB, K, N, B_stride, my_loc()); - auto C_ty = make_type(element_ty, transpose::N, M, N, C_stride, my_loc()); - auto f = make_func("gemm", {A_ty, B_ty, C_ty}, my_loc()); + auto A_ty = make_memref(element_ty, tA, M, K, A_stride, my_loc()); + auto B_ty = make_memref(element_ty, tB, K, N, B_stride, my_loc()); + auto C_ty = make_memref(element_ty, transpose::N, M, N, C_stride, my_loc()); + auto f = make_func("gemm", + {get_group(A_ty, 0, my_loc()), get_group(B_ty, 0, my_loc()), + get_group(C_ty, 0, my_loc())}, + my_loc()); auto fn_body = f.get_body(); auto params = std::array{}; fn_body.get_parameters(params); @@ -98,9 +103,9 @@ auto gemm_kernel_with_inner_repetition(scalar_type ty, transpose tA, transpose t auto calpha = bb.add(make_constant_one(element_ty, my_loc())); auto cbeta = bb.add(update ? make_constant_one(element_ty, my_loc()) : make_constant_zero(element_ty, my_loc())); - auto a = bb.add(make_load(params[0], {gid}, element_ty, my_loc())); - auto b = bb.add(make_load(params[1], {gid}, element_ty, my_loc())); - auto c = bb.add(make_load(params[2], {gid}, element_ty, my_loc())); + auto a = bb.add(make_load(params[0], {gid}, A_ty, my_loc())); + auto b = bb.add(make_load(params[1], {gid}, B_ty, my_loc())); + auto c = bb.add(make_load(params[2], {gid}, C_ty, my_loc())); bb.for_loop( index_ty, from, to, [&](region_builder &bb, value const &) { @@ -123,7 +128,8 @@ auto gemm_kernel_with_inner_repetition(scalar_type ty, transpose tA, transpose t return compile_to_opencl(std::move(p), info); } catch (builder_error const &e) { ctx.report_error(e.loc(), e.what()); - std::cerr << "Error (" << static_cast(e.code()) << "): " << std::endl; + std::cerr << "Error (" << static_cast(e.code()) << "): " << error_string(e.code()) + << std::endl; } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; } diff --git a/examples/matrix_chain/test_ader.cpp b/examples/matrix_chain/test_ader.cpp index 6bb0a471..112135ab 100644 --- a/examples/matrix_chain/test_ader.cpp +++ b/examples/matrix_chain/test_ader.cpp @@ -103,16 +103,19 @@ auto test_ader::make_optimized_kernel(bool dump) auto const static_sizes3 = [](matrix_batch const &b) -> std::array { return {b.nrows(), b.ncols(), 0}; }; + auto const static_sizes2 = [](matrix_batch const &b) -> std::array { + return {b.nrows(), b.ncols()}; + }; auto const offsets3 = array_view(gid); - auto dqt = get_memref(element_ty, static_sizes3(dQ_[0])); + auto dqt = get_memref(element_ty, static_sizes2(dQ_[0]), {1, dynamic}); auto dq = bb.add(make_subview(Q, static_offsets3, static_sizes3(dQ_[0]), offsets3, {}, dqt)); for (std::size_t d = 0; d < dim; ++d) { - auto At = get_memref(element_ty, static_sizes3(A_[d])); + auto At = get_memref(element_ty, static_sizes2(A_[d])); A(d) = bb.add(make_subview(A(d), static_offsets3, static_sizes3(A_[d]), offsets3, {}, At)); } - auto it = get_memref(element_ty, static_sizes3(I_opt_)); + auto it = get_memref(element_ty, static_sizes2(I_opt_), {1, dynamic}); auto i = bb.add(make_subview(I, static_offsets3, static_sizes3(I_opt_), offsets3, {}, it)); bb.add(make_axpby(transpose::N, false, c1, dq, c1, i)); @@ -126,20 +129,20 @@ auto test_ader::make_optimized_kernel(bool dump) auto cfactor = bb.add(make_arith(arithmetic::div, cnum, cdenom, cnum.get_type())); auto bn = Bd_aligned(N_ - n); auto dq_next = bb.add(make_alloca(dQ_[n].local_type(element_ty))); - auto dq_nextvt = get_memref(element_ty, {bn, P_}, {}, address_space::local); + auto dq_nextvt = get_memref(element_ty, {bn, P_}, {1, dynamic}, address_space::local); auto dq_nextv = bb.add(make_subview(dq_next, static_offsets2, {bn, P_}, {}, {}, dq_nextvt)); auto tmp = bb.add( - make_alloca(get_memref(element_ty, {bn, P_}, {1, bn}, address_space::local))); + make_alloca(get_memref(element_ty, {bn, P_}, {1, dynamic}, address_space::local))); for (std::size_t d = 0; d < dim; ++d) { - auto Kvt = get_memref(element_ty, {bn, Bd(N_ - n + 1)}); + auto Kvt = get_memref(element_ty, {bn, Bd(N_ - n + 1)}, {1, dynamic}); auto Kv = bb.add(make_subview(K(d), static_offsets2, {bn, Bd(N_ - n + 1)}, {}, {}, Kvt)); bb.add(make_gemm(transpose::N, transpose::N, false, c1, Kv, dq, c0, tmp)); bb.add(make_gemm(transpose::N, transpose::N, false, c1, tmp, A(d), d > 0 ? c1 : c0, dq_nextv)); } - auto ivt = get_memref(element_ty, {Bd(N_ - n), P_}); + auto ivt = get_memref(element_ty, {Bd(N_ - n), P_}, {1, dynamic}); auto iv = bb.add(make_subview(i, static_offsets2, {Bd(N_ - n), P_}, {}, {}, ivt)); bb.add(make_axpby(transpose::N, false, cfactor, dq_next, c1, iv)); dq = dq_next; @@ -148,6 +151,9 @@ auto test_ader::make_optimized_kernel(bool dump) return f; }; auto ctx = make_compiler_context(); + ctx.set_error_reporter( + [](char const *what, const tinytc_location_t *, void *) { std::cerr << what << std::endl; }, + nullptr); auto p = make_prog(ctx); p.add_function(opt_kernel(ctx)); if (dump) { diff --git a/examples/matrix_chain/test_volume.cpp b/examples/matrix_chain/test_volume.cpp index 3b24f601..926cddd4 100644 --- a/examples/matrix_chain/test_volume.cpp +++ b/examples/matrix_chain/test_volume.cpp @@ -22,7 +22,7 @@ test_volume::test_volume(std::int64_t N, std::int64_t P, std::int64_t howmany Q_ref_(B3_, P_, B3_aligned_, howmany_, q_), Q_opt_(B3_, P_, B3_aligned_, howmany_, q_), I_(B3_, P_, B3_aligned_, howmany_, q_), tmp_(B3_, P_, B2_aligned_, howmany_, q_), A_(dim, matrix_batch(P_, P_, P_, howmany_, q_)), - K_(dim, matrix_batch(B3_, B3_, B3_aligned_, 1, q_)), + K_(dim, matrix_batch(B3_, B3_, B3_aligned_, 1, q_)), ctx_(make_compiler_context()), opt_bundle_(make_optimized_kernel(dump)), opt_kernel_(make_kernel(opt_bundle_, "volume_kernel")) { Q_ref_.random(); @@ -39,11 +39,19 @@ test_volume::test_volume(std::int64_t N, std::int64_t P, std::int64_t howmany g_.emplace_back(make_recipe_handler( q_, make_small_gemm_batched(dev_info_, to_scalar_type_v, transpose::N, transpose::N, B2_aligned_, P_, P_, B3_aligned_, B3_aligned_ * P_, P_, P_ * P_, - B2_aligned_, B2_aligned_ * P_))); + B2_aligned_, B2_aligned_ * P_, ctx_))); g_.emplace_back(make_recipe_handler( q_, make_small_gemm_batched(dev_info_, to_scalar_type_v, transpose::N, transpose::N, B3_aligned_, P_, B2_, B3_aligned_, 0, B2_aligned_, - B2_aligned_ * P_, B3_aligned_, B3_aligned_ * P_))); + B2_aligned_ * P_, B3_aligned_, B3_aligned_ * P_, ctx_))); +} + +template auto test_volume::make_compiler_context() -> compiler_context { + auto ctx = ::tinytc::make_compiler_context(); + ctx.set_error_reporter( + [](char const *what, const tinytc_location_t *, void *) { std::cerr << what << std::endl; }, + nullptr); + return ctx; } template @@ -85,6 +93,9 @@ auto test_volume::make_optimized_kernel(bool dump) auto gid = bb.add(make_builtin(builtin::group_id, get_scalar(ctx, scalar_type::index))); auto const static_offsets2 = std::array{0, 0}; auto const static_offsets3 = std::array{0, 0, dynamic}; + auto const static_sizes2 = [](matrix_batch const &b) -> std::array { + return {b.nrows(), b.ncols()}; + }; auto const static_sizes3 = [](matrix_batch const &b) -> std::array { return {b.nrows(), b.ncols(), 0}; }; @@ -93,14 +104,14 @@ auto test_volume::make_optimized_kernel(bool dump) auto tmp = bb.add( make_alloca(get_memref(element_ty, {B2_aligned_, P_}, {}, address_space::local))); - auto a0t = get_memref(element_ty, static_sizes3(A_[0])); - auto a1t = get_memref(element_ty, static_sizes3(A_[1])); - auto a2t = get_memref(element_ty, static_sizes3(A_[2])); + auto a0t = get_memref(element_ty, static_sizes2(A_[0])); + auto a1t = get_memref(element_ty, static_sizes2(A_[1])); + auto a2t = get_memref(element_ty, static_sizes2(A_[2])); auto k0t = get_memref(element_ty, sizeK2); auto k1t = get_memref(element_ty, sizeK2); auto k2t = get_memref(element_ty, sizeK2); auto qvt = get_memref(element_ty, {B3_aligned_, P_}); - auto ivt = get_memref(element_ty, {B2_aligned_, P_}); + auto ivt = get_memref(element_ty, {B2_aligned_, P_}, {1, dynamic}); auto tmpvt = get_memref(element_ty, {B2_, P_}, {}, address_space::local); auto a0 = bb.add(make_subview(A(0), static_offsets3, static_sizes3(A_[0]), offsets3, {}, a0t)); @@ -125,9 +136,8 @@ auto test_volume::make_optimized_kernel(bool dump) return f; }; - auto ctx = make_compiler_context(); - auto p = make_prog(ctx); - p.add_function(opt_kernel(ctx)); + auto p = make_prog(ctx_); + p.add_function(opt_kernel(ctx_)); if (dump) { p.dump(); } diff --git a/examples/matrix_chain/test_volume.hpp b/examples/matrix_chain/test_volume.hpp index ad163e44..d0d07d40 100644 --- a/examples/matrix_chain/test_volume.hpp +++ b/examples/matrix_chain/test_volume.hpp @@ -36,12 +36,14 @@ template class test_volume : public test { private: constexpr static std::size_t dim = 3; auto make_optimized_kernel(bool dump) -> sycl::kernel_bundle; + auto make_compiler_context() -> tinytc::compiler_context; std::int64_t B3_, B2_, P_, howmany_, B3_aligned_, B2_aligned_; sycl::queue q_; tinytc::core_info dev_info_; matrix_batch Q_ref_, Q_opt_, I_, tmp_; std::vector> A_, K_; + tinytc::compiler_context ctx_; sycl::kernel_bundle opt_bundle_; sycl::kernel opt_kernel_; std::vector g_; diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index cc1bc3d9..deb8e4c3 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -112,12 +112,12 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( auto bb = region_builder{fn_body}; auto gid = bb.add(make_builtin(builtin::group_id, index_ty, my_loc())); - auto at = - get_memref(ty_, A_static_sizes, {1, ldA}, address_space::global, my_loc()); - auto bt = - get_memref(ty_, B_static_sizes, {1, ldB}, address_space::global, my_loc()); - auto ct = - get_memref(ty_, C_static_sizes, {1, ldC}, address_space::global, my_loc()); + auto at = get_memref(ty_, array_view(A_static_sizes.data(), 2), {1, ldA}, + address_space::global, my_loc()); + auto bt = get_memref(ty_, array_view(B_static_sizes.data(), 2), {1, ldB}, + address_space::global, my_loc()); + auto ct = get_memref(ty_, array_view(C_static_sizes.data(), 2), {1, ldC}, + address_space::global, my_loc()); auto a = bb.add(make_subview(params[1], static_offsets, A_static_sizes, array_view{gid}, {}, at, my_loc())); auto b = bb.add(make_subview(params[2], static_offsets, B_static_sizes, From 009b721a75a61ecf5e66b62f1010dac18ea888eb Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 22 Nov 2024 11:23:55 +0100 Subject: [PATCH 128/297] Started removing OpenCL-C backend Signed-off-by: Carsten Uphoff --- docs/api/cl/capi.rst | 7 -- docs/api/cl/capi.yaml | 1 - docs/api/cl/cxxapi.rst | 7 -- docs/api/cl/cxxapi.yaml | 1 - docs/api/core_capi.rst | 92 ++---------------- docs/api/core_capi.yaml | 14 +-- docs/api/core_cxxapi.rst | 22 ----- docs/api/core_cxxapi.yaml | 4 - docs/api/sycl/cxxapi.rst | 7 -- docs/api/sycl/cxxapi.yaml | 1 - docs/api/ze/capi.rst | 14 --- docs/api/ze/capi.yaml | 2 - docs/api/ze/cxxapi.rst | 14 --- docs/api/ze/cxxapi.yaml | 2 - examples/benchmark/main.cpp | 6 +- examples/jit/main.cpp | 2 +- examples/matrix_chain/test_ader.cpp | 3 +- examples/matrix_chain/test_volume.cpp | 3 +- include/tinytc/tinytc.h | 103 +------------------- include/tinytc/tinytc.hpp | 88 ++--------------- include/tinytc/tinytc_cl.h | 15 --- include/tinytc/tinytc_cl.hpp | 16 ---- include/tinytc/tinytc_sycl.hpp | 13 --- include/tinytc/tinytc_ze.h | 29 ------ include/tinytc/tinytc_ze.hpp | 33 ------- include/tinytc/types.h | 8 -- src/CMakeLists.txt | 1 - src/cl/kernel.cpp | 57 +---------- src/cl/recipe_handler.cpp | 2 +- src/compiler.cpp | 29 ------ src/recipe.cpp | 6 +- src/recipe.hpp | 8 +- src/recipe/small_gemm_batched.cpp | 10 +- src/recipe/small_gemm_batched.hpp | 2 +- src/recipe/tall_and_skinny.cpp | 10 +- src/recipe/tall_and_skinny.hpp | 2 +- src/source.cpp | 84 ----------------- src/source.hpp | 40 -------- src/sycl/kernel.cpp | 4 - src/sycl/recipe_handler.cpp | 2 +- src/ze/CMakeLists.txt | 3 - src/ze/error.hpp | 11 +-- src/ze/kernel.cpp | 101 ++------------------ src/ze/opencl_cc.cpp | 125 ------------------------- src/ze/opencl_cc.hpp | 52 ---------- src/ze/recipe_handler.cpp | 2 +- test/CMakeLists.txt | 2 +- test/spv/alloca.ir | 2 +- test/spv/arith.ir | 2 +- test/spv/arith_unary.ir | 2 +- test/spv/barrier.ir | 2 +- test/spv/builtin.ir | 2 +- test/spv/calling_convention.ir | 2 +- test/spv/cast.ir | 2 +- test/spv/compare.ir | 2 +- test/spv/cooperative_matrix_load.ir | 2 +- test/spv/cooperative_matrix_mul_add.ir | 2 +- test/spv/cooperative_matrix_scale.ir | 2 +- test/spv/cooperative_matrix_store.ir | 2 +- test/spv/expand.ir | 2 +- test/spv/for.ir | 2 +- test/spv/fuse.ir | 2 +- test/spv/if.ir | 2 +- test/spv/load.ir | 2 +- test/spv/size.ir | 2 +- test/spv/store.ir | 2 +- test/spv/subview.ir | 2 +- test/spv/unique_function_type.ir | 2 +- test/spv/work_group.ir | 2 +- tools/offline_compiler/main.cpp | 42 ++------- 70 files changed, 94 insertions(+), 1052 deletions(-) delete mode 100644 src/source.cpp delete mode 100644 src/source.hpp delete mode 100644 src/ze/opencl_cc.cpp delete mode 100644 src/ze/opencl_cc.hpp diff --git a/docs/api/cl/capi.rst b/docs/api/cl/capi.rst index f4d2d8a2..98472964 100644 --- a/docs/api/cl/capi.rst +++ b/docs/api/cl/capi.rst @@ -51,8 +51,6 @@ Kernel * :ref:`tinytc_cl_get_group_size` - * :ref:`tinytc_cl_kernel_bundle_create_with_source` - * :ref:`tinytc_cl_kernel_bundle_create_with_program` * :ref:`tinytc_cl_kernel_bundle_create_with_binary` @@ -70,11 +68,6 @@ tinytc_cl_get_group_size .. doxygenfunction:: tinytc_cl_get_group_size -tinytc_cl_kernel_bundle_create_with_source -.......................................... - -.. doxygenfunction:: tinytc_cl_kernel_bundle_create_with_source - tinytc_cl_kernel_bundle_create_with_program ........................................... diff --git a/docs/api/cl/capi.yaml b/docs/api/cl/capi.yaml index 67ed7e61..b16a3346 100644 --- a/docs/api/cl/capi.yaml +++ b/docs/api/cl/capi.yaml @@ -12,7 +12,6 @@ C-API: function: - tinytc_cl_get_global_size - tinytc_cl_get_group_size - - tinytc_cl_kernel_bundle_create_with_source - tinytc_cl_kernel_bundle_create_with_program - tinytc_cl_kernel_bundle_create_with_binary Recipe: diff --git a/docs/api/cl/cxxapi.rst b/docs/api/cl/cxxapi.rst index 0bfa8e83..cd1461e4 100644 --- a/docs/api/cl/cxxapi.rst +++ b/docs/api/cl/cxxapi.rst @@ -57,8 +57,6 @@ Kernel * :ref:`make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t)` - * :ref:`make_kernel_bundle(cl_context,cl_device_id,source const&)` - Kernel Functions ---------------- @@ -87,11 +85,6 @@ make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t) .. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t) -make_kernel_bundle(cl_context,cl_device_id,source const&) -......................................................... - -.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,source const&) - Recipe ====== diff --git a/docs/api/cl/cxxapi.yaml b/docs/api/cl/cxxapi.yaml index fc1e9d6e..e53fa0e0 100644 --- a/docs/api/cl/cxxapi.yaml +++ b/docs/api/cl/cxxapi.yaml @@ -15,7 +15,6 @@ C++-API: - tinytc::make_kernel(cl_program,char const*) - tinytc::make_kernel_bundle(cl_context,cl_device_id,binary const&) - tinytc::make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t) - - tinytc::make_kernel_bundle(cl_context,cl_device_id,source const&) Recipe: function: - tinytc::make_recipe_handler(cl_context,cl_device_id,recipe const&) diff --git a/docs/api/core_capi.rst b/docs/api/core_capi.rst index 44a2668f..9afb93f1 100644 --- a/docs/api/core_capi.rst +++ b/docs/api/core_capi.rst @@ -46,8 +46,6 @@ Common * :ref:`tinytc_recipe_handler_t` - * :ref:`tinytc_source_t` - * :ref:`tinytc_spv_mod_t` * :ref:`tinytc_compiler_context_t` @@ -60,8 +58,6 @@ Common * :ref:`const_tinytc_recipe_handler_t` - * :ref:`const_tinytc_source_t` - * :ref:`const_tinytc_spv_mod_t` * :ref:`const_tinytc_compiler_context_t` @@ -155,11 +151,6 @@ tinytc_recipe_handler_t .. doxygentypedef:: tinytc_recipe_handler_t -tinytc_source_t -............... - -.. doxygentypedef:: tinytc_source_t - tinytc_spv_mod_t ................ @@ -190,11 +181,6 @@ const_tinytc_recipe_handler_t .. doxygentypedef:: const_tinytc_recipe_handler_t -const_tinytc_source_t -..................... - -.. doxygentypedef:: const_tinytc_source_t - const_tinytc_spv_mod_t ...................... @@ -275,8 +261,6 @@ Compiler * :ref:`tinytc_list_function_passes` - * :ref:`tinytc_prog_compile_to_opencl` - * :ref:`tinytc_prog_compile_to_spirv` * :ref:`tinytc_prog_compile_to_spirv_and_assemble` @@ -309,11 +293,6 @@ tinytc_list_function_passes .. doxygenfunction:: tinytc_list_function_passes -tinytc_prog_compile_to_opencl -............................. - -.. doxygenfunction:: tinytc_prog_compile_to_opencl - tinytc_prog_compile_to_spirv ............................ @@ -540,9 +519,9 @@ Recipe * Functions - * :ref:`tinytc_recipe_get_prog` + * :ref:`tinytc_recipe_get_binary` - * :ref:`tinytc_recipe_get_source` + * :ref:`tinytc_recipe_get_prog` * :ref:`tinytc_recipe_handler_get_recipe` @@ -577,16 +556,16 @@ tinytc_mem_type_t Recipe Functions ---------------- +tinytc_recipe_get_binary +........................ + +.. doxygenfunction:: tinytc_recipe_get_binary + tinytc_recipe_get_prog ...................... .. doxygenfunction:: tinytc_recipe_get_prog -tinytc_recipe_get_source -........................ - -.. doxygenfunction:: tinytc_recipe_get_source - tinytc_recipe_handler_get_recipe ................................ @@ -685,60 +664,3 @@ tinytc_spv_mod_retain .. doxygenfunction:: tinytc_spv_mod_retain -Source -====== - -* Functions - - * :ref:`tinytc_source_get_code` - - * :ref:`tinytc_source_get_compiler_context` - - * :ref:`tinytc_source_get_core_features` - - * :ref:`tinytc_source_get_location` - - * :ref:`tinytc_source_get_extensions` - - * :ref:`tinytc_source_release` - - * :ref:`tinytc_source_retain` - -Source Functions ----------------- - -tinytc_source_get_code -...................... - -.. doxygenfunction:: tinytc_source_get_code - -tinytc_source_get_compiler_context -.................................. - -.. doxygenfunction:: tinytc_source_get_compiler_context - -tinytc_source_get_core_features -............................... - -.. doxygenfunction:: tinytc_source_get_core_features - -tinytc_source_get_location -.......................... - -.. doxygenfunction:: tinytc_source_get_location - -tinytc_source_get_extensions -............................ - -.. doxygenfunction:: tinytc_source_get_extensions - -tinytc_source_release -..................... - -.. doxygenfunction:: tinytc_source_release - -tinytc_source_retain -.................... - -.. doxygenfunction:: tinytc_source_retain - diff --git a/docs/api/core_capi.yaml b/docs/api/core_capi.yaml index 5b0fe1f6..1983792e 100644 --- a/docs/api/core_capi.yaml +++ b/docs/api/core_capi.yaml @@ -21,14 +21,12 @@ Core C-API: - tinytc_core_info_t - tinytc_recipe_t - tinytc_recipe_handler_t - - tinytc_source_t - tinytc_spv_mod_t - tinytc_compiler_context_t - const_tinytc_binary_t - const_tinytc_core_info_t - const_tinytc_recipe_t - const_tinytc_recipe_handler_t - - const_tinytc_source_t - const_tinytc_spv_mod_t - const_tinytc_compiler_context_t - tinytc_error_reporter_t @@ -47,7 +45,6 @@ Core C-API: function: - tinytc_run_function_pass - tinytc_list_function_passes - - tinytc_prog_compile_to_opencl - tinytc_prog_compile_to_spirv - tinytc_prog_compile_to_spirv_and_assemble - tinytc_spirv_assemble @@ -87,8 +84,8 @@ Core C-API: enum: - tinytc_mem_type_t function: + - tinytc_recipe_get_binary - tinytc_recipe_get_prog - - tinytc_recipe_get_source - tinytc_recipe_handler_get_recipe - tinytc_recipe_small_gemm_batched_create - tinytc_recipe_small_gemm_batched_set_args @@ -107,12 +104,3 @@ Core C-API: - tinytc_spv_mod_print_to_string - tinytc_spv_mod_release - tinytc_spv_mod_retain - Source: - function: - - tinytc_source_get_code - - tinytc_source_get_compiler_context - - tinytc_source_get_core_features - - tinytc_source_get_location - - tinytc_source_get_extensions - - tinytc_source_release - - tinytc_source_retain diff --git a/docs/api/core_cxxapi.rst b/docs/api/core_cxxapi.rst index 15ce00d5..6c1128a4 100644 --- a/docs/api/core_cxxapi.rst +++ b/docs/api/core_cxxapi.rst @@ -160,8 +160,6 @@ Compiler * :ref:`list_function_passes` - * :ref:`compile_to_opencl` - * :ref:`compile_to_spirv` * :ref:`compile_to_spirv_and_assemble` @@ -181,11 +179,6 @@ list_function_passes .. doxygenfunction:: tinytc::list_function_passes -compile_to_opencl -................. - -.. doxygenfunction:: tinytc::compile_to_opencl - compile_to_spirv ................ @@ -465,18 +458,3 @@ spv_mod .. doxygenclass:: tinytc::spv_mod -Source -====== - -* Classes - - * :ref:`source` - -Source Classes --------------- - -source -...... - -.. doxygenclass:: tinytc::source - diff --git a/docs/api/core_cxxapi.yaml b/docs/api/core_cxxapi.yaml index f8cb1c6a..e49bae81 100644 --- a/docs/api/core_cxxapi.yaml +++ b/docs/api/core_cxxapi.yaml @@ -29,7 +29,6 @@ Core C++-API: function: - tinytc::run_function_pass - tinytc::list_function_passes - - tinytc::compile_to_opencl - tinytc::compile_to_spirv - tinytc::compile_to_spirv_and_assemble - tinytc::spirv_assemble @@ -77,6 +76,3 @@ Core C++-API: SPIR-V module: class: - tinytc::spv_mod - Source: - class: - - tinytc::source diff --git a/docs/api/sycl/cxxapi.rst b/docs/api/sycl/cxxapi.rst index 2db5996f..3458b74e 100644 --- a/docs/api/sycl/cxxapi.rst +++ b/docs/api/sycl/cxxapi.rst @@ -44,8 +44,6 @@ Kernel * :ref:`make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t)` - * :ref:`make_kernel_bundle(sycl::context const &,sycl::device const &,source const &)` - Kernel Functions ---------------- @@ -79,11 +77,6 @@ make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_f .. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t) -make_kernel_bundle(sycl::context const &,sycl::device const &,source const &) -............................................................................. - -.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,source const &) - Recipe ====== diff --git a/docs/api/sycl/cxxapi.yaml b/docs/api/sycl/cxxapi.yaml index 2d2a7a05..3649a2fb 100644 --- a/docs/api/sycl/cxxapi.yaml +++ b/docs/api/sycl/cxxapi.yaml @@ -13,7 +13,6 @@ C++-API: - tinytc::make_kernel(sycl::kernel_bundle const &,char const *) - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &) - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t) - - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,source const &) Recipe: function: - tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &) diff --git a/docs/api/ze/capi.rst b/docs/api/ze/capi.rst index d0c0e6ca..5be3377b 100644 --- a/docs/api/ze/capi.rst +++ b/docs/api/ze/capi.rst @@ -57,10 +57,6 @@ Kernel * :ref:`tinytc_ze_kernel_bundle_create_with_program` - * :ref:`tinytc_ze_kernel_bundle_create_with_source` - - * :ref:`tinytc_ze_source_compile_to_binary` - Kernel Functions ---------------- @@ -89,16 +85,6 @@ tinytc_ze_kernel_bundle_create_with_program .. doxygenfunction:: tinytc_ze_kernel_bundle_create_with_program -tinytc_ze_kernel_bundle_create_with_source -.......................................... - -.. doxygenfunction:: tinytc_ze_kernel_bundle_create_with_source - -tinytc_ze_source_compile_to_binary -.................................. - -.. doxygenfunction:: tinytc_ze_source_compile_to_binary - Recipe ====== diff --git a/docs/api/ze/capi.yaml b/docs/api/ze/capi.yaml index acfde096..4696a841 100644 --- a/docs/api/ze/capi.yaml +++ b/docs/api/ze/capi.yaml @@ -15,8 +15,6 @@ C-API: - tinytc_ze_kernel_create - tinytc_ze_kernel_bundle_create_with_binary - tinytc_ze_kernel_bundle_create_with_program - - tinytc_ze_kernel_bundle_create_with_source - - tinytc_ze_source_compile_to_binary Recipe: function: - tinytc_ze_recipe_handler_create diff --git a/docs/api/ze/cxxapi.rst b/docs/api/ze/cxxapi.rst index eedae3b3..f37d9a11 100644 --- a/docs/api/ze/cxxapi.rst +++ b/docs/api/ze/cxxapi.rst @@ -47,8 +47,6 @@ Kernel * Functions - * :ref:`compile_to_binary` - * :ref:`get_group_count` * :ref:`get_group_size(ze_kernel_handle_t)` @@ -59,16 +57,9 @@ Kernel * :ref:`make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t)` - * :ref:`make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&)` - Kernel Functions ---------------- -compile_to_binary -................. - -.. doxygenfunction:: tinytc::compile_to_binary - get_group_count ............... @@ -94,11 +85,6 @@ make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_featu .. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t) -make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&) -........................................................................ - -.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&) - Recipe ====== diff --git a/docs/api/ze/cxxapi.yaml b/docs/api/ze/cxxapi.yaml index ad480350..f01608a9 100644 --- a/docs/api/ze/cxxapi.yaml +++ b/docs/api/ze/cxxapi.yaml @@ -10,13 +10,11 @@ C++-API: - tinytc::make_core_info(ze_device_handle_t) Kernel: function: - - tinytc::compile_to_binary - tinytc::get_group_count - tinytc::get_group_size(ze_kernel_handle_t) - tinytc::make_kernel(ze_module_handle_t,char const *) - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&) - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t) - - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&) Recipe: function: - tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&) diff --git a/examples/benchmark/main.cpp b/examples/benchmark/main.cpp index 1f51670b..6f01a236 100644 --- a/examples/benchmark/main.cpp +++ b/examples/benchmark/main.cpp @@ -56,7 +56,7 @@ auto gemm_kernel_with_inner_repetition(scalar_type ty, transpose tA, transpose t std::array A_stride, std::array B_stride, bool update, std::array C_stride, - std::int32_t repetitions, bool dump, queue q) -> source { + std::int32_t repetitions, bool dump, queue q) -> binary { auto ctx = make_compiler_context(); ctx.set_error_reporter( [](char const *what, const tinytc_location_t *, void *) { std::cerr << what << std::endl; }, @@ -125,7 +125,7 @@ auto gemm_kernel_with_inner_repetition(scalar_type ty, transpose tA, transpose t auto info = make_core_info(q.get_device()); info.set_core_features(tinytc_core_feature_flag_large_register_file); - return compile_to_opencl(std::move(p), info); + return compile_to_spirv_and_assemble(std::move(p), info); } catch (builder_error const &e) { ctx.report_error(e.loc(), e.what()); std::cerr << "Error (" << static_cast(e.code()) << "): " << error_string(e.code()) @@ -133,7 +133,7 @@ auto gemm_kernel_with_inner_repetition(scalar_type ty, transpose tA, transpose t } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; } - return source{nullptr}; + return binary{nullptr}; } template void test(queue q, args &a) { diff --git a/examples/jit/main.cpp b/examples/jit/main.cpp index 01444c4e..887c2535 100644 --- a/examples/jit/main.cpp +++ b/examples/jit/main.cpp @@ -21,7 +21,7 @@ int main(int argc, char **argv) { if (!prog) { return -1; } - compile_to_opencl(std::move(prog), info); + compile_to_spirv_and_assemble(std::move(prog), info); } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; return 1; diff --git a/examples/matrix_chain/test_ader.cpp b/examples/matrix_chain/test_ader.cpp index 112135ab..bb745852 100644 --- a/examples/matrix_chain/test_ader.cpp +++ b/examples/matrix_chain/test_ader.cpp @@ -159,7 +159,8 @@ auto test_ader::make_optimized_kernel(bool dump) if (dump) { p.dump(); } - return make_kernel_bundle(q_.get_context(), q_.get_device(), compile_to_opencl(p, dev_info_)); + return make_kernel_bundle(q_.get_context(), q_.get_device(), + compile_to_spirv_and_assemble(p, dev_info_)); } template diff --git a/examples/matrix_chain/test_volume.cpp b/examples/matrix_chain/test_volume.cpp index 926cddd4..86a712d8 100644 --- a/examples/matrix_chain/test_volume.cpp +++ b/examples/matrix_chain/test_volume.cpp @@ -141,7 +141,8 @@ auto test_volume::make_optimized_kernel(bool dump) if (dump) { p.dump(); } - return make_kernel_bundle(q_.get_context(), q_.get_device(), compile_to_opencl(p, dev_info_)); + return make_kernel_bundle(q_.get_context(), q_.get_device(), + compile_to_spirv_and_assemble(p, dev_info_)); } template std::vector test_volume::reference() { diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 853f4134..eede8224 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -1485,18 +1485,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_run_function_pass(char const *pass_name, ti TINYTC_EXPORT tinytc_status_t tinytc_list_function_passes(uint32_t *names_size, char const *const **names); -/** - * @brief Compile tensor language to OpenCL-C - * - * @param src [out] pointer to the source object created - * @param prg [inout] tensor program; modified as compiler passes are run - * @param info [in] core info object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_t prg, - const_tinytc_core_info_t info); - /** * @brief Compile tensor language to SPIR-V * @@ -1532,67 +1520,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_prog_compile_to_spirv_and_assemble( TINYTC_EXPORT tinytc_status_t tinytc_spirv_assemble(tinytc_binary_t *bin, const_tinytc_spv_mod_t mod); -/** - * @brief Get source text - * - * @param src [in] source object - * @param length [out] pointer to code length - * @param code [out] code contains a pointer to the source text; the pointer is only valid as - * long as the source object is alive - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_get_code(const_tinytc_source_t src, size_t *length, - char const **code); - -/** - * @brief Get context object from source object - * - * @param src [in] source object - * @param ctx [out] pointer to context object; reference count is increased so the user needs to - * call tinytc_compiler_context_release to clean up - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_get_compiler_context(const_tinytc_source_t src, - tinytc_compiler_context_t *ctx); - -/** - * @brief Get source location - * - * @param src [in] source object - * @param loc [out] pointer to location - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_get_location(const_tinytc_source_t src, - tinytc_location_t *loc); - -/** - * @brief Get core features - * - * @param src [in] source object - * @param core_features [out] pointer to core features - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_get_core_features( - const_tinytc_source_t src, tinytc_core_feature_flags_t *core_features); - -/** - * @brief Get required OpenCL extensions - * - * @param src [in] source object - * @param extensions_size [out] pointer to number of extensions - * @param extensions [out][range(0,extensions_size)] pointer to array of C-strings; array owned - * by source object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_get_extensions(const_tinytc_source_t src, - uint32_t *extensions_size, - char const *const **extensions); - /** * @brief Create binary * @@ -1648,26 +1575,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_binary_get_raw(const_tinytc_binary_t bin, TINYTC_EXPORT tinytc_status_t tinytc_binary_get_core_features( const_tinytc_binary_t bin, tinytc_core_feature_flags_t *core_features); -/** - * @brief Release source object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param obj [inout] source object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_release(tinytc_source_t obj); - -/** - * @brief Increase reference count of source object by 1 - * - * @param obj [inout] source object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_retain(tinytc_source_t obj); - /** * @brief Release binary object * @@ -1890,16 +1797,16 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_get_prog(const_tinytc_recipe_t recip tinytc_prog_t *prg); /** - * @brief Get source object + * @brief Get binary * * @param recipe [in] recipe object - * @param src [out] pointer to source object; reference count is increased so the user needs to - * call tinytc_source_release to clean up + * @param bin [out] pointer to binary; reference count is increased so the user needs to + * call tinytc_binary_release to clean up * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_recipe_get_source(const_tinytc_recipe_t recipe, - tinytc_source_t *src); +TINYTC_EXPORT tinytc_status_t tinytc_recipe_get_binary(const_tinytc_recipe_t recipe, + tinytc_binary_t *bin); /** * @brief Release recipe object diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 9c8ef7d7..92ac5b95 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -2158,68 +2158,6 @@ inline auto parse_string(std::string const &src, compiler_context const &ctx = { ///////// Compiler ///////// //////////////////////////// -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_source_t handle) -> tinytc_status_t { - return tinytc_source_retain(handle); - } - static auto release(tinytc_source_t handle) -> tinytc_status_t { - return tinytc_source_release(handle); - } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_source_t -class source : public shared_handle { - public: - using shared_handle::shared_handle; - - /** - * @brief Get code - * - * @return Pointer to C-string that is bound to the lifetime of the source object - */ - inline auto get_code() const -> std::string_view { - char const *code = nullptr; - std::size_t length = 0; - CHECK_STATUS(tinytc_source_get_code(obj_, &length, &code)); - return std::string_view(code, length); - } - - /** - * @brief Get compiler context - * - * @return Compiler context - */ - inline auto get_compiler_context() const -> compiler_context { - tinytc_compiler_context_t ctx; - CHECK_STATUS(tinytc_source_get_compiler_context(obj_, &ctx)); - return compiler_context{ctx, true}; - } - - /** - * @brief Get location - * - * @return Location - */ - inline auto get_location() const -> location { - location loc = {}; - CHECK_STATUS(tinytc_source_get_location(obj_, &loc)); - return loc; - } - - /** - * @brief Get OpenCL extension - * - * @param extensions_size Number of extensions - * @param extensions Array of extensions - */ - inline void get_extensions(std::uint32_t &extensions_size, - char const *const *&extensions) const { - CHECK_STATUS(tinytc_source_get_extensions(obj_, &extensions_size, &extensions)); - } -}; - namespace internal { template <> struct shared_handle_traits { static auto retain(tinytc_binary_t handle) -> tinytc_status_t { @@ -2319,20 +2257,6 @@ inline void list_function_passes(std::uint32_t &names_size, char const *const *& CHECK_STATUS(tinytc_list_function_passes(&names_size, &names)); } -/** - * @brief Compile program to OpenCL-C - * - * @param prg Program - * @param info Core info - * - * @return Source - */ -inline auto compile_to_opencl(prog prg, core_info const &info) -> source { - tinytc_source_t src; - CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, prg.get(), info.get())); - return source{src}; -} - /** * @brief Convert tensor language to SPIR-V * @@ -2484,14 +2408,14 @@ class recipe : public shared_handle { } /** - * @brief Get source + * @brief Get binary * - * @return Source + * @return Binary */ - auto get_source() const -> source { - tinytc_source_t src; - CHECK_STATUS(tinytc_recipe_get_source(obj_, &src)); - return source{src}; + auto get_binary() const -> binary { + tinytc_binary_t bin; + CHECK_STATUS(tinytc_recipe_get_binary(obj_, &bin)); + return binary{bin}; } }; diff --git a/include/tinytc/tinytc_cl.h b/include/tinytc/tinytc_cl.h index 191e0b40..dc77aaf3 100644 --- a/include/tinytc/tinytc_cl.h +++ b/include/tinytc/tinytc_cl.h @@ -57,21 +57,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_cl_core_info_create(tinytc_core_info_t *inf ////////// Kernel ////////// //////////////////////////// -/** - * @brief Compile OpenCL-C source to device binary - * - * @param bundle [out] pointer to the kernel bundle (cl_program) object created - * @param context [in] context handle - * @param device [in] device handle - * @param src [in] source text and extensions - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_source(cl_program *bundle, - cl_context context, - cl_device_id device, - const_tinytc_source_t src); - /** * @brief Compile tensor program * diff --git a/include/tinytc/tinytc_cl.hpp b/include/tinytc/tinytc_cl.hpp index 2165c4fd..23d03537 100644 --- a/include/tinytc/tinytc_cl.hpp +++ b/include/tinytc/tinytc_cl.hpp @@ -77,22 +77,6 @@ template <> struct shared_handle_traits { }; } // namespace internal -/** - * @brief Make an OpenCL program from a tinytc source - * - * @param context Context - * @param device Device - * @param src Source - * - * @return cl_program (shared handle) - */ -inline auto make_kernel_bundle(cl_context context, cl_device_id device, - source const &src) -> shared_handle { - cl_program obj; - CHECK_STATUS(tinytc_cl_kernel_bundle_create_with_source(&obj, context, device, src.get())); - return shared_handle{obj}; -} - /** * @brief Make an OpenCL program from a tinytc program * diff --git a/include/tinytc/tinytc_sycl.hpp b/include/tinytc/tinytc_sycl.hpp index 9b53af7c..7c7bb21d 100644 --- a/include/tinytc/tinytc_sycl.hpp +++ b/include/tinytc/tinytc_sycl.hpp @@ -41,19 +41,6 @@ TINYTC_EXPORT auto make_core_info(sycl::device const &dev) -> core_info; ////////// Kernel ////////// //////////////////////////// -/** - * @brief Make SYCL kernel bundle from tinytc source - * - * @param ctx Context - * @param dev Device - * @param src Source - * - * @return SYCL kernel bundle - */ -TINYTC_EXPORT auto -make_kernel_bundle(sycl::context const &ctx, sycl::device const &dev, - source const &src) -> sycl::kernel_bundle; - /** * @brief Make SYCL kernel bundle from tinytc program * diff --git a/include/tinytc/tinytc_ze.h b/include/tinytc/tinytc_ze.h index 2539076a..0564eca7 100644 --- a/include/tinytc/tinytc_ze.h +++ b/include/tinytc/tinytc_ze.h @@ -57,35 +57,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_ze_core_info_create(tinytc_core_info_t *inf ////////// Kernel ////////// //////////////////////////// -/** - * @brief Compile OpenCL-C source to device binary - * - * @param bin [out] pointer to the binary object created - * @param src [in] source text - * @param ip_version [in] IP version (pass tinytc_intel_gpu_architecture_t here) - * @param format [in] binary format (SPIR-V or native) - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_ze_source_compile_to_binary(tinytc_binary_t *bin, - const_tinytc_source_t src, - uint32_t ip_version, - tinytc_bundle_format_t format); - -/** - * @brief Compile OpenCL-C source to device binary - * - * @param bundle [out] pointer to the kernel bundle (ze_module_handle_t) object created - * @param context [in] context handle - * @param device [in] device handle - * @param src [in] source text and extensions - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t -tinytc_ze_kernel_bundle_create_with_source(ze_module_handle_t *bundle, ze_context_handle_t context, - ze_device_handle_t device, const_tinytc_source_t src); - /** * @brief Compile tensor program * diff --git a/include/tinytc/tinytc_ze.hpp b/include/tinytc/tinytc_ze.hpp index a3bc4b6f..94264282 100644 --- a/include/tinytc/tinytc_ze.hpp +++ b/include/tinytc/tinytc_ze.hpp @@ -58,23 +58,6 @@ inline auto make_core_info(ze_device_handle_t device) -> core_info { ////////// Kernel ////////// //////////////////////////// -/** - * @brief Compile source to binary - * - * @param src Source object - * @param ip_version IP version (pass tinytc_intel_gpu_architecture_t here) - * @param format Bundle format (SPIR-V or Native) - * - * @return Binary - */ -inline auto compile_to_binary(source const &src, std::uint32_t ip_version, - bundle_format format) -> binary { - tinytc_binary_t bin; - CHECK_STATUS(tinytc_ze_source_compile_to_binary(&bin, src.get(), ip_version, - static_cast(format))); - return binary{bin}; -} - namespace internal { template <> struct unique_handle_traits { static void destroy(ze_kernel_handle_t obj) { zeKernelDestroy(obj); } @@ -84,22 +67,6 @@ template <> struct unique_handle_traits { }; } // namespace internal -/** - * @brief Make a Level Zero module from a tinytc source - * - * @param context Context - * @param device Device - * @param src Source - * - * @return Level Zero module (unique handle) - */ -inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t device, - source const &src) -> unique_handle { - ze_module_handle_t obj; - CHECK_STATUS(tinytc_ze_kernel_bundle_create_with_source(&obj, context, device, src.get())); - return unique_handle{obj}; -} - /** * @brief Make a Level Zero module from a tinytc program * diff --git a/include/tinytc/types.h b/include/tinytc/types.h index b11c2470..248f64c5 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -486,14 +486,6 @@ typedef struct tinytc_core_info *tinytc_core_info_t; //! @brief const core_info handle typedef const struct tinytc_core_info *const_tinytc_core_info_t; -//! @struct tinytc_source; -//! @brief Opaque struct for source text -struct tinytc_source; -//! @brief source handle -typedef struct tinytc_source *tinytc_source_t; -//! @brief const source handle -typedef const struct tinytc_source *const_tinytc_source_t; - //! @struct tintyc_compiler_context //! @brief Opaque struct for compiler context struct tinytc_compiler_context; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 80c330e3..967ba264 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -77,7 +77,6 @@ set(SOURCES spv/pass/dump_asm.cpp spv/pass/capex.cpp spv/uniquifier.cpp - source.cpp tiling.cpp value.cpp support/walk.cpp diff --git a/src/cl/kernel.cpp b/src/cl/kernel.cpp index 741dd780..b3340c57 100644 --- a/src/cl/kernel.cpp +++ b/src/cl/kernel.cpp @@ -21,54 +21,6 @@ using tinytc::compiler_context; extern "C" { -tinytc_status_t tinytc_cl_kernel_bundle_create_with_source(cl_program *bundle, cl_context context, - cl_device_id device, - const_tinytc_source_t src) { - if (bundle == nullptr || src == nullptr) { - return tinytc_status_invalid_arguments; - } - - size_t length = 0; - char const *code = nullptr; - tinytc_core_feature_flags_t core_features = 0; - tinytc_compiler_context_t ctx = nullptr; - TINYTC_CHECK_STATUS(tinytc_source_get_code(src, &length, &code)); - TINYTC_CHECK_STATUS(tinytc_source_get_core_features(src, &core_features)); - TINYTC_CHECK_STATUS(tinytc_source_get_compiler_context(src, &ctx)); - auto ctx_ = compiler_context{ctx}; // Clean-up ctx when ctx_ gets out of scope - - cl_int err; - cl_program p = clCreateProgramWithSource(context, 1, &code, &length, &err); - TINYTC_CL_CHECK_STATUS(err); - - auto options = std::ostringstream{}; - for (auto const &opt : tinytc::default_compiler_options) { - options << opt << " "; - } - if (core_features & tinytc_core_feature_flag_large_register_file) { - options << tinytc::large_register_file_compiler_option_cl; - } - auto options_str = std::move(options).str(); - if (err = clBuildProgram(p, 1, &device, options_str.c_str(), nullptr, nullptr); - err != CL_SUCCESS) { - if (ctx_.get()) { - std::string log; - std::size_t log_size; - clGetProgramBuildInfo(p, device, CL_PROGRAM_BUILD_LOG, 0, nullptr, &log_size); - log.resize(log_size); - clGetProgramBuildInfo(p, device, CL_PROGRAM_BUILD_LOG, log_size, log.data(), nullptr); - - tinytc_location_t loc = {}; - tinytc_source_get_location(src, &loc); - tinytc_compiler_context_report_error(ctx_.get(), &loc, log.c_str()); - } - clReleaseProgram(p); - TINYTC_CL_CHECK_STATUS(err); - } - *bundle = p; - return tinytc_status_success; -} - tinytc_status_t tinytc_cl_kernel_bundle_create_with_program(cl_program *bundle, cl_context context, cl_device_id device, tinytc_prog_t prg, @@ -78,7 +30,7 @@ tinytc_cl_kernel_bundle_create_with_program(cl_program *bundle, cl_context conte } tinytc_core_info_t info = nullptr; - tinytc_source_t src = nullptr; + tinytc_binary_t bin = nullptr; tinytc_status_t status = tinytc_status_success; if (status = tinytc_cl_core_info_create(&info, device); status != tinytc_status_success) { @@ -88,15 +40,16 @@ tinytc_cl_kernel_bundle_create_with_program(cl_program *bundle, cl_context conte status != tinytc_status_success) { goto err; } - if (status = tinytc_prog_compile_to_opencl(&src, prg, info); status != tinytc_status_success) { + if (status = tinytc_prog_compile_to_spirv_and_assemble(&bin, prg, info); + status != tinytc_status_success) { goto err; } - if (status = tinytc_cl_kernel_bundle_create_with_source(bundle, context, device, src); + if (status = tinytc_cl_kernel_bundle_create_with_binary(bundle, context, device, bin); status != tinytc_status_success) { goto err; } err: - tinytc_source_release(src); + tinytc_binary_release(bin); tinytc_core_info_release(info); return status; diff --git a/src/cl/recipe_handler.cpp b/src/cl/recipe_handler.cpp index 59761897..db091bbf 100644 --- a/src/cl/recipe_handler.cpp +++ b/src/cl/recipe_handler.cpp @@ -20,7 +20,7 @@ namespace tinytc { cl_recipe_handler::cl_recipe_handler(cl_context context, cl_device_id device, recipe rec) : ::tinytc_recipe_handler(std::move(rec)) { - module_ = make_kernel_bundle(context, device, get_recipe().get_source()); + module_ = make_kernel_bundle(context, device, get_recipe().get_binary()); auto const num_kernels = get_recipe()->num_kernels(); kernels_.reserve(num_kernels); diff --git a/src/compiler.cpp b/src/compiler.cpp index a8ed5aa3..d58546d0 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -21,7 +21,6 @@ #include "passes.hpp" #include "reference_counted.hpp" #include "required_extensions.hpp" -#include "source.hpp" #include "spv/pass/assemble.hpp" #include "spv/pass/assign_ids.hpp" #include "tinytc/tinytc.h" @@ -129,34 +128,6 @@ tinytc_status_t tinytc_list_function_passes(uint32_t *names_size, char const *co return tinytc_status_success; } -tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_t prg, - const_tinytc_core_info_t info) { - if (src == nullptr || prg == nullptr || info == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code( - [&] { - apply_default_optimization_pipeline(prg, info); - - // opencl - auto ast = convert_to_opencl_pass{info}.run_on_program(*prg); - clir::make_names_unique(ast); - - auto oss = std::ostringstream{}; - auto ext = required_extensions(ast); - for (auto const &e : ext) { - oss << "#pragma OPENCL EXTENSION " << e << " : enable" << std::endl; - } - - clir::generate_opencl(oss, std::move(ast)); - - *src = std::make_unique<::tinytc_source>(prg->share_context(), oss.str(), prg->loc(), - std::move(ext), info->core_features()) - .release(); - }, - prg->context()); -} - tinytc_status_t tinytc_prog_compile_to_spirv(tinytc_spv_mod_t *mod, tinytc_prog_t prg, const_tinytc_core_info_t info) { if (mod == nullptr || prg == nullptr || info == nullptr) { diff --git a/src/recipe.cpp b/src/recipe.cpp index 222e9e27..0727ab3d 100644 --- a/src/recipe.cpp +++ b/src/recipe.cpp @@ -53,12 +53,12 @@ tinytc_status_t tinytc_recipe_get_prog(const_tinytc_recipe_t recipe, tinytc_prog [&] { *prg = tinytc::prog(recipe->get_program()).release(); }); } -tinytc_status_t tinytc_recipe_get_source(const_tinytc_recipe_t recipe, tinytc_source_t *src) { - if (recipe == nullptr || src == nullptr) { +tinytc_status_t tinytc_recipe_get_binary(const_tinytc_recipe_t recipe, tinytc_binary_t *bin) { + if (recipe == nullptr || bin == nullptr) { return tinytc_status_invalid_arguments; } return tinytc::exception_to_status_code( - [&] { *src = tinytc::source(recipe->get_source()).release(); }); + [&] { *bin = tinytc::binary(recipe->get_binary()).release(); }); } tinytc_status_t tinytc_recipe_release(tinytc_recipe_t obj) { diff --git a/src/recipe.hpp b/src/recipe.hpp index 879245a0..6d3ef819 100644 --- a/src/recipe.hpp +++ b/src/recipe.hpp @@ -19,19 +19,19 @@ auto is_argument_zero(scalar_type type, std::size_t arg_size, const void *arg_va struct tinytc_recipe : tinytc::reference_counted { public: - inline tinytc_recipe(tinytc::prog prg, tinytc::source src) - : prg_(std::move(prg)), src_(std::move(src)) {} + inline tinytc_recipe(tinytc::prog prg, tinytc::binary bin) + : prg_(std::move(prg)), bin_(std::move(bin)) {} virtual ~tinytc_recipe() = default; inline auto get_program() const -> tinytc::prog const & { return prg_; } - inline auto get_source() const -> tinytc::source const & { return src_; } + inline auto get_binary() const -> tinytc::binary const & { return bin_; } virtual auto num_kernels() const -> int = 0; virtual auto kernel_name(int kernel_num) const -> char const * = 0; private: tinytc::prog prg_; - tinytc::source src_; + tinytc::binary bin_; }; struct tinytc_recipe_handler : tinytc::reference_counted { diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index deb8e4c3..43ac3724 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -32,8 +32,8 @@ auto small_gemm_batched_kernel_name(small_gemm_batched_kernel k) -> char const * } throw status::invalid_arguments; } -small_gemm_batched_recipe::small_gemm_batched_recipe(prog prg, source src, scalar_type ty) - : ::tinytc_recipe(std::move(prg), std::move(src)), ty_(ty) {} +small_gemm_batched_recipe::small_gemm_batched_recipe(prog prg, binary bin, scalar_type ty) + : ::tinytc_recipe(std::move(prg), std::move(bin)), ty_(ty) {} auto small_gemm_batched_recipe::num_kernels() const -> int { return static_cast(small_gemm_batched_kernel::num_kernels); } @@ -135,9 +135,9 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_create( kernel(small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm), true)); p.add_function(kernel( small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm_beta0), false)); - tinytc_source_t src; - CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, p.get(), info)); - *recipe = std::make_unique(std::move(p), source(src), + tinytc_binary_t bin; + CHECK_STATUS(tinytc_prog_compile_to_spirv_and_assemble(&bin, p.get(), info)); + *recipe = std::make_unique(std::move(p), binary(bin), enum_cast(ty)) .release(); }, diff --git a/src/recipe/small_gemm_batched.hpp b/src/recipe/small_gemm_batched.hpp index 0ffdc2da..dab7c6f9 100644 --- a/src/recipe/small_gemm_batched.hpp +++ b/src/recipe/small_gemm_batched.hpp @@ -15,7 +15,7 @@ auto small_gemm_batched_kernel_name(small_gemm_batched_kernel k) -> char const * struct small_gemm_batched_recipe : ::tinytc_recipe { public: - small_gemm_batched_recipe(prog prg, source src, scalar_type ty); + small_gemm_batched_recipe(prog prg, binary bin, scalar_type ty); auto num_kernels() const -> int override; auto kernel_name(int kernel_num) const -> char const * override; diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 1975a442..b1aafd22 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -34,10 +34,10 @@ auto tall_and_skinny_kernel_name(tall_and_skinny_kernel k) -> char const * { } throw status::invalid_arguments; } -tall_and_skinny_recipe::tall_and_skinny_recipe(prog prg, source src, scalar_type ty, std::int64_t M, +tall_and_skinny_recipe::tall_and_skinny_recipe(prog prg, binary bin, scalar_type ty, std::int64_t M, std::int64_t ldA, std::int64_t ldB, std::int64_t ldC, std::int32_t M_block_size) - : ::tinytc_recipe(std::move(prg), std::move(src)), ty_(ty), M_dyn_(is_dynamic_value(M)), + : ::tinytc_recipe(std::move(prg), std::move(bin)), ty_(ty), M_dyn_(is_dynamic_value(M)), ldA_dyn_(is_dynamic_value(ldA)), ldB_dyn_(is_dynamic_value(ldB)), ldC_dyn_(is_dynamic_value(ldC)), M_block_size_(M_block_size) {} auto tall_and_skinny_recipe::num_kernels() const -> int { @@ -186,9 +186,9 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( p.add_function(kernel(tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm), true)); p.add_function( kernel(tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm_beta0), false)); - tinytc_source_t src; - CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, p.get(), info)); - *recipe = std::make_unique(std::move(p), source(src), + tinytc_binary_t bin; + CHECK_STATUS(tinytc_prog_compile_to_spirv_and_assemble(&bin, p.get(), info)); + *recipe = std::make_unique(std::move(p), binary(bin), enum_cast(ty), M, ldA, ldB, ldC, M_block_size) .release(); diff --git a/src/recipe/tall_and_skinny.hpp b/src/recipe/tall_and_skinny.hpp index dd2aaca4..2e8c2b76 100644 --- a/src/recipe/tall_and_skinny.hpp +++ b/src/recipe/tall_and_skinny.hpp @@ -17,7 +17,7 @@ auto tall_and_skinny_kernel_name(tall_and_skinny_kernel k) -> char const *; struct tall_and_skinny_recipe : ::tinytc_recipe { public: - tall_and_skinny_recipe(prog prg, source src, scalar_type ty, std::int64_t M, std::int64_t ldA, + tall_and_skinny_recipe(prog prg, binary bin, scalar_type ty, std::int64_t M, std::int64_t ldA, std::int64_t ldB, std::int64_t ldC, std::int32_t M_block_size); auto num_kernels() const -> int override; auto kernel_name(int kernel_num) const -> char const * override; diff --git a/src/source.cpp b/src/source.cpp deleted file mode 100644 index fcea0153..00000000 --- a/src/source.cpp +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "source.hpp" -#include "error.hpp" -#include "tinytc/tinytc.h" - -#include -#include - -using namespace tinytc; - -tinytc_source::tinytc_source(compiler_context ctx, std::string code, - tinytc_location const &code_loc, - std::vector required_extensions, - tinytc_core_feature_flags_t core_features) - : ctx_{std::move(ctx)}, code_(std::move(code)), code_loc_(code_loc), - required_extensions_(std::move(required_extensions)), core_features_(core_features) {} - -extern "C" { -tinytc_status_t tinytc_source_get_code(const_tinytc_source_t src, size_t *length, - char const **code) { - if (src == nullptr || length == nullptr || code == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *length = src->size(); - *code = src->code(); - }); -} - -tinytc_status_t tinytc_source_get_location(const_tinytc_source_t src, tinytc_location_t *loc) { - if (src == nullptr || loc == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { *loc = src->code_loc(); }); -} - -tinytc_status_t tinytc_source_get_compiler_context(const_tinytc_source_t src, - tinytc_compiler_context_t *ctx) { - if (src == nullptr || ctx == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { *ctx = src->share_context().release(); }); -} - -tinytc_status_t tinytc_source_get_core_features(const_tinytc_source_t src, - tinytc_core_feature_flags_t *core_features) { - if (src == nullptr || core_features == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { *core_features = src->core_features(); }); -} - -tinytc_status_t tinytc_source_get_extensions(const_tinytc_source_t src, uint32_t *extensions_size, - char const *const **extensions) { - if (src == nullptr || extensions_size == nullptr || extensions == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *extensions_size = src->required_extensions().size(); - *extensions = src->required_extensions().data(); - }); -} - -tinytc_status_t tinytc_source_release(tinytc_source_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; -} - -tinytc_status_t tinytc_source_retain(tinytc_source_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - obj->inc_ref(); - return tinytc_status_success; -} -} diff --git a/src/source.hpp b/src/source.hpp deleted file mode 100644 index 13e4797a..00000000 --- a/src/source.hpp +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef SOURCE_20240412_HPP -#define SOURCE_20240412_HPP - -#include "compiler_context.hpp" -#include "reference_counted.hpp" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.h" - -#include -#include -#include - -struct tinytc_source : tinytc::reference_counted { - public: - tinytc_source(tinytc::compiler_context ctx, std::string code, tinytc_location const &code_loc, - std::vector required_extensions, - tinytc_core_feature_flags_t core_features); - - inline auto code() const -> char const * { return code_.c_str(); } - inline auto code_loc() const -> tinytc_location const & { return code_loc_; } - inline auto context() const -> tinytc_compiler_context_t { return ctx_.get(); } - inline auto share_context() const -> tinytc::compiler_context { return ctx_; } - inline auto size() const -> std::size_t { return code_.size(); } - inline auto const &required_extensions() const { return required_extensions_; } - inline auto core_features() const noexcept -> tinytc_core_feature_flags_t { - return core_features_; - } - - private: - tinytc::compiler_context ctx_; - std::string code_; - tinytc_location code_loc_; - std::vector required_extensions_; - tinytc_core_feature_flags_t core_features_; -}; - -#endif // SOURCE_20240412_HPP diff --git a/src/sycl/kernel.cpp b/src/sycl/kernel.cpp index 995e341c..6b97b4da 100644 --- a/src/sycl/kernel.cpp +++ b/src/sycl/kernel.cpp @@ -60,10 +60,6 @@ template <> struct kernel_bundle_dispatcher { } }; -auto make_kernel_bundle(context const &ctx, device const &dev, - source const &src) -> kernel_bundle { - return dispatch(dev.get_backend(), ctx, dev, src); -} auto make_kernel_bundle(context const &ctx, device const &dev, prog prg, tinytc_core_feature_flags_t core_features) -> kernel_bundle { diff --git a/src/sycl/recipe_handler.cpp b/src/sycl/recipe_handler.cpp index e005b72e..a275cb9f 100644 --- a/src/sycl/recipe_handler.cpp +++ b/src/sycl/recipe_handler.cpp @@ -26,7 +26,7 @@ template <> struct arg_handler_dispatcher { sycl_recipe_handler_impl::sycl_recipe_handler_impl(sycl::context const &context, sycl::device const &device, recipe rec) : ::tinytc_recipe_handler(std::move(rec)), - module_(make_kernel_bundle(context, device, get_recipe().get_source())) { + module_(make_kernel_bundle(context, device, get_recipe().get_binary())) { auto const num_kernels = get_recipe()->num_kernels(); kernels_.reserve(num_kernels); diff --git a/src/ze/CMakeLists.txt b/src/ze/CMakeLists.txt index a4b0ed31..901c3024 100644 --- a/src/ze/CMakeLists.txt +++ b/src/ze/CMakeLists.txt @@ -6,13 +6,11 @@ include(GNUInstallDirs) include(InstallLib) find_package(LevelZero REQUIRED) -find_package(ocloc REQUIRED) set(SOURCES device_info.cpp error.cpp kernel.cpp - opencl_cc.cpp recipe_handler.cpp ) set(PUBLIC_HEADERS @@ -25,7 +23,6 @@ add_library(tinytc_ze ${SOURCES}) add_library(tinytc::tinytc_ze ALIAS tinytc_ze) set_cxx_common_options(tinytc_ze) -target_link_libraries(tinytc_ze PRIVATE ocloc::ocloc) target_link_libraries(tinytc_ze PUBLIC tinytc LevelZero::LevelZero) # include directories diff --git a/src/ze/error.hpp b/src/ze/error.hpp index 62c8cbf7..6c7113a7 100644 --- a/src/ze/error.hpp +++ b/src/ze/error.hpp @@ -4,7 +4,6 @@ #ifndef ZE_ERROR_20240419_HPP #define ZE_ERROR_20240419_HPP -#include "opencl_cc.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" @@ -12,21 +11,13 @@ namespace tinytc { -template -auto exception_to_status_code_ze(F &&f, - tinytc_compiler_context_t context = nullptr) -> tinytc_status_t { +template auto exception_to_status_code_ze(F &&f) -> tinytc_status_t { try { f(); } catch (status const &st) { return static_cast(st); } catch (builder_error const &e) { return static_cast(e.code()); - } catch (opencl_c_compilation_error const &e) { - if (context) { - auto const loc = location{}; - tinytc_compiler_context_report_error(context, &loc, e.what()); - } - return tinytc_status_compilation_error; } catch (std::bad_alloc const &e) { return tinytc_status_bad_alloc; } catch (...) { diff --git a/src/ze/kernel.cpp b/src/ze/kernel.cpp index f116cc76..03714c32 100644 --- a/src/ze/kernel.cpp +++ b/src/ze/kernel.cpp @@ -4,7 +4,6 @@ #include "../compiler_options.hpp" #include "../support/util.hpp" #include "error.hpp" -#include "opencl_cc.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" #include "tinytc/tinytc_ze.h" @@ -22,72 +21,6 @@ using namespace tinytc; extern "C" { -tinytc_status_t tinytc_ze_source_compile_to_binary(tinytc_binary_t *bin, const_tinytc_source_t src, - uint32_t ip_version, - tinytc_bundle_format_t format) { - - if (bin == nullptr || src == nullptr) { - return tinytc_status_invalid_arguments; - } - - size_t code_size = 0; - char const *code = nullptr; - tinytc_compiler_context_t ctx = nullptr; - tinytc_core_feature_flags_t core_features = 0; - std::uint32_t extensions_size = 0; - char const *const *extensions = nullptr; - - TINYTC_CHECK_STATUS(tinytc_source_get_code(src, &code_size, &code)); - TINYTC_CHECK_STATUS(tinytc_source_get_core_features(src, &core_features)); - TINYTC_CHECK_STATUS(tinytc_source_get_extensions(src, &extensions_size, &extensions)); - TINYTC_CHECK_STATUS(tinytc_source_get_compiler_context(src, &ctx)); - auto ctx_ = compiler_context{ctx}; // Clean-up ctx when ctx_ gets out of scope - - return exception_to_status_code_ze( - [&] { - auto compiler_options = - std::vector(default_compiler_options.begin(), default_compiler_options.end()); - if (core_features & tinytc_core_feature_flag_large_register_file) { - compiler_options.push_back(large_register_file_compiler_option_ze); - } - auto fmt = enum_cast(format); - auto bin_data = - compile_opencl_c(code_size, code, fmt, ip_version, compiler_options.size(), - compiler_options.data(), extensions_size, extensions); - CHECK_STATUS(tinytc_binary_create(bin, ctx_.get(), format, bin_data.size(), - bin_data.data(), core_features)); - }, - ctx_.get()); -} - -tinytc_status_t tinytc_ze_kernel_bundle_create_with_source(ze_module_handle_t *bundle, - ze_context_handle_t context, - ze_device_handle_t device, - const_tinytc_source_t src) { - if (bundle == nullptr || src == nullptr) { - return tinytc_status_invalid_arguments; - } - - // Get IP version - auto dev_ip_ver = ze_device_ip_version_ext_t{}; - dev_ip_ver.stype = ZE_STRUCTURE_TYPE_DEVICE_IP_VERSION_EXT; - dev_ip_ver.pNext = nullptr; - auto dev_props = ze_device_properties_t{}; - dev_props.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES; - dev_props.pNext = &dev_ip_ver; - TINYTC_ZE_CHECK_STATUS(zeDeviceGetProperties(device, &dev_props)); - - // Get binary - tinytc_binary_t bin = nullptr; - TINYTC_CHECK_STATUS(tinytc_ze_source_compile_to_binary(&bin, src, dev_ip_ver.ipVersion, - tinytc_bundle_format_native)); - - tinytc_status_t status = - tinytc_ze_kernel_bundle_create_with_binary(bundle, context, device, bin); - tinytc_binary_release(bin); - return status; -} - tinytc_status_t tinytc_ze_kernel_bundle_create_with_program(ze_module_handle_t *bundle, ze_context_handle_t context, ze_device_handle_t device, tinytc_prog_t prg, @@ -96,10 +29,7 @@ tinytc_ze_kernel_bundle_create_with_program(ze_module_handle_t *bundle, ze_conte return tinytc_status_invalid_arguments; } - const bool use_spirv_backend = getenv("TINYTC_SPIRV") != nullptr; - tinytc_core_info_t info = nullptr; - tinytc_source_t src = nullptr; tinytc_binary_t bin = nullptr; tinytc_status_t status = tinytc_status_success; @@ -110,31 +40,16 @@ tinytc_ze_kernel_bundle_create_with_program(ze_module_handle_t *bundle, ze_conte status != tinytc_status_success) { goto err; } - if (!use_spirv_backend) { - if (status = tinytc_prog_compile_to_opencl(&src, prg, info); - status != tinytc_status_success) { - goto err; - } - if (status = tinytc_ze_kernel_bundle_create_with_source(bundle, context, device, src); - status != tinytc_status_success) { - goto err; - } - } else { - if (status = tinytc_prog_compile_to_spirv_and_assemble(&bin, prg, info); - status != tinytc_status_success) { - goto err; - } - if (status = tinytc_ze_kernel_bundle_create_with_binary(bundle, context, device, bin); - status != tinytc_status_success) { - goto err; - } + if (status = tinytc_prog_compile_to_spirv_and_assemble(&bin, prg, info); + status != tinytc_status_success) { + goto err; } -err: - if (!use_spirv_backend) { - tinytc_source_release(src); - } else { - tinytc_binary_release(bin); + if (status = tinytc_ze_kernel_bundle_create_with_binary(bundle, context, device, bin); + status != tinytc_status_success) { + goto err; } +err: + tinytc_binary_release(bin); tinytc_core_info_release(info); return status; diff --git a/src/ze/opencl_cc.cpp b/src/ze/opencl_cc.cpp deleted file mode 100644 index b7ff2a07..00000000 --- a/src/ze/opencl_cc.cpp +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -// Code COPIED from Double-Batched FFT Library -// Copyright (C) 2022 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "opencl_cc.hpp" - -#include "ocloc_api.h" - -#include -#include -#include -#include - -namespace tinytc { - -std::vector compile_opencl_c(std::size_t source_length, char const *source_text, - bundle_format format, std::uint32_t ip_version, - std::uint32_t options_size, char const *const *options, - std::uint32_t extensions_size, - char const *const *extensions) { - auto const format_options = [](std::uint32_t options_size, - char const *const *options) -> std::string { - if (options_size == 0) { - return {}; - } - auto oss = std::ostringstream{}; - std::uint32_t opt_it = 0; - oss << options[opt_it++]; - for (; opt_it < options_size; ++opt_it) { - oss << " " << options[opt_it]; - } - return oss.str(); - }; - auto const format_ext_list = [](std::uint32_t extensions_size, - char const *const *extensions) -> std::string { - if (extensions_size == 0) { - return {}; - } - auto oss = std::ostringstream{}; - std::uint32_t ext_it = 0; - oss << "-cl-ext=+" << extensions[ext_it++]; - for (; ext_it < extensions_size; ++ext_it) { - oss << ",+" << extensions[ext_it]; - } - return oss.str(); - }; - unsigned int num_args = 2; - constexpr unsigned int max_num_args = 11; - char const *argv[max_num_args] = {"ocloc", "compile"}; - auto const ext_list = format_ext_list(extensions_size, extensions); - if (!ext_list.empty()) { - argv[num_args++] = "-internal_options"; - argv[num_args++] = ext_list.c_str(); - } - auto const cl_options = format_options(options_size, options); - if (!cl_options.empty()) { - argv[num_args++] = "-options"; - argv[num_args++] = cl_options.c_str(); - } - char device[16]; - snprintf(device, sizeof(device), "%d", ip_version); - if (ip_version != 0) { - argv[num_args++] = "-device"; - argv[num_args++] = device; - } - if (format == bundle_format::spirv) { - argv[num_args++] = "-spv_only"; - } - argv[num_args++] = "-file"; - argv[num_args++] = "kernel.cl"; - - const std::uint32_t num_sources = 1; - const std::uint8_t *data_sources = reinterpret_cast(source_text); - const std::uint64_t len_sources = source_length + 1; - char const *name_sources = argv[num_args - 1]; - std::uint32_t num_input_headers = 0; - std::uint32_t num_outputs = 0; - std::uint8_t **data_outputs = nullptr; - std::uint64_t *len_outputs = nullptr; - char **name_outputs = nullptr; - oclocInvoke(num_args, argv, num_sources, &data_sources, &len_sources, &name_sources, - num_input_headers, nullptr, nullptr, nullptr, &num_outputs, &data_outputs, - &len_outputs, &name_outputs); - - auto const ends_with = [](char const *str, char const *ending) { - auto lstr = strlen(str); - auto lend = strlen(ending); - if (lend > lstr) { - return false; - } - return strncmp(str + (lstr - lend), ending, lend) == 0; - }; - - constexpr std::uint32_t invalid_index = std::numeric_limits::max(); - std::uint32_t log_file = invalid_index; - std::uint32_t bin_file = invalid_index; - for (std::uint32_t o = 0; o < num_outputs; ++o) { - if (strcmp(name_outputs[o], "stdout.log") == 0) { - log_file = o; - } else if (format == bundle_format::spirv && ends_with(name_outputs[o], ".spv")) { - bin_file = o; - } else if (format == bundle_format::native && - (ends_with(name_outputs[o], ".bin") || ends_with(name_outputs[o], ".ar"))) { - bin_file = o; - } - } - if (bin_file == invalid_index) { - if (log_file != invalid_index) { - char *log_ptr = reinterpret_cast(data_outputs[log_file]); - auto log = std::string(log_ptr, len_outputs[log_file]); - throw opencl_c_compilation_error(std::move(log)); - } - throw opencl_c_compilation_error("OpenCL-C compilation failed (no log available)"); - } - - auto result = std::vector(data_outputs[bin_file], - data_outputs[bin_file] + len_outputs[bin_file]); - oclocFreeOutput(&num_outputs, &data_outputs, &len_outputs, &name_outputs); - return result; -} - -} // namespace tinytc diff --git a/src/ze/opencl_cc.hpp b/src/ze/opencl_cc.hpp deleted file mode 100644 index db026c76..00000000 --- a/src/ze/opencl_cc.hpp +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -// Code COPIED from Double-Batched FFT Library -// Copyright (C) 2022 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef OPENCL_CC_20240307_HPP -#define OPENCL_CC_20240307_HPP - -#include "tinytc/types.hpp" - -#include -#include -#include -#include -#include -#include - -namespace tinytc { - -class opencl_c_compilation_error : public std::exception { - public: - opencl_c_compilation_error(std::string build_log) : build_log_(std::move(build_log)) {} - inline char const *what() const noexcept override { return build_log_.c_str(); } - - private: - std::string build_log_; -}; - -/** - * @brief Takes OpenCL-C code and outputs a SPIR-V or native device binary - * - * @param source_length Source text length (excluding zero terminator) - * @param source_text OpenCL-C source code (zero-terminated) - * @param format Target binary format - * @param ip_version Device ip version; you may pass 0 for format==spirv - * @param options List of compiler options - * @param extensions List of OpenCL-C extensions - * - * @return binary - */ -std::vector compile_opencl_c(std::size_t source_length, char const *source_text, - bundle_format format, std::uint32_t ip_version, - std::uint32_t options_size = 0, - char const *const *options = nullptr, - std::uint32_t extensions_size = 0, - char const *const *extensions = nullptr); - -} // namespace tinytc - -#endif // OPENCL_CC_20240307_HPP diff --git a/src/ze/recipe_handler.cpp b/src/ze/recipe_handler.cpp index 9bddeb88..9ff611c8 100644 --- a/src/ze/recipe_handler.cpp +++ b/src/ze/recipe_handler.cpp @@ -20,7 +20,7 @@ ze_recipe_handler::ze_recipe_handler(ze_context_handle_t context, ze_device_hand recipe rec) : ::tinytc_recipe_handler(std::move(rec)) { - module_ = make_kernel_bundle(context, device, get_recipe().get_source()); + module_ = make_kernel_bundle(context, device, get_recipe().get_binary()); auto const num_kernels = get_recipe()->num_kernels(); kernels_.reserve(num_kernels); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a0774c5e..05cd0e6f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -76,7 +76,7 @@ if(SPIRVTools_FOUND) ) foreach(SOURCE IN LISTS SPIRV_VAL_SOURCES) get_filename_component(TEST_NAME ${SOURCE} NAME_WE) - set(CHECK_COMMAND $ -O0 -gspirv ${CMAKE_CURRENT_SOURCE_DIR}/${SOURCE} | ${SPIRVTools_SPIRV_VAL} -) + set(CHECK_COMMAND $ -O0 ${CMAKE_CURRENT_SOURCE_DIR}/${SOURCE} | ${SPIRVTools_SPIRV_VAL} -) list(JOIN CHECK_COMMAND " " CHECK_COMMAND_STR) add_test(NAME spirv-val-${TEST_NAME} COMMAND bash -c "${CHECK_COMMAND_STR}") add_custom_target(spirv-val-${TEST_NAME} COMMAND ${CHECK_COMMAND}) diff --git a/test/spv/alloca.ir b/test/spv/alloca.ir index 99ebd264..c4c71156 100644 --- a/test/spv/alloca.ir +++ b/test/spv/alloca.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s ; CHECK: OpEntryPoint Kernel %[[#]] "alloca" %[[#STACK_VAR:]] ; CHECK: OpDecorate %[[#STACK_PTR_TY:]] Alignment 16 diff --git a/test/spv/arith.ir b/test/spv/arith.ir index 3b600b5b..cf3878d2 100644 --- a/test/spv/arith.ir +++ b/test/spv/arith.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s ; CHECK: OpCapability Int64 ; CHECK: %[[#BOOL:]] = OpTypeBool diff --git a/test/spv/arith_unary.ir b/test/spv/arith_unary.ir index d3c41d22..82a766fd 100644 --- a/test/spv/arith_unary.ir +++ b/test/spv/arith_unary.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s ; CHECK: OpCapability Int64 ; CHECK: %[[#EXT:]] = OpExtInstImport "OpenCL.std" diff --git a/test/spv/barrier.ir b/test/spv/barrier.ir index 28654897..af0e07be 100644 --- a/test/spv/barrier.ir +++ b/test/spv/barrier.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s ; CHECK: %[[#I32:]] = OpTypeInt 32 0 ; CHECK: %[[#SCOPE:]] = OpConstant %[[#I32]] 2 diff --git a/test/spv/builtin.ir b/test/spv/builtin.ir index 6093b3a0..4af32834 100644 --- a/test/spv/builtin.ir +++ b/test/spv/builtin.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s ; CHECK: OpEntryPoint Kernel %[[#]] "tbuiltin" %[[#VAR1:]] %[[#VAR2:]] %[[#VAR3:]] %[[#VAR4:]] %[[#VAR5:]] %[[#VAR6:]] diff --git a/test/spv/calling_convention.ir b/test/spv/calling_convention.ir index 05b4c608..f8527b48 100644 --- a/test/spv/calling_convention.ir +++ b/test/spv/calling_convention.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s ; CHECK: OpDecorate %[[#PTR_F32:]] Alignment 4 ; CHECK: OpDecorate %[[#PTR_I16:]] Alignment 2 diff --git a/test/spv/cast.ir b/test/spv/cast.ir index 62114064..984014b8 100644 --- a/test/spv/cast.ir +++ b/test/spv/cast.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s ; CHECK: OpCapability Int64 ; CHECK: OpCapability Int8 diff --git a/test/spv/compare.ir b/test/spv/compare.ir index 1957c0cf..2447cbc4 100644 --- a/test/spv/compare.ir +++ b/test/spv/compare.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s ; CHECK: %[[#BOOL:]] = OpTypeBool ; CHECK: %[[#BOOL2:]] = OpTypeVector %[[#BOOL]] 2 diff --git a/test/spv/cooperative_matrix_load.ir b/test/spv/cooperative_matrix_load.ir index 8679884e..e0561b38 100644 --- a/test/spv/cooperative_matrix_load.ir +++ b/test/spv/cooperative_matrix_load.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s ; CHECK: %[[#I32:]] = OpTypeInt 32 0 ; CHECK: %[[#I32_PTR:]] = OpTypePointer CrossWorkgroup %[[#I32]] diff --git a/test/spv/cooperative_matrix_mul_add.ir b/test/spv/cooperative_matrix_mul_add.ir index 1fb06168..6b8c464b 100644 --- a/test/spv/cooperative_matrix_mul_add.ir +++ b/test/spv/cooperative_matrix_mul_add.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -O0 -gspirv -S < %s | filecheck %s +; RUN: %tinytc-oc -O0 -S < %s | filecheck %s ; CHECK: %[[#F32:]] = OpTypeFloat 32 ; CHECK: %[[#F32_C1:]] = OpConstant %[[#F32]] 0x1p+0 diff --git a/test/spv/cooperative_matrix_scale.ir b/test/spv/cooperative_matrix_scale.ir index 235a6f37..12de6441 100644 --- a/test/spv/cooperative_matrix_scale.ir +++ b/test/spv/cooperative_matrix_scale.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s ; CHECK: %[[#F32:]] = OpTypeFloat 32 ; CHECK: %[[#F32_CPI:]] = OpConstant %[[#F32]] 0x1.921fb6p+1 diff --git a/test/spv/cooperative_matrix_store.ir b/test/spv/cooperative_matrix_store.ir index bc90ae03..d6b6334c 100644 --- a/test/spv/cooperative_matrix_store.ir +++ b/test/spv/cooperative_matrix_store.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s ; CHECK: %[[#I32:]] = OpTypeInt 32 0 ; CHECK: %[[#I32_PTR:]] = OpTypePointer CrossWorkgroup %[[#I32]] diff --git a/test/spv/expand.ir b/test/spv/expand.ir index a7417352..54e2a03a 100644 --- a/test/spv/expand.ir +++ b/test/spv/expand.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -O0 -S -gspirv < %s | filecheck %s +; RUN: %tinytc-oc -O0 -S < %s | filecheck %s ; CHECK: %[[#I64:]] = OpTypeInt 64 0 ; CHECK: %[[#I64_C32:]] = OpConstant %[[#I64]] 32 diff --git a/test/spv/for.ir b/test/spv/for.ir index 659fe509..a20642af 100644 --- a/test/spv/for.ir +++ b/test/spv/for.ir @@ -16,7 +16,7 @@ ; CHECK: %[[#I16_C6:]] = OpConstant %[[#I16]] 6 ; CHECK: %[[#I16_C1:]] = OpConstant %[[#I16]] 1 -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s func @for1() { %lb = constant 0 : i16 %ub = constant 10 : i16 diff --git a/test/spv/fuse.ir b/test/spv/fuse.ir index 2120747a..cd9136af 100644 --- a/test/spv/fuse.ir +++ b/test/spv/fuse.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -O0 -S -gspirv < %s | filecheck %s +; RUN: %tinytc-oc -O0 -S < %s | filecheck %s ; CHECK: %[[#F32:]] = OpTypeFloat 32 ; CHECK: %[[#F32_PTR:]] = OpTypePointer CrossWorkgroup %[[#F32]] diff --git a/test/spv/if.ir b/test/spv/if.ir index 1a4182e0..fbb832e1 100644 --- a/test/spv/if.ir +++ b/test/spv/if.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s ; CHECK: %[[#I32:]] = OpTypeInt 32 0 ; CHECK: %[[#BOOL:]] = OpTypeBool diff --git a/test/spv/load.ir b/test/spv/load.ir index 7e493443..a01026eb 100644 --- a/test/spv/load.ir +++ b/test/spv/load.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s ; CHECK: %[[#F32:]] = OpTypeFloat 32 ; CHECK: %[[#PTR_F32:]] = OpTypePointer CrossWorkgroup %[[#F32]] diff --git a/test/spv/size.ir b/test/spv/size.ir index b282a39e..1c6b4dc1 100644 --- a/test/spv/size.ir +++ b/test/spv/size.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s ; CHECK: %[[#I64:]] = OpTypeInt 64 0 ; CHECK: %[[#I64_C8:]] = OpConstant %[[#I64]] 8 diff --git a/test/spv/store.ir b/test/spv/store.ir index fa4755ce..1dc96472 100644 --- a/test/spv/store.ir +++ b/test/spv/store.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s ; CHECK: OpCapability AtomicFloat32AddEXT ; CHECK: OpCapability AtomicFloat64AddEXT diff --git a/test/spv/subview.ir b/test/spv/subview.ir index 9d6bfd80..58094b59 100644 --- a/test/spv/subview.ir +++ b/test/spv/subview.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -O0 -S -gspirv < %s | filecheck %s +; RUN: %tinytc-oc -O0 -S < %s | filecheck %s ; CHECK: %[[#F32:]] = OpTypeFloat 32 ; CHECK: %[[#F32_PTR:]] = OpTypePointer CrossWorkgroup %[[#F32]] diff --git a/test/spv/unique_function_type.ir b/test/spv/unique_function_type.ir index c0649d5c..607ec9b2 100644 --- a/test/spv/unique_function_type.ir +++ b/test/spv/unique_function_type.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S < %s | filecheck %s +; RUN: %tinytc-oc -S < %s | filecheck %s func @f1() {} func @f2() {} func @f3(%a: i32, %b: f32) {} diff --git a/test/spv/work_group.ir b/test/spv/work_group.ir index 7f9b840b..b5feadea 100644 --- a/test/spv/work_group.ir +++ b/test/spv/work_group.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: %tinytc-oc -gspirv -S -O0 < %s | filecheck %s +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s ; CHECK: OpCapability Group ; CHECK: %[[#I16:]] = OpTypeInt 16 0 diff --git a/tools/offline_compiler/main.cpp b/tools/offline_compiler/main.cpp index 8ae9aa6c..46a58e6b 100644 --- a/tools/offline_compiler/main.cpp +++ b/tools/offline_compiler/main.cpp @@ -17,32 +17,15 @@ using namespace tinytc; -enum class generator { opencl, spirv }; - int main(int argc, char **argv) { char const *filename = nullptr; auto info = core_info{}; tinytc_core_feature_flags_t core_features = 0; std::int32_t opt_level = 2; auto flags = cmd::optflag_states{}; - auto gen = generator::opencl; bool emit_asm = false; bool help = false; - auto const convert_string_to_generator = [](char const *str, generator &val) { - switch (fnv1a(str, std::strlen(str))) { - case "opencl"_fnv1a: - val = generator::opencl; - break; - case "spirv"_fnv1a: - val = generator::spirv; - break; - default: - return cmd::parser_status::invalid_argument; - }; - return cmd::parser_status::success; - }; - auto parser = cmd::arg_parser{}; try { info = make_core_info_intel_from_arch(intel_gpu_architecture::pvc); @@ -59,8 +42,6 @@ int main(int argc, char **argv) { } return cmd::parser_status::success; }); - parser.set_short_opt('g', &gen, "Code generation backend (opencl or spirv)") - .converter(convert_string_to_generator); parser.set_short_opt('S', &emit_asm, "Compile only; do not assemble"); parser.set_short_opt('h', &help, "Show help"); parser.set_long_opt("help", &help, "Show help"); @@ -105,21 +86,14 @@ int main(int argc, char **argv) { p = parse_file(filename, ctx); } - switch (gen) { - case generator::opencl: - std::cout << compile_to_opencl(std::move(p), info).get_code(); - break; - case generator::spirv: - if (emit_asm) { - auto mod = compile_to_spirv(std::move(p), info); - auto spvasm = mod.print_to_string(); - std::cout << spvasm.get(); - } else { - auto bin = compile_to_spirv_and_assemble(std::move(p), info); - auto raw_data = bin.get_raw(); - std::cout.write(reinterpret_cast(raw_data.data), raw_data.data_size); - } - break; + if (emit_asm) { + auto mod = compile_to_spirv(std::move(p), info); + auto spvasm = mod.print_to_string(); + std::cout << spvasm.get(); + } else { + auto bin = compile_to_spirv_and_assemble(std::move(p), info); + auto raw_data = bin.get_raw(); + std::cout.write(reinterpret_cast(raw_data.data), raw_data.data_size); } } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; From 509e72189e8c8e5f50e4e6fa6d5892ceb3f9a55c Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 22 Nov 2024 11:53:58 +0100 Subject: [PATCH 129/297] Remove bbfft dependency Signed-off-by: Carsten Uphoff --- src/CMakeLists.txt | 6 - src/codegen_tools.cpp | 476 +--------- src/codegen_tools.hpp | 122 +-- src/compiler.cpp | 5 - src/gemm_generator.cpp | 398 -------- src/gemm_generator.hpp | 100 -- src/pass/convert_to_opencl.cpp | 1631 -------------------------------- src/pass/convert_to_opencl.hpp | 132 --- src/pass/lower_foreach.cpp | 4 +- src/pass/lower_linalg.cpp | 46 +- src/required_extensions.cpp | 30 - src/required_extensions.hpp | 19 - src/scalar_type.cpp | 81 -- src/scalar_type.hpp | 12 - test/generator.cpp | 17 - 15 files changed, 35 insertions(+), 3044 deletions(-) delete mode 100644 src/gemm_generator.cpp delete mode 100644 src/gemm_generator.hpp delete mode 100644 src/pass/convert_to_opencl.cpp delete mode 100644 src/pass/convert_to_opencl.hpp delete mode 100644 src/required_extensions.cpp delete mode 100644 src/required_extensions.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 967ba264..ce47bf6c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -12,7 +12,6 @@ else () set(type static) endif () -find_package(clir 0.5.1 REQUIRED ${type}) find_package(re2c REQUIRED) find_package(BISON 3.8.2 REQUIRED) @@ -31,7 +30,6 @@ set(SOURCES device_info.cpp error.cpp func.cpp - gemm_generator.cpp gemm_tools.cpp inst.cpp location.cpp @@ -46,7 +44,6 @@ set(SOURCES pass/clone.cpp pass/constant_folding.cpp pass/constant_propagation.cpp - pass/convert_to_opencl.cpp pass/convert_to_spirv.cpp pass/dead_code_elimination.cpp pass/dump_cfg.cpp @@ -64,7 +61,6 @@ set(SOURCES recipe/small_gemm_batched.cpp recipe/tall_and_skinny.cpp region.cpp - required_extensions.cpp scalar_type.cpp spv/capex_util.cpp spv/converter.cpp @@ -99,13 +95,11 @@ add_flag_if_available_to_source_files(CXX "${BISON_parser_OUTPUTS}" "-Wno-unused add_library(tinytc-objects OBJECT ${SOURCES} ${BISON_parser_OUTPUTS}) add_re2c_to_target(TARGET tinytc-objects SOURCES ${RE2C_SOURCES}) set_cxx_common_options(tinytc-objects) -target_link_libraries(tinytc-objects PUBLIC clir::clir) target_compile_definitions(tinytc-objects PUBLIC "$<$>:TINYTC_STATIC_DEFINE>") add_library(tinytc $) add_library(tinytc::tinytc ALIAS tinytc) -target_link_libraries(tinytc PRIVATE clir::clir) set_cxx_common_options(tinytc) # Generate export header diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 0210ead4..aa7312ba 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -13,482 +13,15 @@ #include "support/visit.hpp" #include "tinytc/types.h" -#include -#include -#include -#include -#include -#include - #include #include #include #include -using namespace clir; - namespace tinytc { -short bits(scalar_type ty) { return size(ty) * 8; } -expr constant(scalar_type ty, std::int64_t value) { return expr(value, bits(ty)); } -expr constant(scalar_type ty, double value) { - if (is_complex_type(ty)) { - const auto ety = element_type(ty); - return init_vector(to_clir_ty(ty), {constant(ety, value), constant(ety, 0.0)}); - } - return expr(value, bits(ty)); -} -expr multiply(scalar_type ty_a, scalar_type ty_b, expr a, expr b) { - if (is_complex_type(ty_a) && is_complex_type(ty_b)) { - return a * b.s(0) + init_vector(to_clir_ty(ty_a), {-a.s(1), a.s(0)}) * b.s(1); - } - return a * b; -} -expr divide(scalar_type ty_a, scalar_type ty_b, expr a, expr b) { - if (is_complex_type(ty_a) && is_complex_type(ty_b)) { - return (a * b.s(0) - init_vector(to_clir_ty(ty_a), {-a.s(1), a.s(0)}) * b.s(1)) / - (b.s(0) * b.s(0) + b.s(1) * b.s(1)); - } - if (is_complex_type(ty_b)) { - return a * init_vector(to_clir_ty(ty_b), {b.s(0), -b.s(1)}) / - (b.s(0) * b.s(0) + b.s(1) * b.s(1)); - } - return a / b; -} - -expr vload_helper(short vec_size, expr offset, expr ptr) { - switch (vec_size) { - case 1: - return ptr[std::move(offset)]; - case 2: - return vload2(std::move(offset), std::move(ptr)); - case 3: - return vload3(std::move(offset), std::move(ptr)); - case 4: - return vload4(std::move(offset), std::move(ptr)); - case 8: - return vload8(std::move(offset), std::move(ptr)); - case 16: - return vload16(std::move(offset), std::move(ptr)); - default: - break; - }; - return nullptr; -} - -struct block_rw_config { - builtin_type cast_type; - expr (*sub_group_block_read)(expr); - expr (*sub_group_block_write)(expr, expr); - expr (*as_type)(expr); -}; - -auto get_block_rw_config(scalar_type ty) { - switch (ty) { - case scalar_type::i16: - return block_rw_config{builtin_type::ushort_t, &intel_sub_group_block_read_us, - &intel_sub_group_block_write_us, &as_short}; - case scalar_type::i32: - return block_rw_config{builtin_type::uint_t, &intel_sub_group_block_read_ui, - &intel_sub_group_block_write_ui, &as_int}; - case scalar_type::f32: - return block_rw_config{builtin_type::uint_t, &intel_sub_group_block_read_ui, - &intel_sub_group_block_write_ui, &as_float}; - case scalar_type::i64: - return block_rw_config{builtin_type::ulong_t, &intel_sub_group_block_read_ul, - &intel_sub_group_block_write_ul, &as_long}; - case scalar_type::f64: - return block_rw_config{builtin_type::ulong_t, &intel_sub_group_block_read_ul, - &intel_sub_group_block_write_ul, &as_double}; - default: - break; - } - return block_rw_config{builtin_type::void_t, nullptr, nullptr, nullptr}; -} - -expr sub_group_block_read_helper(expr pointer, scalar_type ty, clir::address_space as) { - const auto cfg = get_block_rw_config(ty); - if (cfg.sub_group_block_read == nullptr) { - return pointer[get_sub_group_local_id()]; - } - pointer = cast(pointer_to(clir::data_type(cfg.cast_type, as)), std::move(pointer)); - auto inst = (*cfg.sub_group_block_read)(std::move(pointer)); - return (*cfg.as_type)(std::move(inst)); -} -expr sub_group_block_write_helper(expr pointer, expr data, scalar_type ty, clir::address_space as) { - const auto cfg = get_block_rw_config(ty); - if (cfg.sub_group_block_write == nullptr) { - return pointer[get_sub_group_local_id()] = std::move(data); - } - pointer = cast(pointer_to(clir::data_type(cfg.cast_type, as)), std::move(pointer)); - data = (*cfg.as_type)(std::move(data)); - return (*cfg.sub_group_block_write)(std::move(pointer), std::move(data)); -} - -void store_helper(block_builder &bb, bool is_atomic, expr dst, scalar_type ty, - clir::address_space as, expr value, scalar_type beta_ty, expr beta) { - if (is_atomic) { - atomic_store_helper(bb, std::move(dst), ty, as, std::move(value), beta_ty, std::move(beta)); - } else { - const auto c_scaled = multiply(ty, beta_ty, dereference(dst), beta); - bb.assign(dereference(dst), std::move(value) + std::move(c_scaled)); - } -} - -void atomic_store_helper(block_builder &bb, expr dst, scalar_type ty, clir::address_space as, - expr value, scalar_type beta_ty, expr beta) { - int mode = -1; - visit(overloaded{ - [&](clir::internal::int_imm &c) { - mode = c.value() == 0 ? 0 : (c.value() == 1 ? 1 : -1); - }, - [&](clir::internal::uint_imm &c) { - mode = c.value() == 0u ? 0 : (c.value() == 1u ? 1 : -1); - }, - [&](clir::internal::float_imm &c) { - mode = c.value() == 0.0 ? 0 : (c.value() == 1.0 ? 1 : -1); - }, - [&](auto &) {}, - }, - *beta); - auto pointer_ty = pointer_to(to_clir_atomic_ty(ty, as, type_qualifier::volatile_t)); - auto atomic_dst = cast(std::move(pointer_ty), dst); - if (mode == 0) { - bb.add(call_builtin(builtin_function::atomic_store_explicit, - {std::move(atomic_dst), std::move(value), memory_order::relaxed, - memory_scope::work_group})); - } else if (mode == 1) { - bb.add(call_builtin(builtin_function::atomic_fetch_add_explicit, - {std::move(atomic_dst), std::move(value), memory_order::relaxed, - memory_scope::work_group})); - } else { - auto expected = bb.declare_assign(to_clir_ty(ty), "expected", dereference(dst)); - auto desired = bb.declare(to_clir_ty(ty), "desired"); - auto cmpxchg = - call_builtin(builtin_function::atomic_compare_exchange_strong_explicit, - {std::move(atomic_dst), address_of(std::move(expected)), desired, - memory_order::relaxed, memory_order::relaxed, memory_scope::work_group}); - bb.add(while_loop_builder(std::move(cmpxchg), true) - .body([&](block_builder &bb) { bb.assign(desired, value + beta * expected); }) - .get_product()); - } -} - -auto atomic_store_helper_new(store_flag flag, memref_data_type const *ty, expr pointer, - expr value) -> std::vector { - const auto make_atomic_store = [&](auto fun, expr pointer, expr value) -> std::vector { - constexpr auto mem_order = clir::memory_order::relaxed; - constexpr auto mem_scope = clir::memory_scope::work_group; - constexpr auto qualifier = clir::type_qualifier::volatile_t; - - const auto sty = ty->element_ty(); - const auto addrspace = to_clir_address_space(ty->addrspace()); - if (is_complex_type(sty)) { - const auto atomic_pointer_ty = - pointer_to(to_clir_atomic_ty(element_type(sty)), addrspace, qualifier); - return {expression_statement(call_builtin( - fun, {cast(atomic_pointer_ty, address_of(dereference(pointer).s(0))), value, - mem_order, mem_scope})), - expression_statement(call_builtin( - fun, {cast(atomic_pointer_ty, address_of(dereference(pointer).s(1))), value, - mem_order, mem_scope}))}; - } else { - const auto atomic_pointer_ty = pointer_to(to_clir_atomic_ty(sty, addrspace, qualifier)); - return { - expression_statement(call_builtin(fun, {cast(atomic_pointer_ty, std::move(pointer)), - std::move(value), mem_order, mem_scope}))}; - } - }; - - switch (flag) { - case store_flag::regular: - return { - expression_statement(assignment(dereference(std::move(pointer)), std::move(value)))}; - case store_flag::atomic: - return make_atomic_store(clir::builtin_function::atomic_store_explicit, std::move(pointer), - std::move(value)); - case store_flag::atomic_add: - return make_atomic_store(clir::builtin_function::atomic_fetch_add_explicit, - std::move(pointer), std::move(value)); - } - return {}; -} - -void dispatch_constant_dynamic(expr e, std::function const &const_case, - std::function const &dyn_case) { - visit( - overloaded{ - [&](clir::internal::int_imm &c) { const_case(c.value()); }, - [&](clir::internal::uint_imm &c) { const_case(static_cast(c.value())); }, - [&](auto &) { dyn_case(std::move(e)); }, - }, - *e); -} - -void tile_loop_by_sgs(block_builder &bb, expr loop_trip_count, unsigned sgs, unsigned num_tiles, - var sg_id, sgs_loop_body_builder const &body) { - dispatch_constant_dynamic( - std::move(loop_trip_count), - [&](std::int64_t c) { - tile_loop_by_sgs_constant(bb, c, sgs, num_tiles, std::move(sg_id), body); - }, - [&](expr d) { - tile_loop_by_sgs_dynamic(bb, std::move(d), sgs, num_tiles, std::move(sg_id), body); - }); -} - -void tile_loop_by_sgs_constant(block_builder &bb, unsigned loop_trip_count, unsigned sgs, - unsigned num_tiles, var sg_id, sgs_loop_body_builder const &body) { - auto blocks = loop_trip_count / sgs; - auto rem = loop_trip_count % sgs; - - auto block = bb.declare(generic_uint(), "blck"); - if (blocks > 0) { - bb.add(for_loop_builder(assignment(block, sgs * sg_id), block < sgs * blocks, - add_into(block, sgs * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, false, sgs); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); - } - if (rem > 0) { - bb.assign(block, blocks * sgs); - bb.add(if_selection_builder(sg_id == num_tiles - 1u) - .then([&](block_builder &bb) { body(bb, block, true, rem); }) - .get_product()); - } -} - -void tile_loop_by_sgs_dynamic(block_builder &bb, expr loop_trip_count, unsigned sgs, - unsigned num_tiles, var sg_id, sgs_loop_body_builder const &body) { - auto blocks = bb.declare_assign(generic_uint(), "blocks", loop_trip_count / sgs); - auto rem = bb.declare_assign(generic_uint(), "rem", std::move(loop_trip_count) % sgs); - - auto block = bb.declare(generic_uint(), "blck"); - bb.add(for_loop_builder(assignment(block, sgs * sg_id), block < sgs * blocks, - add_into(block, sgs * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, false, sgs); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); - - bb.add(if_selection_builder(rem > 0) - .then([&](block_builder &bb) { - bb.assign(block, blocks * sgs); - bb.add(if_selection_builder(sg_id == num_tiles - 1u) - .then([&](block_builder &bb) { body(bb, block, true, rem); }) - .get_product()); - }) - .get_product()); -} - -unsigned tile_loop_uniformly_max_block_size(unsigned loop_trip_count, unsigned block_size, - unsigned num_tiles) { - auto blocks = 1 + (loop_trip_count - 1) / block_size; - blocks = (1 + (blocks - 1) / num_tiles) * num_tiles; - auto bs = loop_trip_count / blocks; - auto rem = loop_trip_count % blocks; - return rem > 0 ? bs + 1 : bs; -} - -void tile_loop_uniformly(block_builder &bb, expr loop_trip_count, unsigned block_size, - unsigned num_tiles, var sg_id, uniform_loop_body_builder const &body) { - dispatch_constant_dynamic( - std::move(loop_trip_count), - [&](std::int64_t c) { - tile_loop_uniformly_constant(bb, c, block_size, num_tiles, std::move(sg_id), body); - }, - [&](expr d) { - tile_loop_uniformly_dynamic(bb, std::move(d), block_size, num_tiles, std::move(sg_id), - body); - }); -} - -void tile_loop_uniformly_constant(block_builder &bb, unsigned loop_trip_count, unsigned block_size, - unsigned num_tiles, var sg_id, - uniform_loop_body_builder const &body) { - // Find minimum number of blocks such that the block sizes are smaller or equal block_size - auto blocks = 1 + (loop_trip_count - 1) / block_size; - // Increase the number of blocks if such that the number of blocks is a multiple - // of the number of tiles - blocks = (1 + (blocks - 1) / num_tiles) * num_tiles; - auto bs = loop_trip_count / blocks; - auto bs_1 = bs + 1; - auto rem = loop_trip_count % blocks; - - auto block = bb.declare(generic_uint(), "blck"); - if (rem > 0) { - bb.add(for_loop_builder(assignment(block, bs_1 * sg_id), block < bs_1 * rem, - add_into(block, bs_1 * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, bs_1); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); - } - - auto sg_id_1 = (std::move(sg_id) + rem % num_tiles) % num_tiles; - bb.add(for_loop_builder(assignment(block, bs_1 * rem + bs * std::move(sg_id_1)), - block < loop_trip_count, add_into(block, bs * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, bs); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); -} - -void tile_loop_uniformly_dynamic(block_builder &bb, expr loop_trip_count, unsigned block_size, - unsigned num_tiles, var sg_id, - uniform_loop_body_builder const &body) { - auto blocks = - bb.declare_assign(generic_uint(), "blocks", 1 + (loop_trip_count - 1) / block_size); - bb.assign(blocks, (1 + (blocks - 1) / num_tiles) * num_tiles); - auto bs = bb.declare_assign(generic_uint(), "bs", loop_trip_count / blocks); - auto bs_1 = bb.declare_assign(generic_uint(), "bs_1", bs + 1); - auto rem = bb.declare_assign(generic_uint(), "rem", loop_trip_count % blocks); - - auto block = bb.declare(generic_uint(), "blck"); - bb.add(for_loop_builder(assignment(block, bs_1 * sg_id), block < bs_1 * rem, - add_into(block, bs_1 * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, bs_1); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); - - auto sg_id_1 = (std::move(sg_id) + rem % num_tiles) % num_tiles; - bb.add(for_loop_builder(assignment(block, bs_1 * rem + bs * std::move(sg_id_1)), - block < std::move(loop_trip_count), add_into(block, bs * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, bs); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); -} - -block_accessor_regular::block_accessor_regular(expr block, int Kb) - : block_(std::move(block)), offset_{clir::expr{nullptr}}, Kb_(Kb) {} -auto block_accessor_regular::get(int m_block, int k) const -> expr { - const auto i = k + m_block * Kb_; - if (offset_) { - return block_[offset_ + i]; - } - return block_[i]; -} - -block_accessor_vector::block_accessor_vector(expr block) : block_(std::move(block)) {} -auto block_accessor_vector::get(int m_block, int k) const -> expr { return block_[m_block].s(k); } - -int matrix_block_description::first_block_with_check(std::int32_t subgroup_size) const { - int fb = 0; - dispatch_constant_dynamic( - M, [&](std::int64_t m) { fb = m / subgroup_size; }, [](expr const &) {}); - return fb; -} - -bool matrix_block_description::is_unit_stride(int mode) const { - bool is_unit = false; - dispatch_constant_dynamic( - stride[mode], [&](std::int64_t s) { is_unit = s == 1; }, [](expr const &) {}); - return is_unit; -} - -expr matrix_block_description::condition(int m_block, std::int32_t subgroup_size) const { - return get_sub_group_local_id() + m_block * subgroup_size < M; -} - -auto read_matrix_block_regular(block_builder &bb, matrix_block_description const &d, int M_mode, - core_config const &core_cfg, - char const *block_name) -> std::unique_ptr { - assert(M_mode == 0 || M_mode == 1); - - const int m_blocks = 1 + (d.Mb - 1) / core_cfg.subgroup_size; - auto block = bb.declare(array_of(to_clir_ty(d.ty), m_blocks * d.Kb), block_name); - - const int first_m_block_with_check = d.first_block_with_check(core_cfg.subgroup_size); - const bool enable_sub_group_reads = - core_cfg.block_read_write_supported && d.is_unit_stride(M_mode); - for (int k = 0; k < d.Kb; ++k) { - for (int m_block = 0; m_block < m_blocks; ++m_block) { - auto const store = [&](expr rhs) { - bb.assign(block[k + m_block * d.Kb], std::move(rhs)); - }; - if (enable_sub_group_reads && m_block < first_m_block_with_check) { - store(sub_group_block_read_helper(d.pointer + m_block * core_cfg.subgroup_size, - d.ty, d.as)); - } else { - auto rhs = d.pointer[d.stride[M_mode] * - (get_sub_group_local_id() + m_block * core_cfg.subgroup_size)]; - if (m_block >= first_m_block_with_check) { - rhs = ternary_conditional(d.condition(m_block, core_cfg.subgroup_size), - std::move(rhs), 0); - } - store(std::move(rhs)); - } - } - bb.add(add_into(d.pointer, d.stride[1 - M_mode])); - } - return std::make_unique(std::move(block), d.Kb); -} - -auto read_matrix_block_vector(block_builder &bb, matrix_block_description const &d, int M_mode, - core_config const &core_cfg, - char const *block_name) -> std::unique_ptr { - assert(M_mode == 0 || M_mode == 1); - - const int m_blocks = 1 + (d.Mb - 1) / core_cfg.subgroup_size; - const auto dt = to_clir_ty(d.ty, d.Kb); - auto block = bb.declare(array_of(dt, m_blocks), block_name); - - int first_m_block_with_check = d.first_block_with_check(core_cfg.subgroup_size); - for (int m_block = 0; m_block < m_blocks; ++m_block) { - auto rhs = vload_helper(d.Kb, 0, - d.pointer + d.stride[M_mode] * (get_sub_group_local_id() + - m_block * core_cfg.subgroup_size)); - if (!bool(rhs)) { - throw internal_compiler_error(); - } - if (m_block >= first_m_block_with_check) { - rhs = ternary_conditional(d.condition(m_block, core_cfg.subgroup_size), rhs, - init_vector(dt, {0})); - } - bb.assign(block[m_block], std::move(rhs)); - } - bb.add(add_into(d.pointer, d.Kb * d.stride[1 - M_mode])); - - return std::make_unique(std::move(block)); -} - -auto read_matrix_block(block_builder &bb, matrix_block_description const &d, int M_mode, - core_config const &core_cfg, - char const *block_name) -> std::unique_ptr { - assert(M_mode == 0 || M_mode == 1); - - if (d.is_unit_stride(1 - M_mode) && !is_complex_type(d.ty) && - (d.Kb == 2 || d.Kb == 3 || d.Kb == 4 || d.Kb == 8 || d.Kb == 16)) { - return read_matrix_block_vector(bb, d, M_mode, core_cfg, block_name); - } - return read_matrix_block_regular(bb, d, M_mode, core_cfg, block_name); -} - -void write_matrix_block(block_builder &bb, block_accessor const &block, - matrix_block_description const &d, bool is_atomic, scalar_type beta_ty, - expr beta, core_config const &core_cfg) { - const int m_blocks = 1 + (d.Mb - 1) / core_cfg.subgroup_size; - - const int first_m_block_with_check = d.first_block_with_check(core_cfg.subgroup_size); - for (int k = 0; k < d.Kb; ++k) { - for (int m_block = 0; m_block < m_blocks; ++m_block) { - const auto write = [&](block_builder &bb) { - store_helper(bb, is_atomic, - d.pointer + d.stride[0] * (get_sub_group_local_id() + - m_block * core_cfg.subgroup_size), - d.ty, d.as, block.get(m_block, k), beta_ty, beta); - }; - if (m_block >= first_m_block_with_check) { - bb.add(if_selection_builder(d.condition(m_block, core_cfg.subgroup_size)) - .then(write) - .get_product()); - } else { - write(bb); - } - } - bb.add(add_into(d.pointer, d.stride[1])); - } -} - -void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, - value sg_id, sgs_loop_body_builder_new const &body) { +void tile_loop_by_sgs(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, + value sg_id, sgs_loop_body_builder const &body) { auto ity = loop_trip_count->ty(); auto bool_ty = boolean_data_type::get(ity->context()); auto c_sgs = bb.add(make_constant(sgs, ity)); @@ -525,9 +58,8 @@ void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, in }); } -void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int block_size, - int num_tiles, value sg_id, - uniform_loop_body_builder_new const &body) { +void tile_loop_uniformly(region_builder &bb, value loop_trip_count, int block_size, int num_tiles, + value sg_id, uniform_loop_body_builder const &body) { auto ity = loop_trip_count->ty(); auto bool_ty = boolean_data_type::get(ity->context()); auto c0 = bb.add(make_constant(0, ity)); diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index 6867d071..55cd9cd8 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -10,12 +10,6 @@ #include "tinytc/types.h" #include "tinytc/types.hpp" -#include -#include -#include -#include -#include - #include #include #include @@ -29,119 +23,15 @@ namespace tinytc { // tools for OpenCL codegen short bits(scalar_type ty); -clir::expr constant(scalar_type ty, std::int64_t value); -clir::expr constant(scalar_type ty, double value); -clir::expr multiply(scalar_type ty_a, scalar_type ty_b, clir::expr a, clir::expr b); -clir::expr divide(scalar_type ty_a, scalar_type ty_b, clir::expr a, clir::expr b); -clir::expr vload_helper(short vec_size, clir::expr offset, clir::expr ptr); -clir::expr sub_group_block_read_helper(clir::expr pointer, scalar_type ty, clir::address_space as); -clir::expr sub_group_block_write_helper(clir::expr pointer, clir::expr data, scalar_type ty, - clir::address_space as); - -void store_helper(clir::block_builder &bb, bool is_atomic, clir::expr dst, scalar_type ty, - clir::address_space as, clir::expr value, scalar_type beta_ty, clir::expr beta); -void atomic_store_helper(clir::block_builder &bb, clir::expr dst, scalar_type ty, - clir::address_space as, clir::expr value, scalar_type beta_ty, - clir::expr beta); -auto atomic_store_helper_new(store_flag flag, memref_data_type const *ty, clir::expr pointer, - clir::expr value) -> std::vector; - -void dispatch_constant_dynamic(clir::expr e, std::function const &const_case, - std::function const &dyn_case); - -using sgs_loop_body_builder = - std::function; -using uniform_loop_body_builder = - std::function; - -void tile_loop_by_sgs(clir::block_builder &bb, clir::expr loop_trip_count, unsigned sgs, - unsigned num_tiles, clir::var sg_id, sgs_loop_body_builder const &body); -void tile_loop_by_sgs_constant(clir::block_builder &bb, unsigned loop_trip_count, unsigned sgs, - unsigned num_tiles, clir::var sg_id, - sgs_loop_body_builder const &body); -void tile_loop_by_sgs_dynamic(clir::block_builder &bb, clir::expr loop_trip_count, unsigned sgs, - unsigned num_tiles, clir::var sg_id, - sgs_loop_body_builder const &body); - -unsigned tile_loop_uniformly_max_block_size(unsigned loop_trip_count, unsigned block_size, - unsigned num_tiles); -void tile_loop_uniformly(clir::block_builder &bb, clir::expr loop_trip_count, unsigned block_size, - unsigned num_tiles, clir::var sg_id, - uniform_loop_body_builder const &body); -void tile_loop_uniformly_constant(clir::block_builder &bb, unsigned loop_trip_count, - unsigned block_size, unsigned num_tiles, clir::var sg_id, - uniform_loop_body_builder const &body); -void tile_loop_uniformly_dynamic(clir::block_builder &bb, clir::expr loop_trip_count, - unsigned block_size, unsigned num_tiles, clir::var sg_id, - uniform_loop_body_builder const &body); - -class block_accessor { - public: - virtual ~block_accessor() = default; - virtual auto get(int m_block, int k) const -> clir::expr = 0; -}; - -class block_accessor_regular : public block_accessor { - public: - block_accessor_regular(clir::expr block, int Kb); - auto get(int m_block, int k) const -> clir::expr override; - inline void offset(clir::expr offset) { offset_ = std::move(offset); } - - private: - clir::expr block_, offset_; - int Kb_; -}; - -class block_accessor_vector : public block_accessor { - public: - block_accessor_vector(clir::expr block); - auto get(int m_block, int k) const -> clir::expr override; - - private: - clir::expr block_; -}; - -struct matrix_block_description { - scalar_type ty; ///< Matrix scalar type - clir::address_space as; ///< Matrix address space - int Mb; ///< Number of rows if M_mode == 0; number of columns if M_mode == 1 - int Kb; ///< Number of columns if M_mode == 0; number of rows if M_mode == 0 - clir::expr pointer; ///< Pointer to block start - clir::expr M; ///< Size of row mode if M_mode == 0; size of column mode if M_mode == 1 - std::array stride; ///< Matrix stride - - int first_block_with_check(std::int32_t subgroup_size) const; - clir::expr condition(int m_block, std::int32_t subgroup_size) const; - bool is_unit_stride(int mode) const; -}; - -auto read_matrix_block_regular(clir::block_builder &bb, matrix_block_description const &d, - int M_mode, core_config const &core_cfg, - char const *block_name) -> std::unique_ptr; -auto read_matrix_block_vector(clir::block_builder &bb, matrix_block_description const &d, - int M_mode, core_config const &core_cfg, - char const *block_name) -> std::unique_ptr; - -// Read MbxKb block -auto read_matrix_block(clir::block_builder &bb, matrix_block_description const &d, int M_mode, - core_config const &core_cfg, - char const *block_name) -> std::unique_ptr; - -// Write MbxKb block -void write_matrix_block(clir::block_builder &bb, block_accessor const &block, - matrix_block_description const &d, bool is_atomic, scalar_type beta_ty, - clir::expr beta, core_config const &core_cfg); - -// tools for tinytc lowering -using sgs_loop_body_builder_new = std::function; -using uniform_loop_body_builder_new = std::function; +using sgs_loop_body_builder = std::function; +using uniform_loop_body_builder = std::function; -void tile_loop_by_sgs_new(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, - value sg_id, sgs_loop_body_builder_new const &body); +void tile_loop_by_sgs(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, + value sg_id, sgs_loop_body_builder const &body); -void tile_loop_uniformly_new(region_builder &bb, value loop_trip_count, int block_size, - int num_tiles, value sg_id, uniform_loop_body_builder_new const &body); +void tile_loop_uniformly(region_builder &bb, value loop_trip_count, int block_size, int num_tiles, + value sg_id, uniform_loop_body_builder const &body); auto mixed_precision_arithmetic(region_builder &bb, arithmetic operation, value a, value b, location const &loc) -> value; diff --git a/src/compiler.cpp b/src/compiler.cpp index d58546d0..bebe75ec 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -6,7 +6,6 @@ #include "node/program_node.hpp" #include "pass/check_ir.hpp" #include "pass/constant_propagation.hpp" -#include "pass/convert_to_opencl.hpp" #include "pass/convert_to_spirv.hpp" #include "pass/dead_code_elimination.hpp" #include "pass/dump_cfg.hpp" @@ -20,16 +19,12 @@ #include "pass/work_group_size.hpp" #include "passes.hpp" #include "reference_counted.hpp" -#include "required_extensions.hpp" #include "spv/pass/assemble.hpp" #include "spv/pass/assign_ids.hpp" #include "tinytc/tinytc.h" #include "tinytc/types.h" #include "tinytc/types.hpp" -#include -#include - #include #include #include diff --git a/src/gemm_generator.cpp b/src/gemm_generator.cpp deleted file mode 100644 index e6c8608c..00000000 --- a/src/gemm_generator.cpp +++ /dev/null @@ -1,398 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "gemm_generator.hpp" -#include "codegen_tools.hpp" -#include "device_info.hpp" -#include "gemm_tools.hpp" -#include "scalar_type.hpp" -#include "tiling.hpp" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -using namespace clir; - -namespace tinytc { - -gemm_scalar_type::gemm_scalar_type(scalar_type ty) : alpha(ty), A(ty), B(ty), beta(ty), C(ty) {} -gemm_scalar_type::gemm_scalar_type(scalar_type alphaAB, scalar_type betaC) - : alpha(alphaAB), A(alphaAB), B(alphaAB), beta(betaC), C(betaC) {} -gemm_scalar_type::gemm_scalar_type(scalar_type alpha, scalar_type A, scalar_type B, - scalar_type beta, scalar_type C) - : alpha(alpha), A(A), B(B), beta(beta), C(C) {} - -std::string gemm_configuration::identifier(std::string_view prefix) const { - std::ostringstream oss; - auto const dyn_val = [&oss](std::int64_t v) { - if (v == dynamic) { - oss << "d"; - } else { - oss << v; - } - }; - auto const stride = [&oss, &dyn_val](char X, std::array const &s) { - oss << "_" << X << "stride"; - dyn_val(s[0]); - oss << "_"; - dyn_val(s[1]); - }; - oss << prefix << "_"; - if (atomic) { - oss << "atomic_"; - } - oss << to_string(ty.alpha) << to_string(ty.A) << to_string(ty.B) << to_string(ty.beta) - << to_string(ty.C) << "_A" << to_string(transA) << "_B" << to_string(transB) << "_M"; - dyn_val(M); - oss << "_N"; - dyn_val(N); - oss << "_K"; - dyn_val(K); - stride('A', A_stride); - stride('B', B_stride); - stride('C', C_stride); - auto const format_optional = [&](std::optional const &val) { - if (val) { - auto f = oss.flags(); - auto v = *val; - oss << std::hex << std::bit_cast(v); - oss.flags(f); - } else { - oss << "d"; - } - }; - oss << "_alpha"; - format_optional(alpha); - oss << "_beta"; - format_optional(beta); - return oss.str(); -} - -class generator { - public: - generator(gemm_configuration const &gemm_cfg, local_tiling const &tiling, - core_config const &core_cfg, clir::address_space As, clir::address_space Bs, - clir::address_space Cs) - : gemm_cfg(gemm_cfg), tiling(tiling), core_cfg(core_cfg), Aspace(As), Bspace(Bs), - Cspace(Cs) {} - bool use_double_buffering() const; - void multiply_update(block_builder &bb, expr a, expr b, int n_offset, expr c, expr c_im); - void add_microkernel(block_builder &bb, expr M, expr N, var A, var B, var C, expr C_offset, - expr alpha, expr beta); - void add_mloop(block_builder &bb, expr N, var A, var B, var C, expr C_offset, expr alpha, - expr beta); - void add_function_body(block_builder &bb, var A, var B, var C, expr alpha, expr beta); - ::clir::func function(std::string_view name); - - private: - gemm_configuration const gemm_cfg; - local_tiling const tiling; - core_config const core_cfg; - clir::address_space Aspace, Bspace, Cspace; - int row_blocks_in_register = 1; - int cols_in_register = 1; - var c_acc, c_acc_im, m; - std::array MNK; - std::array A_stride, B_stride, C_stride; -}; - -bool generator::use_double_buffering() const { - return is_complex_type(gemm_cfg.ty.A) && is_complex_type(gemm_cfg.ty.B); -} - -void generator::multiply_update(block_builder &bb, expr a, expr b, int n_offset, expr c, - expr c_im) { - if (is_complex_type(gemm_cfg.ty.A)) { - if (is_complex_type(gemm_cfg.ty.B)) { - assert(use_double_buffering()); - auto b_bc_re = sub_group_broadcast(b.s(0), n_offset); - auto b_bc_im = sub_group_broadcast(b.s(1), n_offset); - bb.add(add_into(c, a * b_bc_re)); - bb.add(add_into(c_im, a * b_bc_im)); - } else { - auto b_bc = sub_group_broadcast(b, n_offset); - bb.add(add_into(std::move(c), std::move(a) * std::move(b_bc))); - } - } else if (is_complex_type(gemm_cfg.ty.B)) { - auto b_bc_re = sub_group_broadcast(b.s(0), n_offset); - auto b_bc_im = sub_group_broadcast(b.s(1), n_offset); - bb.add(add_into(c.s(0), a * b_bc_re)); - bb.add(add_into(c.s(1), a * b_bc_im)); - } else { - auto b_bc = sub_group_broadcast(b, n_offset); - if (gemm_cfg.ty.A == gemm_cfg.ty.B && gemm_cfg.ty.B == gemm_cfg.ty.C) { - bb.assign(c, fma(std::move(a), std::move(b_bc), c)); - } else { - bb.add(add_into(std::move(c), std::move(a) * std::move(b_bc))); - } - } -} - -void generator::add_microkernel(block_builder &bb, expr M, expr N, var A, var B, var C, - expr C_offset, expr alpha, expr beta) { - int n_bs = 0; - dispatch_constant_dynamic( - N, [&](std::int64_t n) { n_bs = n; }, - [&](expr) { n_bs = static_cast(cols_in_register); }); - - auto my_row_blocks_in_register = row_blocks_in_register; - dispatch_constant_dynamic( - M, - [&](std::int64_t m) { my_row_blocks_in_register = 1 + (m - 1) / core_cfg.subgroup_size; }, - [&](expr) {}); - auto const Mb = my_row_blocks_in_register * core_cfg.subgroup_size; - - auto Ab = bb.declare_assign(pointer_to(to_clir_ty(gemm_cfg.ty.A, Aspace)), "Ab", A); - auto Bb = bb.declare_assign(pointer_to(to_clir_ty(gemm_cfg.ty.B, Bspace)), "Bb", B); - - auto c_block = block_accessor_regular(c_acc, n_bs); - auto c_block_im = block_accessor_regular(c_acc_im, n_bs); - - for (int n = 0; n < n_bs; ++n) { - for (int m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - bb.assign(c_block.get(m_block, n), constant(gemm_cfg.ty.C, 0.0)); - if (use_double_buffering()) { - bb.assign(c_block_im.get(m_block, n), constant(gemm_cfg.ty.C, 0.0)); - } - } - } - - auto const compute_c = [&](block_builder &bb, int Kb, ::clir::expr K0, ::clir::expr K1) { - auto kb = var("kb"); - bb.add( - for_loop_builder(declaration_assignment(generic_short(), kb, std::move(K0)), - kb < std::move(K1), add_into(kb, Kb)) - .body([&](block_builder &bb) { - auto const a_descr = - matrix_block_description{gemm_cfg.ty.A, Aspace, Mb, Kb, Ab, M, A_stride}; - auto const am = gemm_cfg.transA == transpose::T ? 1 : 0; - auto const a = read_matrix_block(bb, a_descr, am, core_cfg, "a"); - - auto const b_descr = - matrix_block_description{gemm_cfg.ty.B, Bspace, n_bs, Kb, Bb, N, B_stride}; - auto const bn = gemm_cfg.transB == transpose::T ? 0 : 1; - auto const b = read_matrix_block(bb, b_descr, bn, core_cfg, "b"); - - const int nbb = 4; - for (int m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - for (int nb = 0; nb < n_bs; nb += nbb) { - for (int k = 0; k < Kb; ++k) { - for (int n = 0; n < nbb; ++n) { - if (nb + n < n_bs) { - auto const n_block = (nb + n) / core_cfg.subgroup_size; - auto const n_offset = (nb + n) % core_cfg.subgroup_size; - /*auto my_a = a->get(m_block, k); - auto my_b = - sub_group_broadcast(b->get(n_block, k), n_offset); - auto my_c = c_block.get(m_block, nb + n); - if (gemm_cfg.ty.A == gemm_cfg.ty.B && - gemm_cfg.ty.B == gemm_cfg.ty.C) { - bb.assign(my_c, fma(std::move(my_a), std::move(my_b), - my_c)); - } else { - bb.add(add_into(std::move(my_c), - std::move(my_a) * std::move(my_b))); - }*/ - auto my_a = a->get(m_block, k); - auto my_b = b->get(n_block, k); - auto c_re = c_block.get(m_block, nb + n); - auto c_im = c_block_im.get(m_block, nb + n); - multiply_update(bb, std::move(my_a), std::move(my_b), - n_offset, std::move(c_re), std::move(c_im)); - } - } - } - } - } - }) - .attribute(opencl_unroll_hint(1)) - .get_product()); - }; - dispatch_constant_dynamic( - MNK[2], - [&](std::int64_t K) { - static_assert(max_K_unrolling % 2 == 0, "max_K_unrolling must be a multiple of 2"); - auto Kb = max_K_unrolling; - while (K < Kb && Kb > 1) { - Kb /= 2; - } - auto KmultipleKb = (K / Kb) * Kb; - compute_c(bb, Kb, 0, KmultipleKb); - if (K - KmultipleKb > 0) { - compute_c(bb, 1, KmultipleKb, K); - } - }, - [&](expr K) { - auto KmultipleKb = bb.declare_assign(generic_uint(), "KmultipleKb", - (K / max_K_unrolling) * max_K_unrolling); - compute_c(bb, max_K_unrolling, 0, KmultipleKb); - bb.add(if_selection_builder(K - KmultipleKb > 0) - .then([&](block_builder &bb) { compute_c(bb, 1, KmultipleKb, K); }) - .get_product()); - }); - - auto Cb = bb.declare_assign(pointer_to(to_clir_ty(gemm_cfg.ty.C, Cspace)), "Cb", C + C_offset); - auto const c_descr = matrix_block_description{gemm_cfg.ty.C, Cspace, Mb, 1, Cb, M, C_stride}; - auto n = var("n"); - c_block.offset(n); - c_block_im.offset(n); - bb.add(for_loop_builder(declaration_assignment(generic_short(), n, 0), n < N, ++n) - .body([&](block_builder &bb) { - if (use_double_buffering()) { - for (int m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - auto c_im = c_block_im.get(m_block, 0); - auto c_ty = to_clir_ty(gemm_cfg.ty.C); - bb.add(add_into(c_block.get(m_block, 0), - init_vector(c_ty, {-c_im.s(1), c_im.s(0)}))); - } - } - for (int m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - auto c = c_block.get(m_block, 0); - bb.assign(c, multiply(gemm_cfg.ty.alpha, gemm_cfg.ty.C, alpha, c)); - } - write_matrix_block(bb, c_block, c_descr, gemm_cfg.atomic, gemm_cfg.ty.beta, beta, - core_cfg); - }) - .get_product()); -} - -void generator::add_mloop(block_builder &bb, expr N, var A, var B, var C, expr C_offset, expr alpha, - expr beta) { - auto sg_m = bb.declare_assign(generic_uint(), "sg_m", get_sub_group_id() % tiling.m_tiles()); - tile_loop_by_sgs( - bb, MNK[0], core_cfg.subgroup_size * row_blocks_in_register, tiling.m_tiles(), - std::move(sg_m), [&](block_builder &bb, expr block, bool, expr inner_trip_count) { - auto Astride_m = gemm_cfg.transA == transpose::T ? A_stride[1] : A_stride[0]; - auto Ab = bb.declare_assign(pointer_to(to_clir_ty(gemm_cfg.ty.A, Aspace)), "Ab", - A + std::move(Astride_m) * block); - add_microkernel(bb, std::move(inner_trip_count), N, std::move(Ab), B, C, - C_stride[0] * std::move(block) + C_offset, alpha, beta); - }); -} - -void generator::add_function_body(block_builder &bb, var A, var B, var C, expr alpha, expr beta) { - m = bb.declare_assign(generic_uint(), "m", get_sub_group_local_id()); - c_acc = var("c"); - c_acc_im = var("c_im"); - - auto register_space = core_cfg.register_space; - if (use_double_buffering()) { - // We buffer the real / imag part separately, so we only have half the register space - // available for one of the buffers - register_space /= 2; - } - auto [max_rows, max_cols] = - max_register_block_gemm(size(gemm_cfg.ty.C), core_cfg.subgroup_size, register_space); - const auto max_row_blocks = max_rows / core_cfg.subgroup_size; - row_blocks_in_register = max_row_blocks; - cols_in_register = max_cols; - if (!is_dynamic_value(gemm_cfg.M)) { - auto const row_blocks_needed_to_cover_M = 1 + (gemm_cfg.M - 1) / core_cfg.subgroup_size; - if (row_blocks_needed_to_cover_M < max_row_blocks) { - row_blocks_in_register = row_blocks_needed_to_cover_M; - } else { - auto blocks = gemm_cfg.M / row_blocks_in_register; - auto sg_blocks = 1 + (blocks - 1) / tiling.m_tiles(); - while (sg_blocks < tiling.m_tiles() && row_blocks_in_register >= 2) { - row_blocks_in_register /= 2; - blocks = gemm_cfg.M / row_blocks_in_register; - sg_blocks = 1 + (blocks - 1) / tiling.m_tiles(); - } - } - } - if (!is_dynamic_value(gemm_cfg.N)) { - cols_in_register = - tile_loop_uniformly_max_block_size(gemm_cfg.N, cols_in_register, tiling.n_tiles()); - } - bb.declare(array_of(to_clir_ty(gemm_cfg.ty.C), row_blocks_in_register * cols_in_register), - c_acc); - if (use_double_buffering()) { - bb.declare(array_of(to_clir_ty(gemm_cfg.ty.C), row_blocks_in_register * cols_in_register), - c_acc_im); - } - - auto sg_n = bb.declare_assign(generic_uint(), "sg_n", get_sub_group_id() / tiling.m_tiles()); - tile_loop_uniformly(bb, MNK[1], max_cols, tiling.n_tiles(), std::move(sg_n), - [&](block_builder &bb, expr block, expr inner_trip_count) { - auto Bstride_n = - gemm_cfg.transB == transpose::T ? B_stride[0] : B_stride[1]; - auto Bb = - bb.declare_assign(pointer_to(to_clir_ty(gemm_cfg.ty.B, Bspace)), - "Bb", B + std::move(Bstride_n) * block); - add_mloop(bb, std::move(inner_trip_count), A, std::move(Bb), C, - C_stride[1] * std::move(block), alpha, beta); - }); -} - -::clir::func generator::function(std::string_view name) { - auto A = var("A"); - auto B = var("B"); - auto C = var("C"); - - auto fb = ::clir::function_builder{std::string(name)}; - auto const scalar = [&](scalar_type ty, std::optional const &val, - std::string const &prefix) -> expr { - auto v = var{prefix}; - fb.argument(to_clir_ty(ty), v); - return val ? constant(ty, *val) : v; - }; - auto const shape = [&](std::int64_t shape, expr &target, std::string const &prefix) { - auto v = var{prefix}; - fb.argument(to_clir_ty(scalar_type::index), v); - target = is_dynamic_value(shape) ? expr{std::move(v)} : expr{shape}; - }; - auto const stride = [&](std::array const &stride, std::array &target, - std::string const &prefix) { - for (std::size_t i = 0; i < stride.size(); ++i) { - auto v = var{prefix}; - fb.argument(to_clir_ty(scalar_type::index), v); - target[i] = is_dynamic_value(stride[i]) ? expr{std::move(v)} : expr{stride[i]}; - } - }; - - shape(gemm_cfg.M, MNK[0], "M"); - shape(gemm_cfg.N, MNK[1], "N"); - shape(gemm_cfg.K, MNK[2], "K"); - expr alpha = scalar(gemm_cfg.ty.alpha, gemm_cfg.alpha, "alpha"); - fb.argument(pointer_to(to_clir_ty(gemm_cfg.ty.A, Aspace)), A); - stride(gemm_cfg.A_stride, A_stride, "A_stride"); - fb.argument(pointer_to(to_clir_ty(gemm_cfg.ty.B, Bspace)), B); - stride(gemm_cfg.B_stride, B_stride, "B_stride"); - expr beta = scalar(gemm_cfg.ty.beta, gemm_cfg.beta, "beta"); - fb.argument(pointer_to(to_clir_ty(gemm_cfg.ty.C, Cspace)), C); - stride(gemm_cfg.C_stride, C_stride, "C_stride"); - - fb.body([&](block_builder &bb) { add_function_body(bb, A, B, C, alpha, beta); }); - - auto f = fb.get_product(); - make_names_unique(f); - unsafe_simplify(f); - - return f; -} - -::clir::func generate_gemm(gemm_configuration const &gemm_cfg, local_tiling const &tiling, - core_config const &core_cfg, std::string_view name, - clir::address_space As, clir::address_space Bs, clir::address_space Cs) { - return generator{gemm_cfg, tiling, core_cfg, As, Bs, Cs}.function(name); -} - -} // namespace tinytc diff --git a/src/gemm_generator.hpp b/src/gemm_generator.hpp deleted file mode 100644 index 7a89d7d5..00000000 --- a/src/gemm_generator.hpp +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef GEMM_GENERATOR_20240314_HPP -#define GEMM_GENERATOR_20240314_HPP - -#include "device_info.hpp" -#include "tiling.hpp" -#include "tinytc/types.hpp" - -#include -#include - -#include -#include -#include -#include -#include - -namespace tinytc { - -//! Struct to handle mixed precision GEMMs -struct gemm_scalar_type { - //! alpha, A, B, beta, C all have the same type - gemm_scalar_type(scalar_type ty); - //! alpha's, A's, and B's type is different from beta's and C's type - gemm_scalar_type(scalar_type alphaAB, scalar_type betaC); - //! All operands potentially have a different type - gemm_scalar_type(scalar_type alpha, scalar_type A, scalar_type B, scalar_type beta, - scalar_type C); - scalar_type alpha, ///< @f$\alpha@f$ type - A, ///< A element type - B, ///< B element type - beta, ///< @f$\beta@f$ type - C; ///< C element type -}; - -/** - * @brief GEMM configuration struct - * - * The interface supports the operation - * - * C = alpha * opA(A) * opB(B) + beta * C, - * - * where - * - * opA/B(X) = transA/B == T ? X^T : X - * - * C is an MxN matrix, A is a MxK matrix, and B is a KxN matrix. - * - * The address of a matrix is calculated as following. Let X be element of {A,B,C}, then - * - * X(i,j) = X[i * X_stride[0] + j * X_stride[1]] - * - * If the atomic flag is set, C is updated atomically, either using - * - * * beta = 0: atomic store - * * beta = 1: atomic fetch add - * * general beta: atomic compare exchange - */ -struct gemm_configuration { - gemm_scalar_type ty; ///< scalar types of alpha, A, B, beta, C - transpose transA; ///< Transposition of A - transpose transB; ///< Transposition of B - std::int64_t M; ///< M, can be set to dynamic - std::int64_t N; ///< N, can be set to dynamic - std::int64_t K; ///< K, can be set to dynamic - std::array A_stride; ///< stride of A, entries can be set to dynamic - std::array B_stride; ///< stride of B, entries can be set to dynamic - std::array C_stride; ///< stride of C, entries can be set to dynamic - std::optional alpha; ///< fixed alpha if set; dynamic alpha if std::nullopt - std::optional beta; ///< fixed beta if set; dynamic beta if std::nullopt - bool atomic = false; ///< update C atomically - - std::string identifier( - std::string_view prefix = "gemm") const; ///< convert configuration to identification string -}; - -/** - * @brief Generate GEMM - * - * @param gemm_cfg configuration - * @param tiling Size of 2D subgroup grid - * @param core_cfg Core configuration - * @param name Routine prefix - * @param As Memory space of A (global or local) - * @param Bs Memory space of B (global or local) - * @param Cs Memory space of C (global or local) - * - * @return OpenCL-C AST - */ -::clir::func generate_gemm(gemm_configuration const &gemm_cfg, local_tiling const &tiling, - core_config const &core_cfg, std::string_view name, - ::clir::address_space As = ::clir::address_space::global_t, - ::clir::address_space Bs = ::clir::address_space::global_t, - ::clir::address_space Cs = ::clir::address_space::global_t); - -} // namespace tinytc - -#endif // GEMM_GENERATOR_20240314_HPP diff --git a/src/pass/convert_to_opencl.cpp b/src/pass/convert_to_opencl.cpp deleted file mode 100644 index ee85b01a..00000000 --- a/src/pass/convert_to_opencl.cpp +++ /dev/null @@ -1,1631 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "pass/convert_to_opencl.hpp" -#include "codegen_tools.hpp" -#include "error.hpp" -#include "gemm_generator.hpp" -#include "scalar_type.hpp" -#include "support/casting.hpp" -#include "support/ilist.hpp" -#include "support/ilist_base.hpp" -#include "support/util.hpp" -#include "support/visit.hpp" -#include "tinytc/tinytc.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace tinytc { - -std::string var_name(std::string name) { - if (name.empty() || !isalpha(name[0])) { - // we use clir unique names to clean up possible duplicates - name = "x" + name; - } - return name; -} - -dope_vector dope_vector::from_value(value_node const &v, decl_fun_t declare) { - memref_data_type const *m = nullptr; - auto dt = clir::data_type{}; - visit(overloaded{[&](memref_data_type const &mr) { - m = &mr; - dt = to_clir_ty(scalar_type::index); - }, - [&](group_data_type const &g) { - m = dyn_cast(g.ty()); - dt = clir::pointer_to( - to_clir_ty(scalar_type::index, clir::address_space::global_t)); - }, - [](auto const &) {}}, - *v.ty()); - if (m == nullptr) { - throw compilation_error( - v.loc(), status::internal_compiler_error, - "dope_vector::from_value must only be called for memref or group type"); - } - auto dv = dope_vector::from_memref_type(std::string(v.name()), *m, std::move(dt), declare); - visit(overloaded{[&](memref_data_type const &) {}, - [&](group_data_type const &g) { - if (is_dynamic_value(g.offset())) { - auto var = clir::var( - (std::ostringstream{} << var_name(v.name()) << "_offset").str()); - declare(to_clir_ty(scalar_type::index), var, type::offset, 0); - dv.offset(std::move(var)); - } else { - dv.offset(g.offset()); - } - }, - [](auto const &) {}}, - *v.ty()); - return dv; -} - -dope_vector dope_vector::from_memref_type(std::string const &prefix, memref_data_type const &m, - clir::data_type dt, decl_fun_t declare) { - auto shape = std::vector{}; - auto stride = std::vector{}; - shape.resize(m.dim()); - stride.resize(m.dim()); - for (std::int64_t j = 0; j < m.dim(); ++j) { - if (is_dynamic_value(m.shape(j))) { - auto oss = std::ostringstream{}; - oss << var_name(prefix) << "_shape" << j; - auto var = clir::var(oss.str()); - declare(dt, var, type::shape, j); - shape[j] = std::move(var); - } else { - shape[j] = m.shape(j); - } - if (is_dynamic_value(m.stride(j))) { - auto oss = std::ostringstream{}; - oss << var_name(prefix) << "_stride" << j; - auto var = clir::var(oss.str()); - declare(dt, var, type::stride, j); - stride[j] = std::move(var); - } else { - stride[j] = m.stride(j); - } - } - return dope_vector(std::move(shape), std::move(stride)); -} - -convert_to_opencl_pass::convert_to_opencl_pass(::tinytc_core_info const *info) - : info_(std::move(info)) { - if (info_ == nullptr) { - throw std::invalid_argument("info must not be nullptr"); - } - declared_vars_.push_back({}); -} - -auto convert_to_opencl_pass::get_dope_vector(value_node const &v) -> dope_vector & { - auto dv = dope_vector_.find(std::bit_cast(&v)); - if (dv == dope_vector_.end()) { - throw compilation_error(v.loc(), status::internal_compiler_error, - "Dope vector for value is missing"); - } - return dv->second; -} - -void convert_to_opencl_pass::set_dope_vector(value_node const &v, dope_vector dv) { - uintptr_t u = std::bit_cast(&v); - dope_vector_[u] = std::move(dv); -} - -clir::var convert_to_opencl_pass::declare(value_node const &v) { - uintptr_t u = std::bit_cast(&v); - for (auto it = declared_vars_.rbegin(); it != declared_vars_.rend(); ++it) { - if (it->find(u) != it->end()) { - throw compilation_error(v.loc(), status::internal_compiler_error, - "Variable already declared"); - } - } - - auto name = var_name(std::string(v.name())); - declared_vars_.back()[u] = clir::var(std::move(name)); - return declared_vars_.back()[u]; -} - -auto convert_to_opencl_pass::get_coopmatrix_type(value_node const &v) const - -> const coopmatrix_data_type * { - auto t = dyn_cast(v.ty()); - if (t == nullptr) { - throw compilation_error(v.loc(), status::ir_expected_coopmatrix); - } - return t; -} - -auto convert_to_opencl_pass::get_memref_type(value_node const &v) const - -> const memref_data_type * { - auto t = dyn_cast(v.ty()); - if (t == nullptr) { - throw compilation_error(v.loc(), status::ir_expected_memref); - } - return t; -} - -auto convert_to_opencl_pass::get_scalar_type(value_node const &v) -> scalar_type { - auto st = dyn_cast(v.ty()); - if (!st) { - throw compilation_error(v.loc(), status::ir_expected_scalar); - } - return st->ty(); -}; - -/* Data type nodes */ -clir::data_type convert_to_opencl_pass::operator()(void_data_type const &) { - return clir::builtin_type::void_t; -} -clir::data_type convert_to_opencl_pass::operator()(boolean_data_type const &) { - return clir::builtin_type::bool_t; -} -clir::data_type convert_to_opencl_pass::operator()(coopmatrix_data_type const &ct) { - return array_of(to_clir_ty(ct.component_ty()), ct.length(core_cfg_.subgroup_size)); -} -clir::data_type convert_to_opencl_pass::operator()(group_data_type const &g) { - auto ptr_ty = visit(*this, *g.ty()); - ptr_ty = clir::visit(overloaded{[](clir::internal::pointer &t) { - return clir::pointer_to(clir::pointer_to( - t.ty(), clir::address_space::global_t)); - }, - [](auto &) { return clir::data_type{}; }}, - *ptr_ty); - if (!ptr_ty) { - throw compilation_error(location{}, status::internal_compiler_error, - "Could not determine OpenCL type of group type"); - } - return ptr_ty; -} -clir::data_type convert_to_opencl_pass::operator()(memref_data_type const &d) { - return clir::pointer_to(to_clir_ty(d.element_ty(), to_clir_address_space(d.addrspace()))); -} -clir::data_type convert_to_opencl_pass::operator()(scalar_data_type const &s) { - return to_clir_ty(s.ty()); -} - -/* Value nodes */ -auto convert_to_opencl_pass::val(value_node const &v) -> clir::expr { - uintptr_t u = std::bit_cast(&v); - for (auto it = declared_vars_.rbegin(); it != declared_vars_.rend(); ++it) { - if (auto j = it->find(u); j != it->end()) { - return j->second; - } - } - - throw compilation_error(v.loc(), status::internal_compiler_error, - "Undeclared variable: " + std::string(v.name())); -} - -/* Stmt nodes */ -std::vector convert_to_opencl_pass::operator()(alloca_inst const &a) { - if (a.stack_ptr() < 0) { - throw compilation_error(a.loc(), status::internal_compiler_error, - "Invalid stack_ptr in alloca. Did you run set_stack_ptrs?"); - } - auto result_var = declare(*a.result()); - auto t = dyn_cast(a.result()->ty()); - if (t == nullptr) { - throw compilation_error(a.loc(), status::ir_expected_memref); - } - auto ptr_ty = operator()(*t); - auto result = declaration_assignment(ptr_ty, std::move(result_var), - clir::cast(ptr_ty, stack_ + a.stack_ptr())); - stack_high_water_mark_ = std::max(stack_high_water_mark_, - static_cast(a.stack_ptr()) + t->size_in_bytes()); - - // no declarations are necessary as alloca only accepts fixed-size memrefs - set_dope_vector(a.result(0), - dope_vector::from_value(*a.result(), [](clir::data_type, clir::var, - dope_vector::type, std::int64_t) {})); - return {std::move(result)}; -} - -std::vector convert_to_opencl_pass::operator()(axpby_inst const &inst) { - auto at = get_memref_type(inst.A()); - auto bt = get_memref_type(inst.B()); - auto alpha_ty = get_scalar_type(inst.alpha()); - auto beta_ty = get_scalar_type(inst.beta()); - auto &adv = get_dope_vector(inst.A()); - auto &bdv = get_dope_vector(inst.B()); - - auto pA = inst.tA() == transpose::T && at->dim() == 2 ? 1 : 0; - - auto alpha = val(inst.alpha()); - auto beta = val(inst.beta()); - auto const inner_loop = [&](clir::block_builder &bb, clir::expr Ab, clir::expr Bb, - clir::expr trip_count, std::size_t num_tiles, clir::var sg_id) { - auto m = bb.declare_assign(clir::generic_uint(), "m", clir::get_sub_group_local_id()); - tile_loop_by_sgs( - bb, std::move(trip_count), core_cfg_.subgroup_size, num_tiles, std::move(sg_id), - [&](clir::block_builder &bb, clir::expr block, bool is_remainder, - clir::expr inner_trip_count) { - auto const inner_loop = [&](clir::block_builder &bb) { - auto a = Ab[(block + m) * adv.stride(pA)]; - auto b = bb.declare_assign((*this)(*bt), "b", Bb + (block + m) * bdv.stride(0)); - const auto a_scaled = multiply(alpha_ty, at->element_ty(), alpha, std::move(a)); - store_helper(bb, inst.atomic(), b, bt->element_ty(), - to_clir_address_space(bt->addrspace()), std::move(a_scaled), - beta_ty, beta); - }; - if (is_remainder) { - bb.add(clir::if_selection_builder(m < std::move(inner_trip_count)) - .then(inner_loop) - .get_product()); - } else { - inner_loop(bb); - } - }); - }; - - auto A = val(inst.A()); - auto B = val(inst.B()); - if (bt->dim() == 0) { - auto bb = clir::block_builder{}; - const auto a_scaled = multiply(alpha_ty, at->element_ty(), alpha, A[0]); - store_helper(bb, inst.atomic(), B, bt->element_ty(), to_clir_address_space(bt->addrspace()), - std::move(a_scaled), beta_ty, std::move(beta)); - return {bb.get_product()}; - } - - if (bt->dim() == 1) { - auto bb = clir::block_builder{}; - auto sg_m = bb.declare_assign(clir::generic_uint(), "sg_m", clir::get_sub_group_id()); - inner_loop(bb, std::move(A), std::move(B), bdv.shape(0), - tiling_.m_tiles() * tiling_.n_tiles(), std::move(sg_m)); - return {bb.get_product()}; - } else if (bt->dim() == 2) { - auto bb = clir::block_builder{}; - auto sg_n = bb.declare_assign(clir::generic_uint(), "sg_n", - clir::get_sub_group_id() / tiling_.m_tiles()); - auto sg_m = bb.declare_assign(clir::generic_uint(), "sg_m", - clir::get_sub_group_id() % tiling_.m_tiles()); - tile_loop_uniformly( - bb, bdv.shape(1), core_cfg_.subgroup_size, tiling_.n_tiles(), std::move(sg_n), - [&](clir::block_builder &bb, clir::expr block, clir::expr trip_count) { - auto n = clir::var("n"); - bb.add( - clir::for_loop_builder(clir::declaration_assignment(clir::generic_int(), n, 0), - n < std::move(trip_count), ++n) - .body([&](clir::block_builder &bb) { - auto Ab = bb.declare_assign(this->operator()(*at), "Ab", - A + (block + n) * adv.stride(1 - pA)); - auto Bb = bb.declare_assign(this->operator()(*bt), "Bb", - B + (block + n) * bdv.stride(1)); - inner_loop(bb, Ab, Bb, bdv.shape(0), tiling_.m_tiles(), sg_m); - }) - .get_product()); - }); - return {bb.get_product()}; - } - return {}; -} - -std::vector convert_to_opencl_pass::operator()(barrier_inst const &b) { - clir::expr fence = 0; - if (b.has_fence(address_space::global)) { - fence = fence | clir::cl_mem_fence_flags::CLK_GLOBAL_MEM_FENCE; - } - if (b.has_fence(address_space::local)) { - fence = fence | clir::cl_mem_fence_flags::CLK_LOCAL_MEM_FENCE; - } - return {clir::expression_statement( - clir::call_builtin(clir::builtin_function::barrier, {std::move(fence)}))}; -} - -std::vector convert_to_opencl_pass::operator()(arith_inst const &a) { - auto const make_boolean = [](arithmetic op, clir::expr a, clir::expr b) -> clir::expr { - switch (op) { - case arithmetic::and_: - return std::move(a) && std::move(b); - case arithmetic::or_: - return std::move(a) || std::move(b); - case arithmetic::xor_: - return std::move(a) != std::move(b); - default: - return nullptr; - } - }; - auto const make = [](arithmetic op, clir::expr a, clir::expr b, scalar_type sty) -> clir::expr { - switch (op) { - case arithmetic::add: - return std::move(a) + std::move(b); - case arithmetic::sub: - return std::move(a) - std::move(b); - case arithmetic::mul: - return multiply(sty, sty, std::move(a), std::move(b)); - case arithmetic::div: - return divide(sty, sty, std::move(a), std::move(b)); - case arithmetic::rem: - if (is_floating_type(sty)) { - return clir::fmod(std::move(a), std::move(b)); - } - return std::move(a) % std::move(b); - case arithmetic::shl: - return std::move(a) << std::move(b); - case arithmetic::shr: - return std::move(a) >> std::move(b); - case arithmetic::and_: - return std::move(a) & std::move(b); - case arithmetic::or_: - return std::move(a) | std::move(b); - case arithmetic::xor_: - return std::move(a) ^ std::move(b); - } - return {}; - }; - - auto lhs = declare(a.result(0)); - auto lhs_ty = visit(*this, *a.result()->ty()); - auto av = val(a.a()); - auto bv = val(a.b()); - if (isa(*a.result(0).ty())) { - auto op = make_boolean(a.operation(), av, bv); - if (!bool(op)) { - throw compilation_error(a.loc(), status::ir_boolean_unsupported); - } - return {declaration_assignment(std::move(lhs_ty), std::move(lhs), std::move(op))}; - } else if (auto st = dyn_cast(a.result(0).ty()); st) { - auto op = make(a.operation(), av, bv, st->ty()); - return {declaration_assignment(std::move(lhs_ty), std::move(lhs), std::move(op))}; - } else if (auto ct = dyn_cast(a.result(0).ty()); ct) { - auto clinst = std::vector{}; - auto const len = ct->length(core_cfg_.subgroup_size); - clinst.reserve(len + 1); - clinst.emplace_back(declaration(std::move(lhs_ty), lhs)); - const auto sty = ct->component_ty(); - for (std::int64_t i = 0; i < len; ++i) { - auto op = make(a.operation(), av[i], bv[i], sty); - clinst.emplace_back(expression_statement(assignment(lhs[i], std::move(op)))); - } - return clinst; - } - throw compilation_error(a.loc(), status::ir_expected_coopmatrix_scalar_or_boolean); -} - -std::vector convert_to_opencl_pass::operator()(arith_unary_inst const &a) { - auto const make = [](arithmetic_unary op, clir::expr a, scalar_type sty) -> clir::expr { - switch (op) { - case arithmetic_unary::abs: - if (is_complex_type(sty)) { - return clir::call_builtin(clir::builtin_function::sqrt, - {a.s(0) * a.s(0) + a.s(1) * a.s(1)}); - } - if (is_floating_type(sty)) { - return clir::call_builtin(clir::builtin_function::fabs, {std::move(a)}); - } - return clir::call_builtin(clir::builtin_function::abs, {std::move(a)}); - case arithmetic_unary::neg: - return -std::move(a); - case arithmetic_unary::not_: - return ~std::move(a); - case arithmetic_unary::conj: - return clir::init_vector(to_clir_ty(sty), {a.s(0), -a.s(1)}); - case arithmetic_unary::im: - return std::move(a).s(1); - case arithmetic_unary::re: - return std::move(a).s(0); - } - return {}; - }; - - auto lhs = declare(a.result(0)); - auto lhs_ty = visit(*this, *a.result()->ty()); - auto av = val(a.a()); - if (isa(*a.result(0).ty())) { - if (a.operation() != arithmetic_unary::not_) { - throw compilation_error(a.loc(), status::ir_boolean_unsupported); - } - return {declaration_assignment(std::move(lhs_ty), std::move(lhs), !std::move(av))}; - } else if (auto st = dyn_cast(a.a().ty()); st) { - auto op = make(a.operation(), av, st->ty()); - return {declaration_assignment(std::move(lhs_ty), std::move(lhs), std::move(op))}; - } else if (auto ct = dyn_cast(a.a().ty()); ct) { - auto clinst = std::vector{}; - auto const len = ct->length(core_cfg_.subgroup_size); - clinst.reserve(len + 1); - clinst.emplace_back(declaration(std::move(lhs_ty), lhs)); - const auto sty = ct->component_ty(); - for (std::int64_t i = 0; i < len; ++i) { - auto op = make(a.operation(), av[i], sty); - clinst.emplace_back(expression_statement(assignment(lhs[i], std::move(op)))); - } - return clinst; - } - throw compilation_error(a.loc(), status::ir_expected_coopmatrix_scalar_or_boolean); -} - -std::vector convert_to_opencl_pass::operator()(builtin_inst const &in) { - auto lhs = declare(*in.result()); - auto rhs = [&]() -> clir::expr { - switch (in.builtin_type()) { - case builtin::group_id: - return clir::get_global_id(2); - case builtin::group_size: - return clir::get_global_size(2); - case builtin::num_subgroups: - return clir::get_num_sub_groups(); - case builtin::subgroup_size: - return clir::get_sub_group_size(); - case builtin::subgroup_id: - return clir::get_sub_group_id(); - case builtin::subgroup_local_id: - return clir::get_sub_group_local_id(); - } - throw compilation_error(in.loc(), status::internal_compiler_error); - }; - return {declaration_assignment(visit(*this, *in.result(0).ty()), std::move(lhs), rhs())}; -} - -std::vector convert_to_opencl_pass::operator()(cast_inst const &c) { - auto const make = [](clir::expr a, scalar_type aty, scalar_type rty) -> clir::expr { - if (is_complex_type(aty) && is_complex_type(rty)) { - switch (rty) { - case scalar_type::c32: - return clir::call("convert_float2", {std::move(a)}); - case scalar_type::c64: - return clir::call("convert_double2", {std::move(a)}); - default: - throw status::internal_compiler_error; - } - } else if (is_complex_type(rty)) { - return clir::init_vector(to_clir_ty(rty), {std::move(a), 0}); - } - return cast(to_clir_ty(rty), std::move(a)); - }; - - auto lhs = declare(c.result(0)); - auto lhs_ty = visit(*this, *c.result(0).ty()); - auto av = val(c.a()); - - if (auto rt = dyn_cast(c.result(0).ty()); rt) { - auto aty = get_scalar_type(c.a()); - auto op = make(av, aty, rt->ty()); - return {declaration_assignment(std::move(lhs_ty), std::move(lhs), std::move(op))}; - } else if (auto ct = dyn_cast(c.result(0).ty()); ct) { - const auto rty = ct->component_ty(); - auto at = dyn_cast(c.a().ty()); - if (!at) { - throw compilation_error(c.loc(), status::ir_expected_coopmatrix); - } - const auto aty = at->component_ty(); - - auto clinst = std::vector{}; - auto const len = ct->length(core_cfg_.subgroup_size); - clinst.reserve(len + 1); - clinst.emplace_back(declaration(std::move(lhs_ty), lhs)); - for (std::int64_t i = 0; i < len; ++i) { - auto op = make(av[i], aty, rty); - clinst.emplace_back(expression_statement(assignment(lhs[i], std::move(op)))); - } - return clinst; - } - throw compilation_error(c.loc(), status::ir_expected_coopmatrix_or_scalar); -} - -std::vector convert_to_opencl_pass::operator()(compare_inst const &c) { - auto const make = [](cmp_condition cond, clir::expr a, clir::expr b) -> clir::expr { - switch (cond) { - case cmp_condition::eq: - return std::move(a) == std::move(b); - case cmp_condition::ne: - return std::move(a) != std::move(b); - case cmp_condition::gt: - return std::move(a) > std::move(b); - case cmp_condition::ge: - return std::move(a) >= std::move(b); - case cmp_condition::lt: - return std::move(a) < std::move(b); - case cmp_condition::le: - return std::move(a) <= std::move(b); - } - return {}; - }; - auto v = declare(*c.result()); - return {declaration_assignment(visit(*this, *c.result()->ty()), std::move(v), - make(c.cond(), val(c.a()), val(c.b())))}; -} - -std::vector convert_to_opencl_pass::operator()(constant_inst const &c) { - auto const get_rhs = [&c](scalar_type ty, short ty_bits) { - return std::visit(overloaded{ - [&](bool) -> clir::expr { - throw compilation_error(c.loc(), status::internal_compiler_error); - }, - [&](std::int64_t i) { return clir::expr(i, ty_bits); }, - [&](double d) { return clir::expr(d, ty_bits); }, - [&](std::complex d) { - return init_vector(to_clir_ty(ty), - {clir::expr(d.real(), ty_bits), - clir::expr(d.imag(), ty_bits)}); - }, - }, - c.value()); - }; - auto lhs = declare(c.result(0)); - auto lhs_ty = visit(*this, *c.result()->ty()); - if (isa(*c.result(0).ty())) { - if (!std::holds_alternative(c.value())) { - throw compilation_error(c.loc(), status::internal_compiler_error); - } - return {declaration_assignment(std::move(lhs_ty), std::move(lhs), - clir::expr(std::int8_t{std::get(c.value())}))}; - } else if (auto st = dyn_cast(c.result(0).ty()); st) { - auto ty_bits = static_cast(size(st->ty()) * 8); - return { - declaration_assignment(std::move(lhs_ty), std::move(lhs), get_rhs(st->ty(), ty_bits))}; - } else if (auto ct = dyn_cast(c.result(0).ty()); ct) { - auto ty_bits = static_cast(size(ct->component_ty()) * 8); - auto rhs = get_rhs(ct->component_ty(), ty_bits); - auto clinst = std::vector{}; - auto const len = ct->length(core_cfg_.subgroup_size); - clinst.reserve(len + 1); - clinst.emplace_back(declaration(std::move(lhs_ty), lhs)); - for (std::int64_t i = 0; i < len; ++i) { - clinst.emplace_back(expression_statement(assignment(lhs[i], rhs))); - } - return clinst; - } - throw compilation_error(c.loc(), status::ir_expected_coopmatrix_scalar_or_boolean); -} - -std::vector convert_to_opencl_pass::operator()(cooperative_matrix_load_inst const &c) { - auto lhs = declare(c.result(0)); - auto lhs_ty = visit(*this, *c.result(0).ty()); - auto ot = get_memref_type(c.operand()); - auto rt = get_coopmatrix_type(c.result(0)); - auto &odv = get_dope_vector(c.operand()); - - const int rmode = rt->distributed_mode(); - const int omode = c.t() == transpose::T ? 1 - rmode : rmode; - const bool check_m = c.checked() == checked_flag::both || - (rmode == 0 && c.checked() == checked_flag::rows) || - (rmode == 1 && c.checked() == checked_flag::cols); - const bool check_k = c.checked() == checked_flag::both || - (rmode == 1 && c.checked() == checked_flag::rows) || - (rmode == 0 && c.checked() == checked_flag::cols); - const bool enable_sub_group_reads = - core_cfg_.block_read_write_supported && c.t() == transpose::N && ot->stride(omode) == 1; - - auto clinst = std::vector{}; - auto const len = rt->length(core_cfg_.subgroup_size); - clinst.reserve(len + 5); - clinst.emplace_back(declaration(std::move(lhs_ty), lhs)); - - clir::expr pv[] = {val(c.pos0()), val(c.pos1())}; - auto pointer = clir::var{}; - clinst.emplace_back( - declaration_assignment(visit(*this, *c.operand().ty()), pointer, - val(c.operand()) + pv[0] * odv.stride(0) + pv[1] * odv.stride(1))); - clir::var rem[2] = {}; - if (check_m || check_k) { - clinst.emplace_back( - declaration_assignment(to_clir_ty(scalar_type::index), rem[0], odv.shape(0) - pv[0])); - clinst.emplace_back( - declaration_assignment(to_clir_ty(scalar_type::index), rem[1], odv.shape(1) - pv[1])); - } - - const std::int64_t num_blocks = rt->num_blocks(core_cfg_.subgroup_size); - for (std::int64_t block = 0; block < num_blocks; ++block) { - auto row_in_bounds = clir::var{}; - if (check_m) { - auto m = clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size; - clinst.emplace_back(declaration_assignment(clir::builtin_type::bool_t, row_in_bounds, - m >= -pv[omode] && m < rem[omode])); - } - for (std::int64_t k = 0; k < rt->shape(1 - rmode); ++k) { - auto col_cond = [&] { return k >= -pv[1 - omode] && k < rem[1 - omode]; }; - - auto const store = [&](clir::expr rhs) -> clir::stmt { - return expression_statement( - assignment(lhs[k + block * rt->shape(1 - rmode)], std::move(rhs))); - }; - auto const remainder = rt->shape(rmode) - core_cfg_.subgroup_size * block; - const bool needs_mask = remainder < core_cfg_.subgroup_size; - if (enable_sub_group_reads && !needs_mask && !check_m) { - auto rhs = sub_group_block_read_helper( - pointer + block * core_cfg_.subgroup_size + k * odv.stride(1), ot->element_ty(), - to_clir_address_space(ot->addrspace())); - if (check_k) { - rhs = ternary_conditional(col_cond(), std::move(rhs), 0); - } - clinst.emplace_back(store(std::move(rhs))); - } else { - auto rhs = pointer[odv.stride(omode) * (clir::get_sub_group_local_id() + - block * core_cfg_.subgroup_size) + - k * odv.stride(1 - omode)]; - clir::expr cond = {}; - if (check_m) { - cond = row_in_bounds; - } - if (check_k) { - cond = cond ? cond && col_cond() : col_cond(); - } - if (needs_mask) { - auto mask_cond = clir::get_sub_group_local_id() < remainder; - cond = cond ? cond && mask_cond : mask_cond; - } - - if (cond) { - rhs = ternary_conditional(cond, std::move(rhs), 0); - } - clinst.emplace_back(store(std::move(rhs))); - } - } - } - return clinst; -} -std::vector -convert_to_opencl_pass::operator()(cooperative_matrix_mul_add_inst const &c) { - auto lhs = declare(c.result(0)); - auto lhs_ty = visit(*this, *c.result(0).ty()); - auto rt = get_coopmatrix_type(c.result(0)); - auto at = get_coopmatrix_type(c.a()); - auto bt = get_coopmatrix_type(c.b()); - auto ct = get_coopmatrix_type(c.c()); - auto av = val(c.a()); - auto bv = val(c.b()); - auto cv = val(c.c()); - - const auto a_ty = at->component_ty(); - const auto b_ty = bt->component_ty(); - const auto c_ty = ct->component_ty(); - const auto r_ty = rt->component_ty(); - const bool use_double_buffering = is_complex_type(a_ty) && is_complex_type(b_ty); - - const std::int64_t M = rt->rows(), N = rt->cols(), K = at->cols(); - auto clinst = std::vector{}; - clinst.reserve(M * N + 2); - clinst.emplace_back(declaration(lhs_ty, lhs)); - - auto c_acc_im = clir::var{}; - if (use_double_buffering) { - clinst.emplace_back(declaration(lhs_ty, c_acc_im)); - } - - const std::int64_t num_blocks = rt->num_blocks(core_cfg_.subgroup_size); - const std::int64_t nbb = 4; - for (std::int64_t m_block = 0; m_block < num_blocks; ++m_block) { - for (std::int64_t nb = 0; nb < N; nb += nbb) { - for (std::int64_t k = 0; k < K; ++k) { - for (std::int64_t n = 0; n < nbb; ++n) { - if (nb + n < N) { - auto const n_block = (nb + n) / core_cfg_.subgroup_size; - auto const n_offset = (nb + n) % core_cfg_.subgroup_size; - - auto a = av[k + m_block * K]; - auto b = bv[k + n_block * K]; - auto c_next = lhs[nb + n + m_block * N]; - auto c = [&] { - if (k == 0) { - auto c = cv[nb + n + m_block * N]; - if (c_ty != r_ty) { - if (is_complex_type(r_ty) && !is_complex_type(c_ty)) { - c = clir::init_vector(to_clir_ty(r_ty), {c, 0}); - } else if (r_ty != c_ty) { - return clir::cast(to_clir_ty(r_ty), c); - } - } - return c; - } - return c_next; - }(); - const auto c_next_im = [&] { return c_acc_im[nb + n + m_block * N]; }; - const auto c_im = [&] { - if (k == 0) { - return init_vector(to_clir_ty(r_ty), {0, 0}); - } - return c_next_im(); - }; - - auto const add = [&](auto a_ty, auto b_ty, auto c_ty, auto a, auto b, - auto c, auto c_next) { - if (a_ty == b_ty && b_ty == c_ty) { - clinst.emplace_back(expression_statement(assignment( - std::move(c_next), - clir::fma(std::move(a), std::move(b), std::move(c))))); - } else { - clinst.emplace_back(expression_statement( - assignment(std::move(c_next), - std::move(c) + std::move(a) * std::move(b)))); - } - }; - - if (is_complex_type(a_ty)) { - if (is_complex_type(b_ty)) { - auto b_bc_re = sub_group_broadcast(b.s(0), n_offset); - auto b_bc_im = sub_group_broadcast(b.s(1), n_offset); - add(a_ty, element_type(b_ty), r_ty, a, std::move(b_bc_re), c, - c_next); - add(a_ty, element_type(b_ty), r_ty, a, std::move(b_bc_im), c_im(), - c_next_im()); - } else { - auto b_bc = sub_group_broadcast(b, n_offset); - add(a_ty, b_ty, r_ty, std::move(a), std::move(b_bc), c, c_next); - } - } else if (is_complex_type(b_ty)) { - auto b_bc_re = sub_group_broadcast(b.s(0), n_offset); - auto b_bc_im = sub_group_broadcast(b.s(1), n_offset); - add(a_ty, element_type(b_ty), r_ty, a, std::move(b_bc_re), c.s(0), - c_next.s(0)); - add(a_ty, element_type(b_ty), r_ty, a, std::move(b_bc_im), c.s(1), - c_next.s(1)); - } else { - auto b_bc = sub_group_broadcast(std::move(b), n_offset); - add(a_ty, b_ty, r_ty, std::move(a), std::move(b_bc), std::move(c), - std::move(c_next)); - } - } - } - } - } - } - if (use_double_buffering) { - for (std::int64_t i = 0; i < rt->length(core_cfg_.subgroup_size); ++i) { - clinst.emplace_back(expression_statement( - add_into(lhs[i], clir::init_vector(to_clir_ty(r_ty), - {-c_acc_im[i].s(1), c_acc_im[i].s(0)})))); - } - } - return clinst; -} -std::vector convert_to_opencl_pass::operator()(cooperative_matrix_scale_inst const &c) { - auto lhs = declare(c.result(0)); - auto lhs_ty = visit(*this, *c.result()->ty()); - auto av = val(c.a()); - auto bv = val(c.b()); - auto at = get_scalar_type(c.a()); - auto bt = get_coopmatrix_type(c.b()); - - auto clinst = std::vector{}; - auto const len = bt->length(core_cfg_.subgroup_size); - clinst.reserve(len + 1); - clinst.emplace_back(declaration(std::move(lhs_ty), lhs)); - for (std::int64_t i = 0; i < len; ++i) { - auto op = multiply(at, bt->component_ty(), av, bv[i]); - clinst.emplace_back(expression_statement(assignment(lhs[i], std::move(op)))); - } - return clinst; -} -std::vector convert_to_opencl_pass::operator()(cooperative_matrix_store_inst const &c) { - auto ot = get_memref_type(c.operand()); - auto vt = get_coopmatrix_type(c.val()); - auto &odv = get_dope_vector(c.operand()); - auto valv = val(c.val()); - - const int vmode = vt->distributed_mode(); - const int omode = vmode; - const bool check_m = c.checked() == checked_flag::both || - (vmode == 0 && c.checked() == checked_flag::rows) || - (vmode == 1 && c.checked() == checked_flag::cols); - const bool check_k = c.checked() == checked_flag::both || - (vmode == 1 && c.checked() == checked_flag::rows) || - (vmode == 0 && c.checked() == checked_flag::cols); - - auto clinst = std::vector{}; - auto const len = vt->length(core_cfg_.subgroup_size); - clinst.reserve(len + 4); - - clir::expr pv[] = {val(c.pos0()), val(c.pos1())}; - auto base_pointer = clir::var{}; - clinst.emplace_back( - declaration_assignment(visit(*this, *c.operand().ty()), base_pointer, - val(c.operand()) + pv[0] * odv.stride(0) + pv[1] * odv.stride(1))); - clir::var rem[2] = {}; - if (check_m || check_k) { - clinst.emplace_back( - declaration_assignment(to_clir_ty(scalar_type::index), rem[0], odv.shape(0) - pv[0])); - clinst.emplace_back( - declaration_assignment(to_clir_ty(scalar_type::index), rem[1], odv.shape(1) - pv[1])); - } - - const std::int64_t num_blocks = vt->num_blocks(core_cfg_.subgroup_size); - auto const num_k = vt->shape(1 - vmode); - auto store_block = std::vector{}; - store_block.reserve(num_k); - for (std::int64_t block = 0; block < num_blocks; ++block) { - store_block.clear(); - for (std::int64_t k = 0; k < num_k; ++k) { - auto const remainder = vt->shape(vmode) - core_cfg_.subgroup_size * block; - const bool needs_mask = remainder < core_cfg_.subgroup_size; - - auto pointer = base_pointer + - odv.stride(omode) * - (clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size) + - k * odv.stride(1 - omode); - auto rhs = valv[k + block * vt->shape(1 - vmode)]; - clir::expr cond = {}; - if (check_k) { - cond = k >= -pv[1 - omode] && k < rem[1 - omode]; - } - if (needs_mask) { - auto mask_cond = clir::get_sub_group_local_id() < remainder; - cond = cond ? cond && mask_cond : mask_cond; - } - - if (cond) { - store_block.emplace_back( - clir::if_selection_builder(cond) - .then([&](clir::block_builder &bb) { - for (auto &s : atomic_store_helper_new(c.flag(), ot, std::move(pointer), - std::move(rhs))) { - bb.add(std::move(s)); - } - }) - .get_product()); - } else { - for (auto &s : - atomic_store_helper_new(c.flag(), ot, std::move(pointer), std::move(rhs))) { - store_block.emplace_back(std::move(s)); - } - } - } - - if (check_m) { - auto m = clir::get_sub_group_local_id() + block * core_cfg_.subgroup_size; - clinst.emplace_back(clir::if_selection_builder(m >= -pv[omode] && m < rem[omode]) - .then([&](clir::block_builder &bb) { - for (auto &i : store_block) { - bb.add(i); - } - }) - .get_product()); - } else { - for (auto &i : store_block) { - clinst.emplace_back(i); - } - } - } - - return clinst; -} - -std::vector convert_to_opencl_pass::operator()(expand_inst const &e) { - auto result_var = declare(*e.result()); - auto m = get_memref_type(e.operand()); - auto &dv = get_dope_vector(e.operand()); - auto static_shape = e.static_expand_shape(); - auto dyn_shape = e.expand_shape(); - - auto rhs = val(e.operand()); - auto clinst = std::vector{}; - clinst.emplace_back( - clir::declaration_assignment(this->operator()(*m), std::move(result_var), std::move(rhs))); - - auto shape = std::vector{}; - auto stride = std::vector{}; - shape.reserve(m->dim() + static_shape.size() - 1); - stride.reserve(m->dim() + static_shape.size() - 1); - std::int64_t i = 0; - for (; i < e.expanded_mode(); ++i) { - shape.push_back(dv.shape(i)); - stride.push_back(dv.stride(i)); - } - - auto eshape_cl = std::vector{}; - eshape_cl.reserve(static_shape.size()); - int j = 0; - for (auto &s : static_shape) { - if (is_dynamic_value(s)) { - eshape_cl.emplace_back(val(dyn_shape[j++])); - } else { - eshape_cl.emplace_back(clir::expr(s, static_cast(size(scalar_type::index) * 8))); - } - } - - stride.push_back(m->stride(e.expanded_mode())); - shape.push_back(eshape_cl[0]); - for (std::size_t j = 1; j < eshape_cl.size(); ++j) { - stride.push_back(stride.back() * shape.back()); - shape.push_back(eshape_cl[j]); - } - for (i = e.expanded_mode() + 1; i < m->dim(); ++i) { - shape.push_back(dv.shape(i)); - stride.push_back(dv.stride(i)); - } - - set_dope_vector(e.result(0), - dope_vector::from_value(*e.result(), [&](clir::data_type a, clir::var b, - dope_vector::type t, std::int64_t j) { - auto init = t == dope_vector::type::stride ? stride[j] : shape[j]; - clinst.emplace_back(clir::declaration_assignment(std::move(a), std::move(b), - std::move(init))); - })); - return clinst; -} -std::vector convert_to_opencl_pass::operator()(fuse_inst const &f) { - auto result_var = declare(*f.result()); - auto m = get_memref_type(f.operand()); - auto &dv = get_dope_vector(f.operand()); - - auto rhs = val(f.operand()); - auto shape = std::vector{}; - auto stride = std::vector{}; - shape.reserve(m->dim()); - stride.reserve(m->dim()); - std::int64_t i = 0; - for (; i < f.from(); ++i) { - shape.push_back(dv.shape(i)); - stride.push_back(dv.stride(i)); - } - clir::expr prod = dv.shape(i++); - for (; i <= f.to(); ++i) { - prod = prod * dv.shape(i); - } - shape.push_back(prod); - stride.push_back(dv.stride(f.from())); - for (i = f.to() + 1; i < m->dim(); ++i) { - shape.push_back(dv.shape(i)); - stride.push_back(dv.stride(i)); - } - - auto clinst = std::vector{}; - clinst.emplace_back( - clir::declaration_assignment(this->operator()(*m), std::move(result_var), std::move(rhs))); - - set_dope_vector(*f.result(), - dope_vector::from_value(*f.result(), [&](clir::data_type a, clir::var b, - dope_vector::type t, std::int64_t j) { - auto init = t == dope_vector::type::stride ? stride[j] : shape[j]; - clinst.emplace_back(clir::declaration_assignment(std::move(a), std::move(b), - std::move(init))); - })); - return clinst; -} - -std::vector convert_to_opencl_pass::operator()(load_inst const &e) { - auto rhs = val(e.operand()); - - auto clinst = std::vector{}; - - visit(overloaded{[&](group_data_type const &) { - if (e.index_list().size() != 1) { - throw compilation_error(e.loc(), status::ir_invalid_number_of_indices); - } - - auto idx = val(e.index_list().front()); - rhs = rhs + idx; - - auto &dv = get_dope_vector(e.operand()); - rhs = clir::dereference(std::move(rhs)) + dv.offset(); - - set_dope_vector( - *e.result(), - dope_vector::from_value( - *e.result(), [&](clir::data_type a, clir::var b, - dope_vector::type t, std::int64_t j) { - auto init = t == dope_vector::type::stride ? dv.stride(j) - : dv.shape(j); - clinst.emplace_back(clir::declaration_assignment( - std::move(a), std::move(b), std::move(init)[idx])); - })); - }, - [&](memref_data_type const &m) { - if (static_cast(e.index_list().size()) != m.dim()) { - throw compilation_error(e.loc(), status::ir_invalid_number_of_indices); - } - auto &dv = get_dope_vector(e.operand()); - for (std::int64_t i = 0; i < m.dim(); ++i) { - rhs = rhs + val(e.index_list()[i]) * dv.stride(i); - } - rhs = clir::dereference(std::move(rhs)); - }, - [&e](auto const &) { - throw compilation_error(e.loc(), status::ir_expected_memref_or_group); - }}, - *e.operand().ty()); - - auto lhs = declare(*e.result()); - auto result_type = e.result()->ty(); - if (result_type == nullptr) { - throw compilation_error(e.loc(), status::internal_compiler_error, "Expected type"); - } - - clinst.emplace(clinst.begin(), declaration_assignment(visit(*this, *result_type), - std::move(lhs), std::move(rhs))); - - return clinst; -} - -std::vector convert_to_opencl_pass::operator()(lifetime_stop_inst const &) { - return {}; -} - -std::vector convert_to_opencl_pass::operator()(gemm_inst const &g) { - auto a = get_memref_type(g.A()); - auto b = get_memref_type(g.B()); - auto c = get_memref_type(g.C()); - auto &adv = get_dope_vector(g.A()); - auto &bdv = get_dope_vector(g.B()); - auto &cdv = get_dope_vector(g.C()); - - auto const M = c->shape(0); - auto const N = c->shape(1); - auto const ak = g.tA() == transpose::T ? 0 : 1; - auto const K = a->shape(ak); - - auto gemm_ty = gemm_scalar_type{get_scalar_type(g.alpha()), a->element_ty(), b->element_ty(), - get_scalar_type(g.beta()), c->element_ty()}; - auto cfg = gemm_configuration{std::move(gemm_ty), - g.tA(), - g.tB(), - M, - N, - K, - {a->stride(0), a->stride(1)}, - {b->stride(0), b->stride(1)}, - {c->stride(0), c->stride(1)}, - std::nullopt, - std::nullopt, - g.atomic()}; - auto name = cfg.identifier(); - int name_counter = 0; - while (reserved_names_.find(name) != reserved_names_.end()) { - name = cfg.identifier("gemm" + std::to_string(++name_counter)); - } - if (has_gemm_.find(name) == has_gemm_.end()) { - auto f = generate_gemm(cfg, tiling_, core_cfg_, name, to_clir_address_space(a->addrspace()), - to_clir_address_space(b->addrspace()), - to_clir_address_space(c->addrspace())); - prog_builder_.add(std::move(f)); - } - has_gemm_.emplace(name); - return {clir::expression_statement(clir::call( - std::move(name), {cdv.shape(0), cdv.shape(1), adv.shape(ak), val(g.alpha()), val(g.A()), - adv.stride(0), adv.stride(1), val(g.B()), bdv.stride(0), bdv.stride(1), - val(g.beta()), val(g.C()), cdv.stride(0), cdv.stride(1)}))}; -} - -std::vector convert_to_opencl_pass::operator()(gemv_inst const &g) { - auto a = get_memref_type(g.A()); - auto b = get_memref_type(g.B()); - auto c = get_memref_type(g.C()); - auto &adv = get_dope_vector(g.A()); - auto &bdv = get_dope_vector(g.B()); - auto &cdv = get_dope_vector(g.C()); - - auto const M = c->shape(0); - auto const ak = g.tA() == transpose::T ? 0 : 1; - auto const K = a->shape(ak); - constexpr auto N = 1; - - auto gemm_ty = gemm_scalar_type{get_scalar_type(g.alpha()), a->element_ty(), b->element_ty(), - get_scalar_type(g.beta()), c->element_ty()}; - auto cfg = gemm_configuration{std::move(gemm_ty), - g.tA(), - transpose::N, - M, - N, - K, - {a->stride(0), a->stride(1)}, - {b->stride(0), 0}, - {c->stride(0), 0}, - std::nullopt, - std::nullopt, - g.atomic()}; - auto name = cfg.identifier("gemv"); - int name_counter = 0; - while (reserved_names_.find(name) != reserved_names_.end()) { - name = cfg.identifier("gemv" + std::to_string(++name_counter)); - } - if (has_gemm_.find(name) == has_gemm_.end()) { - auto f = generate_gemm(cfg, tiling_, core_cfg_, name, to_clir_address_space(a->addrspace()), - to_clir_address_space(b->addrspace()), - to_clir_address_space(c->addrspace())); - prog_builder_.add(std::move(f)); - } - has_gemm_.emplace(name); - return {clir::expression_statement( - clir::call(std::move(name), {cdv.shape(0), 1, adv.shape(ak), val(g.alpha()), val(g.A()), - adv.stride(0), adv.stride(1), val(g.B()), bdv.stride(0), 0, - val(g.beta()), val(g.C()), cdv.stride(0), 0}))}; -} - -std::vector convert_to_opencl_pass::operator()(ger_inst const &g) { - auto at = get_memref_type(g.A()); - auto bt = get_memref_type(g.B()); - auto ct = get_memref_type(g.C()); - auto &adv = get_dope_vector(g.A()); - auto &bdv = get_dope_vector(g.B()); - auto &cdv = get_dope_vector(g.C()); - - auto alpha = val(g.alpha()); - auto beta = val(g.beta()); - auto alpha_ty = get_scalar_type(g.alpha()); - auto beta_ty = get_scalar_type(g.beta()); - - auto A = val(g.A()); - auto B = val(g.B()); - auto C = val(g.C()); - - auto bb = clir::block_builder{}; - auto sg_n = bb.declare_assign(clir::generic_uint(), "sg_n", - clir::get_sub_group_id() / tiling_.m_tiles()); - auto sg_m = bb.declare_assign(clir::generic_uint(), "sg_m", - clir::get_sub_group_id() % tiling_.m_tiles()); - tile_loop_uniformly( - bb, cdv.shape(1), core_cfg_.subgroup_size, tiling_.n_tiles(), std::move(sg_n), - [&](clir::block_builder &bb, clir::expr block, clir::expr trip_count) { - auto n = clir::var("n"); - bb.add(clir::for_loop_builder(clir::declaration_assignment(clir::generic_int(), n, 0), - n < std::move(trip_count), ++n) - .body([&](clir::block_builder &bb) { - auto b = bb.declare_assign(to_clir_ty(bt->element_ty()), "b", - B[(block + n) * bdv.stride(0)]); - auto Cb = bb.declare_assign(this->operator()(*ct), "Cb", - C + (block + n) * cdv.stride(1)); - auto m = bb.declare_assign(clir::generic_uint(), "m", - clir::get_sub_group_local_id()); - tile_loop_by_sgs( - bb, cdv.shape(0), core_cfg_.subgroup_size, tiling_.m_tiles(), sg_m, - [&](clir::block_builder &bb, clir::expr block, bool is_remainder, - clir::expr inner_trip_count) { - auto const inner_loop = [&](clir::block_builder &bb) { - auto a = A[(block + m) * adv.stride(0)]; - auto c = bb.declare_assign((*this)(*ct), "c", - Cb + (block + m) * cdv.stride(0)); - auto ab = bb.declare_assign( - to_clir_ty(ct->element_ty()), "ab", - multiply(at->element_ty(), bt->element_ty(), - std::move(a), b)); - const auto ab_scaled = multiply(alpha_ty, ct->element_ty(), - alpha, std::move(ab)); - store_helper(bb, g.atomic(), c, ct->element_ty(), - to_clir_address_space(ct->addrspace()), - std::move(ab_scaled), beta_ty, beta); - }; - if (is_remainder) { - bb.add(clir::if_selection_builder( - m < std::move(inner_trip_count)) - .then(inner_loop) - .get_product()); - } else { - inner_loop(bb); - } - }); - }) - .get_product()); - }); - return {bb.get_product()}; -} - -std::vector convert_to_opencl_pass::operator()(for_inst const &in) { - auto clinst = std::vector{}; - - yielded_vars_.push_back(std::vector{}); - for (std::int64_t i = 0; i < in.num_results(); ++i) { - auto lhs_ty = visit(*this, *in.result(i).ty()); - auto lhs = declare(in.result(i)); - - // Link the iteration variable to the result variable - uintptr_t u = std::bit_cast(&in.result(i)); - uintptr_t v = std::bit_cast(&in.iter_arg(i)); - declared_vars_.back()[v] = declared_vars_.back()[u]; - - auto iinit = val(in.iter_init(i)); - if (auto ct = dyn_cast(in.result(i).ty()); ct) { - clinst.emplace_back(declaration(std::move(lhs_ty), lhs)); - auto const len = ct->length(core_cfg_.subgroup_size); - for (std::int64_t j = 0; j < len; ++j) { - clinst.emplace_back(expression_statement(assignment(lhs[j], iinit[j]))); - } - } else { - clinst.emplace_back(clir::declaration_assignment(lhs_ty, lhs, iinit)); - } - yielded_vars_.back().emplace_back(std::move(lhs)); - } - - auto lv = declare(in.loop_var()); - auto lv_ty = visit(*this, *in.loop_var().ty()); - auto start = clir::declaration_assignment(std::move(lv_ty), lv, val(in.from())); - auto condition = lv < val(in.to()); - auto step = in.has_step() ? clir::add_into(lv, val(in.step())) : ++lv; - auto body = run_on_region(in.body()); - clinst.emplace_back(clir::stmt(std::make_shared( - std::move(start), std::move(condition), std::move(step), std::move(body)))); - - yielded_vars_.pop_back(); - return clinst; -} - -std::vector convert_to_opencl_pass::operator()(foreach_inst const &p) { - throw compilation_error(p.loc(), status::not_implemented); - /*auto lv = declare(p.loop_var()); - auto lv_ty = visit(*this, *p.loop_var().ty()); - auto from = val(p.from()); - auto to = val(p.to()); - auto bb = clir::block_builder{}; - auto sg = bb.declare_assign(clir::generic_uint(), "sg", clir::get_sub_group_id()); - auto m = bb.declare_assign(clir::generic_uint(), "m", clir::get_sub_group_local_id()); - auto trip_count = bb.declare_assign(lv_ty, "trip_count", to - from); - tile_loop_by_sgs( - bb, trip_count, core_cfg_.subgroup_size, tiling_.m_tiles() * tiling_.n_tiles(), - std::move(sg), [&](clir::block_builder &bb, clir::expr block, bool, clir::expr) { - bb.add(clir::declaration_assignment(lv_ty, lv, std::move(block) + m + from)); - bb.add(run_on_region(p.body())); - }); - return {bb.get_product()};*/ -} - -std::vector convert_to_opencl_pass::operator()(hadamard_inst const &g) { - auto at = get_memref_type(g.A()); - auto bt = get_memref_type(g.B()); - auto ct = get_memref_type(g.C()); - auto &adv = get_dope_vector(g.A()); - auto &bdv = get_dope_vector(g.B()); - auto &cdv = get_dope_vector(g.C()); - - auto alpha = val(g.alpha()); - auto beta = val(g.beta()); - auto alpha_ty = get_scalar_type(g.alpha()); - auto beta_ty = get_scalar_type(g.beta()); - - auto A = val(g.A()); - auto B = val(g.B()); - auto C = val(g.C()); - - auto bb = clir::block_builder{}; - auto sg = bb.declare_assign(clir::generic_uint(), "sg", clir::get_sub_group_id()); - auto m = bb.declare_assign(clir::generic_uint(), "m", clir::get_sub_group_local_id()); - tile_loop_by_sgs( - bb, cdv.shape(0), core_cfg_.subgroup_size, tiling_.m_tiles() * tiling_.n_tiles(), - std::move(sg), - [&](clir::block_builder &bb, clir::expr block, bool is_remainder, - clir::expr inner_trip_count) { - auto const inner_loop = [&](clir::block_builder &bb) { - auto b = B[(block + m) * bdv.stride(0)]; - auto a = A[(block + m) * adv.stride(0)]; - - auto c = bb.declare_assign((*this)(*ct), "c", C + (block + m) * cdv.stride(0)); - auto ab = bb.declare_assign( - to_clir_ty(ct->element_ty()), "ab", - multiply(at->element_ty(), bt->element_ty(), std::move(a), b)); - const auto ab_scaled = multiply(alpha_ty, ct->element_ty(), alpha, std::move(ab)); - store_helper(bb, g.atomic(), c, ct->element_ty(), - to_clir_address_space(ct->addrspace()), std::move(ab_scaled), beta_ty, - beta); - }; - if (is_remainder) { - bb.add(clir::if_selection_builder(m < std::move(inner_trip_count)) - .then(inner_loop) - .get_product()); - } else { - inner_loop(bb); - } - }); - return {bb.get_product()}; -} - -std::vector convert_to_opencl_pass::operator()(if_inst const &in) { - auto clinst = std::vector{}; - yielded_vars_.push_back(std::vector{}); - for (auto const &r : in.results()) { - auto v = declare(r); - clinst.emplace_back(clir::declaration(visit(*this, *r.ty()), v)); - yielded_vars_.back().emplace_back(std::move(v)); - } - auto ib = clir::if_selection_builder(val(in.condition())); - ib.set_then(run_on_region(in.then())); - if (!in.is_otherwise_empty()) { - ib.set_otherwise(run_on_region(in.otherwise())); - } - yielded_vars_.pop_back(); - clinst.emplace_back(ib.get_product()); - return clinst; -} - -std::vector convert_to_opencl_pass::operator()(parallel_inst const &p) { - return {run_on_region(p.body())}; -} - -std::vector convert_to_opencl_pass::operator()(size_inst const &s) { - auto v = declare(*s.result()); - auto &dv = get_dope_vector(s.operand()); - - return {clir::declaration_assignment(visit(*this, *s.result()->ty()), std::move(v), - dv.shape(s.mode()))}; -} - -std::vector convert_to_opencl_pass::operator()(subview_inst const &s) { - auto result_var = declare(*s.result()); - auto t = get_memref_type(s.operand()); - auto &dv = get_dope_vector(s.operand()); - - auto rhs = val(s.operand()); - int j = 0; - auto shape_out = std::vector{}; - auto stride_out = std::vector{}; - shape_out.reserve(t->dim()); - stride_out.reserve(t->dim()); - auto dyn_offsets = s.offsets(); - auto dyn_sizes = s.sizes(); - for (std::int64_t i = 0, joffset = 0, jsize = 0; i < t->dim(); ++i) { - auto offset = s.static_offsets()[i]; - - auto offset_cl = clir::expr{}; - if (is_dynamic_value(offset)) { - offset_cl = val(dyn_offsets[joffset++]); - } else { - offset_cl = - clir::expr(offset, static_cast(tinytc::size(scalar_type::index) * 8)); - } - rhs = rhs + offset_cl * dv.stride(j); - - auto size = s.static_sizes()[i]; - if (size > 0 || is_dynamic_value(size)) { - auto size_cl = clir::expr{}; - if (is_dynamic_value(size)) { - size_cl = val(dyn_sizes[jsize++]); - } else { - size_cl = - clir::expr(size, static_cast(tinytc::size(scalar_type::index) * 8)); - } - shape_out.emplace_back(size_cl); - stride_out.emplace_back(dv.stride(j)); - } - - ++j; - } - - auto clinst = std::vector{}; - clinst.emplace_back( - clir::declaration_assignment(this->operator()(*t), std::move(result_var), std::move(rhs))); - - set_dope_vector(*s.result(), - dope_vector::from_value(*s.result(), [&](clir::data_type a, clir::var b, - dope_vector::type t, std::int64_t j) { - auto init = t == dope_vector::type::stride ? stride_out[j] : shape_out[j]; - clinst.emplace_back(clir::declaration_assignment(std::move(a), std::move(b), - std::move(init))); - })); - return clinst; -} - -std::vector convert_to_opencl_pass::operator()(store_inst const &s) { - auto ot = get_memref_type(s.operand()); - - if (static_cast(s.index_list().size()) != ot->dim()) { - throw compilation_error(s.loc(), status::ir_invalid_number_of_indices); - } - - auto lhs = val(s.operand()); - auto &dv = get_dope_vector(s.operand()); - for (std::int64_t i = 0; i < ot->dim(); ++i) { - lhs = lhs + val(s.index_list()[i]) * dv.stride(i); - } - - auto rhs = val(s.val()); - return atomic_store_helper_new(s.flag(), ot, std::move(lhs), std::move(rhs)); -} - -std::vector convert_to_opencl_pass::operator()(sum_inst const &inst) { - auto at = get_memref_type(inst.A()); - auto bt = get_memref_type(inst.B()); - auto &adv = get_dope_vector(inst.A()); - auto &bdv = get_dope_vector(inst.B()); - - auto alpha = val(inst.alpha()); - auto beta = val(inst.beta()); - auto alpha_ty = get_scalar_type(inst.alpha()); - auto beta_ty = get_scalar_type(inst.beta()); - - auto zero = clir::expr(0.0, static_cast(size(at->element_ty()) * 8)); - - auto A = val(inst.A()); - auto B = val(inst.B()); - auto bb = clir::block_builder{}; - auto acc = bb.declare_assign(to_clir_ty(at->element_ty()), "acc", std::move(zero)); - auto sg = bb.declare_assign(clir::generic_uint(), "sg", clir::get_sub_group_id()); - auto m = bb.declare_assign(clir::generic_uint(), "m", clir::get_sub_group_local_id()); - if (bt->dim() == 0) { - tile_loop_by_sgs(bb, adv.shape(0), core_cfg_.subgroup_size, - tiling_.n_tiles() * tiling_.m_tiles(), std::move(sg), - [&](clir::block_builder &bb, clir::expr block, bool is_remainder, - clir::expr inner_trip_count) { - auto const inner_loop = [&](clir::block_builder &bb) { - auto a = A[(block + m) * adv.stride(0)]; - bb.add(add_into(acc, std::move(a))); - }; - if (is_remainder) { - bb.add(clir::if_selection_builder(m < std::move(inner_trip_count)) - .then(inner_loop) - .get_product()); - } else { - inner_loop(bb); - } - }); - auto sum = bb.declare_assign(to_clir_ty(bt->element_ty()), "sum", - clir::work_group_reduce_add(acc)); - bb.add(clir::if_selection_builder(clir::get_sub_group_id() == 0 && - clir::get_sub_group_local_id() == 0) - .then([&](clir::block_builder &bb) { - const auto sum_scaled = multiply(alpha_ty, at->element_ty(), alpha, sum); - store_helper(bb, inst.atomic(), B, bt->element_ty(), - to_clir_address_space(bt->addrspace()), std::move(sum_scaled), - beta_ty, beta); - }) - .get_product()); - } else if (bt->dim() == 1) { - auto ak = inst.tA() == transpose::T ? 0 : 1; - tile_loop_by_sgs( - bb, adv.shape(0), core_cfg_.subgroup_size, tiling_.n_tiles() * tiling_.m_tiles(), - std::move(sg), - [&](clir::block_builder &bb, clir::expr block, bool is_remainder, - clir::expr inner_trip_count) { - auto n = clir::var("n"); - auto const inner_loop = [&](clir::block_builder &bb) { - bb.add(clir::for_loop_builder( - clir::declaration_assignment(clir::generic_int(), n, 0), - n < adv.shape(ak), ++n) - .body([&](clir::block_builder &bb) { - auto a = - A[(block + m) * adv.stride(1 - ak) + n * adv.stride(ak)]; - bb.add(add_into(acc, std::move(a))); - }) - .get_product()); - auto b = bb.declare_assign((*this)(*bt), "b", B + (block + m) * bdv.stride(0)); - const auto sum_scaled = multiply(alpha_ty, at->element_ty(), alpha, acc); - store_helper(bb, inst.atomic(), b, bt->element_ty(), - to_clir_address_space(bt->addrspace()), std::move(sum_scaled), - beta_ty, beta); - }; - if (is_remainder) { - bb.add(clir::if_selection_builder(m < std::move(inner_trip_count)) - .then(inner_loop) - .get_product()); - } else { - inner_loop(bb); - } - }); - } - return {bb.get_product()}; -} - -std::vector convert_to_opencl_pass::operator()(work_group_inst const &in) { - auto const make = [](work_group_operation operation, clir::expr operand, - scalar_type sty) -> clir::expr { - switch (operation) { - case work_group_operation::reduce_add: - if (is_complex_type(sty)) { - return init_vector(to_clir_ty(sty), {clir::work_group_reduce_add(operand.s(0)), - clir::work_group_reduce_add(operand.s(1))}); - } - return clir::work_group_reduce_add(operand); - } - return {}; - }; - - auto lhs = declare(in.result(0)); - auto lhs_ty = visit(*this, *in.result()->ty()); - auto sty = get_scalar_type(in.operand()); - return {declaration_assignment(std::move(lhs_ty), std::move(lhs), - make(in.operation(), val(in.operand()), sty))}; -} - -std::vector convert_to_opencl_pass::operator()(yield_inst const &in) { - if (yielded_vars_.empty()) { - throw compilation_error(in.loc(), status::ir_unexpected_yield); - } - if (static_cast(yielded_vars_.back().size()) != in.num_operands()) { - throw compilation_error(in.loc(), status::ir_yield_mismatch); - } - std::vector clinst; - for (std::int64_t i = 0; i < in.num_operands(); ++i) { - auto &yielded_var = yielded_vars_.back()[i]; - auto ov = val(in.op(i)); - if (auto ct = dyn_cast(in.op(i).ty()); ct) { - auto const len = ct->length(core_cfg_.subgroup_size); - for (std::int64_t j = 0; j < len; ++j) { - clinst.push_back(expression_statement(assignment(yielded_var[j], ov[j]))); - } - } else { - clinst.push_back(expression_statement(assignment(yielded_var, ov))); - } - } - return clinst; -} - -/* Region nodes */ -clir::stmt convert_to_opencl_pass::run_on_region(region_node const ®) { - declared_vars_.push_back({}); - auto bb = clir::block_builder{}; - for (auto &s : reg.insts()) { - for (auto &cs : visit(*this, s)) { - bb.add(cs); - } - } - declared_vars_.pop_back(); - return bb.get_product(); -} - -/* Function nodes */ -auto convert_to_opencl_pass::run_on_function(function_node const &fn) -> clir::func { - stack_high_water_mark_ = 0; - auto const subgroup_size = fn.subgroup_size(); - try { - core_cfg_ = info_->get_core_config(subgroup_size); - } catch (std::out_of_range const &e) { - throw compilation_error(fn.loc(), status::unsupported_subgroup_size); - } - auto const work_group_size = fn.work_group_size(); - tiling_[0] = work_group_size[0] / subgroup_size; - tiling_[1] = work_group_size[1]; - - stack_ = clir::var("stack"); - - // Create prototype - auto fb = clir::kernel_builder(std::string(fn.name())); - for (auto const &v : fn.params()) { - fb.argument(visit(*this, *v.ty()), declare(v)); - auto dv = visit( - overloaded{[&fb, &v](memref_data_type const &) -> std::optional { - return std::make_optional(dope_vector::from_value( - v, [&](clir::data_type a, clir::var b, dope_vector::type, - std::int64_t) { fb.argument(std::move(a), std::move(b)); })); - }, - [&fb, &v](group_data_type const &) -> std::optional { - return std::make_optional(dope_vector::from_value( - v, [&](clir::data_type a, clir::var b, dope_vector::type, - std::int64_t) { fb.argument(std::move(a), std::move(b)); })); - }, - [](auto const &) { return std::nullopt; }}, - *v.ty()); - if (dv) { - set_dope_vector(v, std::move(*dv)); - } - } - - fb.attribute(clir::reqd_work_group_size(work_group_size[0], work_group_size[1], 1)); - fb.attribute(clir::intel_reqd_sub_group_size(subgroup_size)); - - auto body = run_on_region(fn.body()); - - if (stack_high_water_mark_ > 0) { - auto bb = dynamic_cast(body.get()); - if (bb == nullptr) { - throw compilation_error(fn.loc(), status::internal_compiler_error, - "Expected clir basic block"); - } - bb->stmts().insert(bb->stmts().begin(), - declaration(clir::array_of(clir::data_type(clir::builtin_type::uchar_t, - clir::address_space::local_t), - stack_high_water_mark_), - stack_, {clir::aligned(size(scalar_type::f64) * 8)})); - } - return clir::function(fb.get_product(), std::move(body)); -} - -/* Program nodes */ -auto convert_to_opencl_pass::run_on_program(program_node const &p) -> clir::prog { - reserved_names_.clear(); - for (auto const &fn : p) { - reserved_names_.insert(std::string(fn.name())); - } - - prog_builder_ = clir::program_builder{}; - for (auto const &fn : p) { - prog_builder_.add(run_on_function(fn)); - } - return prog_builder_.get_product(); -} - -} // namespace tinytc diff --git a/src/pass/convert_to_opencl.hpp b/src/pass/convert_to_opencl.hpp deleted file mode 100644 index 599a1636..00000000 --- a/src/pass/convert_to_opencl.hpp +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef CONVERT_TO_OPENCL_20240913_HPP -#define CONVERT_TO_OPENCL_20240913_HPP - -#include "device_info.hpp" -#include "node/data_type_node.hpp" -#include "node/function_node.hpp" -#include "node/inst_node.hpp" -#include "node/program_node.hpp" -#include "node/region_node.hpp" -#include "node/value_node.hpp" -#include "tiling.hpp" -#include "tinytc/types.hpp" - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace tinytc { - -class dope_vector { - public: - enum class type { shape, stride, offset }; - using decl_fun_t = std::function; - static dope_vector from_value(value_node const &v, decl_fun_t declare); - - inline dope_vector() {} - inline dope_vector(std::vector shape, std::vector stride) - : shape_(std::move(shape)), stride_(std::move(stride)) {} - - inline auto shape(std::int64_t i) { return shape_[i]; } - inline auto stride(std::int64_t i) { return stride_[i]; } - inline auto offset() { return offset_; } - inline auto offset(clir::expr offset) { offset_ = std::move(offset); } - - private: - static dope_vector from_memref_type(std::string const &prefix, memref_data_type const &m, - clir::data_type dt, decl_fun_t declare); - std::vector shape_, stride_; - clir::expr offset_ = clir::expr(std::int64_t(0)); -}; - -class convert_to_opencl_pass { - public: - convert_to_opencl_pass(::tinytc_core_info const *info); - - /* Data type nodes */ - clir::data_type operator()(void_data_type const &); - clir::data_type operator()(boolean_data_type const &ct); - clir::data_type operator()(coopmatrix_data_type const &ct); - clir::data_type operator()(group_data_type const &g); - clir::data_type operator()(memref_data_type const &m); - clir::data_type operator()(scalar_data_type const &s); - - /* Inst nodes */ - std::vector operator()(alloca_inst const &a); - std::vector operator()(axpby_inst const &a); - std::vector operator()(barrier_inst const &b); - std::vector operator()(builtin_inst const &b); - std::vector operator()(arith_inst const &a); - std::vector operator()(arith_unary_inst const &a); - std::vector operator()(cast_inst const &c); - std::vector operator()(compare_inst const &c); - std::vector operator()(constant_inst const &c); - std::vector operator()(cooperative_matrix_load_inst const &c); - std::vector operator()(cooperative_matrix_mul_add_inst const &c); - std::vector operator()(cooperative_matrix_scale_inst const &c); - std::vector operator()(cooperative_matrix_store_inst const &c); - std::vector operator()(expand_inst const &e); - std::vector operator()(fuse_inst const &f); - std::vector operator()(load_inst const &e); - std::vector operator()(lifetime_stop_inst const &l); - std::vector operator()(gemm_inst const &g); - std::vector operator()(gemv_inst const &g); - std::vector operator()(ger_inst const &g); - std::vector operator()(for_inst const &p); - std::vector operator()(foreach_inst const &in); - std::vector operator()(hadamard_inst const &g); - std::vector operator()(if_inst const &in); - std::vector operator()(parallel_inst const &p); - std::vector operator()(size_inst const &s); - std::vector operator()(subview_inst const &s); - std::vector operator()(store_inst const &s); - std::vector operator()(sum_inst const &s); - std::vector operator()(work_group_inst const &in); - std::vector operator()(yield_inst const &in); - - auto run_on_program(program_node const &p) -> clir::prog; - - private: - auto run_on_region(region_node const ®) -> clir::stmt; - auto run_on_function(function_node const &fn) -> clir::func; - auto val(value_node const &v) -> clir::expr; - - auto get_dope_vector(value_node const &v) -> dope_vector &; - void set_dope_vector(value_node const &v, dope_vector dv); - clir::var declare(value_node const &v); - auto get_coopmatrix_type(value_node const &v) const -> const coopmatrix_data_type *; - auto get_memref_type(value_node const &v) const -> const memref_data_type *; - static auto get_scalar_type(value_node const &v) -> scalar_type; - - ::tinytc_core_info const *info_; - clir::program_builder prog_builder_; - std::vector> declared_vars_; - std::vector> yielded_vars_; - std::unordered_map dope_vector_; - std::unordered_set reserved_names_; - std::unordered_set has_gemm_; - clir::var stack_; - std::size_t stack_high_water_mark_ = 0; - local_tiling tiling_ = {}; - core_config core_cfg_ = {}; -}; - -} // namespace tinytc - -#endif // CONVERT_TO_OPENCL_20240913_HPP diff --git a/src/pass/lower_foreach.cpp b/src/pass/lower_foreach.cpp index 11b40563..a445c0b4 100644 --- a/src/pass/lower_foreach.cpp +++ b/src/pass/lower_foreach.cpp @@ -25,7 +25,7 @@ void make_loop0(region_builder &bb, value from, value to, value sg_id, int sgs, auto sg_lid = bb.add(make_cast(sg_lid_i32, ity, loc)); auto size = bb.add(make_arith(arithmetic::sub, to, from, ity, loc)); auto work_item_offset = bb.add(make_arith(arithmetic::add, from, sg_lid, ity, loc)); - tile_loop_by_sgs_new( + tile_loop_by_sgs( bb, size, sgs, num_tiles, sg_id, [&](region_builder &bb, value block, bool is_remainder, value trip_count) { auto loop_var0 = bb.add(make_arith(arithmetic::add, block, work_item_offset, ity, loc)); @@ -90,7 +90,7 @@ auto foreach_generator::operator()(foreach_inst &in) -> inst { auto sg_id0 = bb.add(make_arith(arithmetic::rem, sg_id, c_m_tiles, sg_id->ty(), in.loc())); auto size1 = bb.add(make_arith(arithmetic::sub, &to[1], &from[1], ity, in.loc())); - tile_loop_uniformly_new( + tile_loop_uniformly( bb, size1, core_cfg_.subgroup_size, tiling_.n_tiles(), sg_id1, [&](region_builder &bb, value block, value trip_count1) { auto from1 = bb.add(make_arith(arithmetic::add, &from[1], block, ity, in.loc())); diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp index eb2323f7..ef5216b3 100644 --- a/src/pass/lower_linalg.cpp +++ b/src/pass/lower_linalg.cpp @@ -296,37 +296,37 @@ void linalg_generator::operator()(gemm_inst &in) { const auto block_size1 = max_cols; if (const_shape1) { - tile_loop_uniformly_new( + tile_loop_uniformly( bb, c_shape1, block_size1, tiling_.n_tiles(), sg_n, [&](region_builder &bb, value n_block, value trip_count) { auto const_trip_count = get_int_constant(trip_count); if (!const_trip_count) { throw compilation_error(in.loc(), status::internal_compiler_error); } - tile_loop_by_sgs_new(bb, c_shape0, block_size0, tiling_.m_tiles(), sg_m, - [&](region_builder &bb, value m_block, bool m_check, value) { - gemm_microkernel( - bb, in.tA(), in.tB(), in.atomic(), &in.alpha(), - &in.A(), &in.B(), &in.beta(), &in.C(), K, m_block, - block_size0, m_check, n_block, *const_trip_count, - false, at->element_data_ty(), bt->element_data_ty(), - ct->element_data_ty(), in.loc()); - }); + tile_loop_by_sgs(bb, c_shape0, block_size0, tiling_.m_tiles(), sg_m, + [&](region_builder &bb, value m_block, bool m_check, value) { + gemm_microkernel(bb, in.tA(), in.tB(), in.atomic(), + &in.alpha(), &in.A(), &in.B(), &in.beta(), + &in.C(), K, m_block, block_size0, m_check, + n_block, *const_trip_count, false, + at->element_data_ty(), bt->element_data_ty(), + ct->element_data_ty(), in.loc()); + }); }); } else { - tile_loop_by_sgs_new(bb, c_shape1, block_size1, tiling_.n_tiles(), sg_n, - [&](region_builder &bb, value n_block, bool n_check, value) { - tile_loop_by_sgs_new( - bb, c_shape0, block_size0, tiling_.m_tiles(), sg_m, - [&](region_builder &bb, value m_block, bool m_check, value) { - gemm_microkernel( - bb, in.tA(), in.tB(), in.atomic(), &in.alpha(), - &in.A(), &in.B(), &in.beta(), &in.C(), K, m_block, - block_size0, m_check, n_block, block_size1, n_check, - at->element_data_ty(), bt->element_data_ty(), - ct->element_data_ty(), in.loc()); - }); - }); + tile_loop_by_sgs(bb, c_shape1, block_size1, tiling_.n_tiles(), sg_n, + [&](region_builder &bb, value n_block, bool n_check, value) { + tile_loop_by_sgs( + bb, c_shape0, block_size0, tiling_.m_tiles(), sg_m, + [&](region_builder &bb, value m_block, bool m_check, value) { + gemm_microkernel(bb, in.tA(), in.tB(), in.atomic(), + &in.alpha(), &in.A(), &in.B(), &in.beta(), + &in.C(), K, m_block, block_size0, m_check, + n_block, block_size1, n_check, + at->element_data_ty(), bt->element_data_ty(), + ct->element_data_ty(), in.loc()); + }); + }); } add(std::move(parallel)); diff --git a/src/required_extensions.cpp b/src/required_extensions.cpp deleted file mode 100644 index ea4d3edc..00000000 --- a/src/required_extensions.cpp +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "required_extensions.hpp" - -#include -#include - -#include - -namespace tinytc { - -auto ext_list(std::vector const &ext) -> std::vector { - auto result = std::vector{}; - result.reserve(ext.size() + 1); - for (auto const &e : ext) { - result.emplace_back(clir::to_string(e)); - } - result.emplace_back("cl_khr_fp64"); - return result; -} - -auto required_extensions(clir::func f) -> std::vector { - return ext_list(clir::get_required_extensions(std::move(f))); -} -auto required_extensions(clir::prog p) -> std::vector { - return ext_list(clir::get_required_extensions(std::move(p))); -} - -} // namespace tinytc diff --git a/src/required_extensions.hpp b/src/required_extensions.hpp deleted file mode 100644 index e4c5a23a..00000000 --- a/src/required_extensions.hpp +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef REQUIRED_EXTENSIONS_20240416_HPP -#define REQUIRED_EXTENSIONS_20240416_HPP - -#include - -#include -#include - -namespace tinytc { - -auto required_extensions(clir::func f) -> std::vector; -auto required_extensions(clir::prog p) -> std::vector; - -} // namespace tinytc - -#endif // REQUIRED_EXTENSIONS_20240416_HPP diff --git a/src/scalar_type.cpp b/src/scalar_type.cpp index 374202a0..f189cf47 100644 --- a/src/scalar_type.cpp +++ b/src/scalar_type.cpp @@ -70,87 +70,6 @@ std::int32_t alignment(scalar_type ty, component_count count) { return scale * tinytc_scalar_type_size(static_cast(ty)); } -clir::data_type to_clir_ty(scalar_type ty, clir::address_space as, clir::type_qualifier q) { - return to_clir_ty(ty, 1, as, q); -} - -clir::data_type to_clir_ty(scalar_type ty, short size, clir::address_space as, - clir::type_qualifier q) { - const auto base_type = [](scalar_type ty) { - switch (ty) { - case scalar_type::i8: - return clir::builtin_type::char_t; - case scalar_type::i16: - return clir::builtin_type::short_t; - case scalar_type::i32: - return clir::builtin_type::int_t; - case scalar_type::i64: - return clir::builtin_type::long_t; - case scalar_type::index: - return clir::builtin_type::long_t; - case scalar_type::f32: - case scalar_type::c32: - return clir::builtin_type::float_t; - case scalar_type::f64: - case scalar_type::c64: - return clir::builtin_type::double_t; - } - return clir::builtin_type::void_t; - }; - const auto components = [](scalar_type ty) -> short { - switch (ty) { - case scalar_type::i8: - case scalar_type::i16: - case scalar_type::i32: - case scalar_type::i64: - case scalar_type::index: - case scalar_type::f32: - case scalar_type::f64: - return 1; - case scalar_type::c32: - case scalar_type::c64: - return 2; - } - return 0; - }; - size *= components(ty); - if (size == 1) { - return clir::data_type(base_type(ty), as, q); - } - return clir::data_type(base_type(ty), size, as, q); -} - -clir::data_type to_clir_atomic_ty(scalar_type ty, clir::address_space as, clir::type_qualifier q) { - auto const base_type = [](scalar_type ty) { - switch (ty) { - case scalar_type::i32: - return clir::builtin_type::atomic_int_t; - case scalar_type::i64: - return clir::builtin_type::atomic_long_t; - case scalar_type::index: - return clir::builtin_type::atomic_long_t; - case scalar_type::f32: - return clir::builtin_type::atomic_float_t; - case scalar_type::f64: - return clir::builtin_type::atomic_double_t; - default: - break; - } - return clir::builtin_type::void_t; - }; - return clir::data_type(base_type(ty), as, q); -} - -clir::address_space to_clir_address_space(address_space as) { - switch (as) { - case address_space::global: - return clir::address_space::global_t; - case address_space::local: - return clir::address_space::local_t; - } - return clir::address_space::global_t; -} - } // namespace tinytc char const *tinytc_scalar_type_to_string(tinytc_scalar_type_t ty) { diff --git a/src/scalar_type.hpp b/src/scalar_type.hpp index cca70e64..c72192a6 100644 --- a/src/scalar_type.hpp +++ b/src/scalar_type.hpp @@ -6,9 +6,6 @@ #include "tinytc/types.hpp" -#include -#include - #include namespace tinytc { @@ -23,15 +20,6 @@ bool is_integer_type(scalar_type ty); scalar_type element_type(scalar_type ty); scalar_type compatible_type(scalar_type a_ty, scalar_type b_ty); std::int32_t alignment(scalar_type ty, component_count count = component_count::v1); -clir::data_type to_clir_ty(scalar_type ty, clir::address_space as = clir::address_space::generic_t, - clir::type_qualifier q = clir::type_qualifier::none); -clir::data_type to_clir_ty(scalar_type ty, short size, - clir::address_space as = clir::address_space::generic_t, - clir::type_qualifier q = clir::type_qualifier::none); -clir::data_type to_clir_atomic_ty(scalar_type ty, - clir::address_space as = clir::address_space::generic_t, - clir::type_qualifier q = clir::type_qualifier::none); -clir::address_space to_clir_address_space(address_space as); } // namespace tinytc diff --git a/test/generator.cpp b/test/generator.cpp index a0705c73..369e2fe0 100644 --- a/test/generator.cpp +++ b/test/generator.cpp @@ -2,7 +2,6 @@ // SPDX-License-Identifier: BSD-3-Clause #include "device_info.hpp" -#include "gemm_generator.hpp" #include "gemm_tools.hpp" #include "reference_counted.hpp" #include "scalar_type.hpp" @@ -48,22 +47,6 @@ TEST_CASE("suggest work group size") { check(dynamic, dynamic, 16, 4, 8); } -TEST_CASE("routine names") { - auto cfg = gemm_configuration{gemm_scalar_type{scalar_type::f32, scalar_type::f64}, - transpose::N, - transpose::T, - 16, - 32, - 48, - {1, 20}, - {1, 40}, - {1, 50}, - 3.14, - std::nullopt}; - CHECK(cfg.identifier("gemm") == "gemm_f32f32f32f64f64_An_Bt_M16_N32_K48_Astride1_20_Bstride1_" - "40_Cstride1_50_alpha40091eb851eb851f_betad"); -} - TEST_CASE("max register block") { auto s1 = max_register_block_gemm(4, 16, 8192); CHECK(s1.first == 32); From 1313922bc2074439111120deeaf15ab19bede1dd Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 22 Nov 2024 12:08:41 +0100 Subject: [PATCH 130/297] Remove or salvage OpenCL-C lit tests Signed-off-by: Carsten Uphoff --- test/codegen/atomic.ir | 27 --- test/codegen/axpby1.ir | 34 ---- test/codegen/cast.ir | 46 ----- test/codegen/coopmatrix_basic.ir | 56 ------ test/codegen/coopmatrix_load.ir | 181 ------------------ test/codegen/coopmatrix_mul_add.ir | 92 --------- test/codegen/coopmatrix_store.ir | 83 -------- test/codegen/dope_vector0.ir | 18 -- test/codegen/dope_vector_group0.ir | 22 --- test/codegen/expand.ir | 134 ------------- test/codegen/for.ir | 40 ---- test/codegen/func_attributes.ir | 27 --- test/codegen/fuse.ir | 32 ---- test/codegen/if.ir | 98 ---------- test/codegen/load.ir | 21 -- test/codegen/scalar_arithmetic.ir | 106 ---------- test/codegen/size.ir | 16 -- test/codegen/store.ir | 12 -- test/codegen/subgroup.ir | 16 -- test/codegen/work_group.ir | 16 -- test/{codegen => opt/check-ir}/axpby0.ir | 2 +- test/opt/check-ir/expand.ir | 3 +- test/opt/check-ir/fuse.ir | 21 ++ .../check-ir}/scalar_arithmetic_error.ir | 2 +- .../check-ir}/syntax_error0.ir | 2 +- .../check-ir}/syntax_error1.ir | 2 +- .../check-ir}/type_mismatch0.ir | 2 +- .../check-ir}/type_mismatch1.ir | 2 +- test/spv/func_attributes.ir | 22 +++ 29 files changed, 50 insertions(+), 1085 deletions(-) delete mode 100644 test/codegen/atomic.ir delete mode 100644 test/codegen/axpby1.ir delete mode 100644 test/codegen/cast.ir delete mode 100644 test/codegen/coopmatrix_basic.ir delete mode 100644 test/codegen/coopmatrix_load.ir delete mode 100644 test/codegen/coopmatrix_mul_add.ir delete mode 100644 test/codegen/coopmatrix_store.ir delete mode 100644 test/codegen/dope_vector0.ir delete mode 100644 test/codegen/dope_vector_group0.ir delete mode 100644 test/codegen/expand.ir delete mode 100644 test/codegen/for.ir delete mode 100644 test/codegen/func_attributes.ir delete mode 100644 test/codegen/fuse.ir delete mode 100644 test/codegen/if.ir delete mode 100644 test/codegen/load.ir delete mode 100644 test/codegen/scalar_arithmetic.ir delete mode 100644 test/codegen/size.ir delete mode 100644 test/codegen/store.ir delete mode 100644 test/codegen/subgroup.ir delete mode 100644 test/codegen/work_group.ir rename test/{codegen => opt/check-ir}/axpby0.ir (81%) create mode 100644 test/opt/check-ir/fuse.ir rename test/{codegen => opt/check-ir}/scalar_arithmetic_error.ir (83%) rename test/{codegen => opt/check-ir}/syntax_error0.ir (75%) rename test/{codegen => opt/check-ir}/syntax_error1.ir (72%) rename test/{codegen => opt/check-ir}/type_mismatch0.ir (76%) rename test/{codegen => opt/check-ir}/type_mismatch1.ir (85%) create mode 100644 test/spv/func_attributes.ir diff --git a/test/codegen/atomic.ir b/test/codegen/atomic.ir deleted file mode 100644 index 49f5d186..00000000 --- a/test/codegen/atomic.ir +++ /dev/null @@ -1,27 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc < %s | filecheck %s -func @atomic_store(%A: memref) { - %f0 = constant 0.0 : f64 - %i0 = constant 0 : index - store.atomic %f0, %A[%i0] - store.atomic_add %f0, %A[%i0] -; CHECK-LABEL: void atomic_store({{.*}} -; CHECK: atomic_store_explicit((global volatile atomic_double*) (A + i0 * 1), f0, memory_order_relaxed, memory_scope_work_group); -; CHECK: atomic_fetch_add_explicit((global volatile atomic_double*) (A + i0 * 1), f0, memory_order_relaxed, memory_scope_work_group); -} - -func @atomic_store_c64(%A: memref) { - %f0 = constant [0.0, 0.0] : c64 - %i0 = constant 0 : index - store.atomic %f0, %A[%i0] - store.atomic_add %f0, %A[%i0] -; CHECK-LABEL: void atomic_store_c64({{.*}} -; CHECK: atomic_store_explicit((atomic_double*global volatile) &(*(A + i0 * 1)).x, f0, memory_order_relaxed, memory_scope_work_group); -; CHECK-NEXT: atomic_store_explicit((atomic_double*global volatile) &(*(A + i0 * 1)).y, f0, memory_order_relaxed, memory_scope_work_group); -; CHECK: atomic_fetch_add_explicit((atomic_double*global volatile) &(*(A + i0 * 1)).x, f0, memory_order_relaxed, memory_scope_work_group); -; CHECK-NEXT: atomic_fetch_add_explicit((atomic_double*global volatile) &(*(A + i0 * 1)).y, f0, memory_order_relaxed, memory_scope_work_group); -} - - diff --git a/test/codegen/axpby1.ir b/test/codegen/axpby1.ir deleted file mode 100644 index 0986ad5c..00000000 --- a/test/codegen/axpby1.ir +++ /dev/null @@ -1,34 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc < %s -func @axpby0(%alpha: f32, %A: memref, %B: memref) { - %z = constant 0.0 : f32 - axpby.n %alpha, %A, %z, %B -} - -func @axpby1(%alpha: f32, %A: memref>, %B: memref) { - %z = constant 0.0 : f32 - axpby.n %alpha, %A, %z, %B -} - -func @axpby2(%alpha: f32, %A: memref, %B: memref) { - %z = constant 0.0 : f32 - axpby.n %alpha, %A, %z, %B -} - -func @axpby3(%alpha: f32, %A: memref, %B: memref) { - %z = constant 0.0 : f32 - %lb = constant 0 : index - %ub = constant 5 : index - for %i=%lb,%ub { - %A0 = subview %A[0:48,0:48,0:4,%i] : memref - %B0 = subview %B[0:48,0:48,0:4,%i] : memref - %ub1 = constant 4 : index - for %j=%lb,%ub1 { - %A1 = subview %A0[0:48,0:48,%j] : memref - %B1 = subview %B0[0:48,0:48,%j] : memref - axpby.t %alpha, %A1, %z, %B1 - } - } -} diff --git a/test/codegen/cast.ir b/test/codegen/cast.ir deleted file mode 100644 index caa336c8..00000000 --- a/test/codegen/cast.ir +++ /dev/null @@ -1,46 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc -O0 < %s | filecheck %s -func @cast_ii() { - %0 = constant 2 : index - %1 = cast %0 : i32 -; CHECK-LABEL: void cast_ii() { -; CHECK: int x1 = (int) x; -} -func @cast_ff() { - %0 = constant 2.0 : f32 - %1 = cast %0 : f64 -; CHECK-LABEL: void cast_ff() { -; CHECK: double x1 = (double) x; -} -func @cast_cc() { - %0 = constant [2.0, 0.0] : c32 - %1 = cast %0 : c64 -; CHECK-LABEL: void cast_cc() { -; CHECK: double2 x1 = convert_double2(x); -} -func @cast_if() { - %0 = constant 2 : i32 - %1 = cast %0 : f32 -; CHECK-LABEL: void cast_if() { -; CHECK: float x1 = (float) x; -} -func @cast_fi() { - %0 = constant 2.0 : f32 - %1 = cast %0 : i16 -; CHECK-LABEL: void cast_fi() { -; CHECK: short x1 = (short) x; -} -func @cast_ic() { - %0 = constant 2 : i8 - %1 = cast %0 : c32 -; CHECK-LABEL: void cast_ic() { -; CHECK: float2 x1 = (float2) (x, 0); -} -func @cast_fc() { - %0 = constant 2.0 : f64 - %1 = cast %0 : c32 -; CHECK-LABEL: void cast_fc() { -; CHECK: float2 x1 = (float2) (x, 0); -} diff --git a/test/codegen/coopmatrix_basic.ir b/test/codegen/coopmatrix_basic.ir deleted file mode 100644 index 5b135727..00000000 --- a/test/codegen/coopmatrix_basic.ir +++ /dev/null @@ -1,56 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc -O0 < %s | filecheck %s -func @coopmatrix_constant() { - %0 = constant 1.0 : coopmatrix -; CHECK-LABEL: void coopmatrix_constant({{.*}} -; CHECK: double x[5]; -; CHECK-NEXT: x[0] = 0x1p+0; -; CHECK-NEXT: x[1] = 0x1p+0; -; CHECK-NEXT: x[2] = 0x1p+0; -; CHECK-NEXT: x[3] = 0x1p+0; -; CHECK-NEXT: x[4] = 0x1p+0; -} - -func @coopmatrix_add() { - %0 = constant 1.0 : coopmatrix - %1 = constant 1.0 : coopmatrix - %2 = arith.add %0, %1 : coopmatrix -; CHECK-LABEL: void coopmatrix_add({{.*}} -; CHECK: double x2[4]; -; CHECK-NEXT: x2[0] = x[0] + x1[0]; -; CHECK-NEXT: x2[1] = x[1] + x1[1]; -; CHECK-NEXT: x2[2] = x[2] + x1[2]; -; CHECK-NEXT: x2[3] = x[3] + x1[3]; -} - -func @coopmatrix_neg() subgroup_size(16) { - %0 = constant 1.0 : coopmatrix - %1 = arith.neg %0 : coopmatrix -; CHECK-LABEL: void coopmatrix_neg({{.*}} -; CHECK: double x1[8]; -; CHECK-NEXT: x1[0] = -x[0]; -; CHECK-NEXT: x1[1] = -x[1]; -; CHECK-NEXT: x1[2] = -x[2]; -; CHECK-NEXT: x1[3] = -x[3]; -; CHECK-NEXT: x1[4] = -x[4]; -; CHECK-NEXT: x1[5] = -x[5]; -; CHECK-NEXT: x1[6] = -x[6]; -; CHECK-NEXT: x1[7] = -x[7]; -} - -func @coopmatrix_cast() subgroup_size(16) { - %0 = constant 1 : coopmatrix - %1 = cast %0 : coopmatrix -; CHECK-LABEL: void coopmatrix_cast({{.*}} -; CHECK: float2 x1[8]; -; CHECK-NEXT: x1[0] = (float2) (x[0], 0); -; CHECK-NEXT: x1[1] = (float2) (x[1], 0); -; CHECK-NEXT: x1[2] = (float2) (x[2], 0); -; CHECK-NEXT: x1[3] = (float2) (x[3], 0); -; CHECK-NEXT: x1[4] = (float2) (x[4], 0); -; CHECK-NEXT: x1[5] = (float2) (x[5], 0); -; CHECK-NEXT: x1[6] = (float2) (x[6], 0); -; CHECK-NEXT: x1[7] = (float2) (x[7], 0); -} diff --git a/test/codegen/coopmatrix_load.ir b/test/codegen/coopmatrix_load.ir deleted file mode 100644 index 3e64cdfb..00000000 --- a/test/codegen/coopmatrix_load.ir +++ /dev/null @@ -1,181 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc -O0 < %s | filecheck %s -func @coopmatrix_a_load_n(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n %A[%x,%y] : coopmatrix -; CHECK-LABEL: void coopmatrix_a_load_n({{.*}} -; CHECK: float x1[8]; -; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; -; CHECK-NEXT: x1[0] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 0 * 64))); -; CHECK-NEXT: x1[1] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 1 * 64))); -; CHECK-NEXT: x1[2] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 2 * 64))); -; CHECK-NEXT: x1[3] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 3 * 64))); -; CHECK-NEXT: x1[4] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 4 * 64))); -; CHECK-NEXT: x1[5] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 5 * 64))); -; CHECK-NEXT: x1[6] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 6 * 64))); -; CHECK-NEXT: x1[7] = as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 7 * 64))); -} - -func @coopmatrix_a_load_n_rows_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n.rows_checked %A[%x,%y] : coopmatrix -; CHECK-LABEL: void coopmatrix_a_load_n_rows_checked({{.*}} -; CHECK: float x1[4]; -; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; -; CHECK-NEXT: long x3 = 64 - x; -; CHECK-NEXT: long x4 = 48 - y; -; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x3; -; CHECK-NEXT: x1[0] = x5 ? x2[1 * (get_sub_group_local_id() + 0) + 0 * 64] : 0; -; CHECK-NEXT: x1[1] = x5 ? x2[1 * (get_sub_group_local_id() + 0) + 1 * 64] : 0; -; CHECK-NEXT: bool x6 = get_sub_group_local_id() + 16 >= -x && get_sub_group_local_id() + 16 < x3; -; CHECK-NEXT: x1[2] = x6 ? x2[1 * (get_sub_group_local_id() + 16) + 0 * 64] : 0; -; CHECK-NEXT: x1[3] = x6 ? x2[1 * (get_sub_group_local_id() + 16) + 1 * 64] : 0; -} - -func @coopmatrix_a_load_n_cols_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n.cols_checked %A[%x,%y] : coopmatrix -; CHECK-LABEL: void coopmatrix_a_load_n_cols_checked({{.*}} -; CHECK: float x1[4]; -; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; -; CHECK-NEXT: long x3 = 64 - x; -; CHECK-NEXT: long x4 = 48 - y; -; CHECK-NEXT: x1[0] = 0 >= -y && 0 < x4 ? as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 0 * 64))) : 0; -; CHECK-NEXT: x1[1] = 1 >= -y && 1 < x4 ? as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 0 + 1 * 64))) : 0; -; CHECK-NEXT: x1[2] = 0 >= -y && 0 < x4 ? as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 16 + 0 * 64))) : 0; -; CHECK-NEXT: x1[3] = 1 >= -y && 1 < x4 ? as_float(intel_sub_group_block_read_ui((global uint*) (x2 + 16 + 1 * 64))) : 0; -} - -func @coopmatrix_a_load_n_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n.both_checked %A[%x,%y] : coopmatrix -; CHECK-LABEL: void coopmatrix_a_load_n_checked({{.*}} -; CHECK: float x1[16]; -; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; -; CHECK-NEXT: long x3 = 64 - x; -; CHECK-NEXT: long x4 = 48 - y; -; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x3; -; CHECK-NEXT: x1[0] = x5 && (0 >= -y && 0 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 0 * 64] : 0; -; CHECK-NEXT: x1[1] = x5 && (1 >= -y && 1 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 1 * 64] : 0; -; CHECK-NEXT: x1[2] = x5 && (2 >= -y && 2 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 2 * 64] : 0; -; CHECK-NEXT: x1[3] = x5 && (3 >= -y && 3 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 3 * 64] : 0; -; CHECK-NEXT: x1[4] = x5 && (4 >= -y && 4 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 4 * 64] : 0; -; CHECK-NEXT: x1[5] = x5 && (5 >= -y && 5 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 5 * 64] : 0; -; CHECK-NEXT: x1[6] = x5 && (6 >= -y && 6 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 6 * 64] : 0; -; CHECK-NEXT: x1[7] = x5 && (7 >= -y && 7 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 7 * 64] : 0; -; CHECK-NEXT: bool x6 = get_sub_group_local_id() + 16 >= -x && get_sub_group_local_id() + 16 < x3; -; CHECK-NEXT: x1[8] = x6 && (0 >= -y && 0 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 0 * 64] : 0; -; CHECK-NEXT: x1[9] = x6 && (1 >= -y && 1 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 1 * 64] : 0; -; CHECK-NEXT: x1[10] = x6 && (2 >= -y && 2 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 2 * 64] : 0; -; CHECK-NEXT: x1[11] = x6 && (3 >= -y && 3 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 3 * 64] : 0; -; CHECK-NEXT: x1[12] = x6 && (4 >= -y && 4 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 4 * 64] : 0; -; CHECK-NEXT: x1[13] = x6 && (5 >= -y && 5 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 5 * 64] : 0; -; CHECK-NEXT: x1[14] = x6 && (6 >= -y && 6 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 6 * 64] : 0; -; CHECK-NEXT: x1[15] = x6 && (7 >= -y && 7 < x4) ? x2[1 * (get_sub_group_local_id() + 16) + 7 * 64] : 0; -} - -func @coopmatrix_a_load_t(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.t %A[%x,%y] : coopmatrix -; CHECK-LABEL: void coopmatrix_a_load_t({{.*}} -; CHECK: float x1[8]; -; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; -; CHECK-NEXT: x1[0] = x2[64 * (get_sub_group_local_id() + 0) + 0 * 1]; -; CHECK-NEXT: x1[1] = x2[64 * (get_sub_group_local_id() + 0) + 1 * 1]; -; CHECK-NEXT: x1[2] = x2[64 * (get_sub_group_local_id() + 0) + 2 * 1]; -; CHECK-NEXT: x1[3] = x2[64 * (get_sub_group_local_id() + 0) + 3 * 1]; -; CHECK-NEXT: x1[4] = x2[64 * (get_sub_group_local_id() + 0) + 4 * 1]; -; CHECK-NEXT: x1[5] = x2[64 * (get_sub_group_local_id() + 0) + 5 * 1]; -; CHECK-NEXT: x1[6] = x2[64 * (get_sub_group_local_id() + 0) + 6 * 1]; -; CHECK-NEXT: x1[7] = x2[64 * (get_sub_group_local_id() + 0) + 7 * 1]; -} - -func @coopmatrix_a_load_t_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.t.both_checked %A[%x,%y] : coopmatrix -; CHECK-LABEL: void coopmatrix_a_load_t_checked({{.*}} -; CHECK: float x1[8]; -; CHECK-NEXT: global float* x2 = A + x * 1 + y * 64; -; CHECK-NEXT: long x3 = 64 - x; -; CHECK-NEXT: long x4 = 48 - y; -; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -y && get_sub_group_local_id() + 0 < x4; -; CHECK-NEXT: x1[0] = x5 && (0 >= -x && 0 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 0 * 1] : 0; -; CHECK-NEXT: x1[1] = x5 && (1 >= -x && 1 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 1 * 1] : 0; -; CHECK-NEXT: x1[2] = x5 && (2 >= -x && 2 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 2 * 1] : 0; -; CHECK-NEXT: x1[3] = x5 && (3 >= -x && 3 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 3 * 1] : 0; -; CHECK-NEXT: x1[4] = x5 && (4 >= -x && 4 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 4 * 1] : 0; -; CHECK-NEXT: x1[5] = x5 && (5 >= -x && 5 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 5 * 1] : 0; -; CHECK-NEXT: x1[6] = x5 && (6 >= -x && 6 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 6 * 1] : 0; -; CHECK-NEXT: x1[7] = x5 && (7 >= -x && 7 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 7 * 1] : 0; -} - -func @coopmatrix_b_load_n(%B: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n %B[%x,%y] : coopmatrix -; CHECK-LABEL: void coopmatrix_b_load_n({{.*}} -; CHECK: float x1[8]; -; CHECK-NEXT: global float* x2 = B + x * 1 + y * 64; -; CHECK-NEXT: x1[0] = x2[64 * (get_sub_group_local_id() + 0) + 0 * 1]; -; CHECK-NEXT: x1[1] = x2[64 * (get_sub_group_local_id() + 0) + 1 * 1]; -; CHECK-NEXT: x1[2] = x2[64 * (get_sub_group_local_id() + 0) + 2 * 1]; -; CHECK-NEXT: x1[3] = x2[64 * (get_sub_group_local_id() + 0) + 3 * 1]; -; CHECK-NEXT: x1[4] = x2[64 * (get_sub_group_local_id() + 0) + 4 * 1]; -; CHECK-NEXT: x1[5] = x2[64 * (get_sub_group_local_id() + 0) + 5 * 1]; -; CHECK-NEXT: x1[6] = x2[64 * (get_sub_group_local_id() + 0) + 6 * 1]; -; CHECK-NEXT: x1[7] = x2[64 * (get_sub_group_local_id() + 0) + 7 * 1]; -} - -func @coopmatrix_b_load_n_checked(%B: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.n.both_checked %B[%x,%y] : coopmatrix -; CHECK-LABEL: void coopmatrix_b_load_n_checked({{.*}} -; CHECK: float x1[16]; -; CHECK-NEXT: global float* x2 = B + x * 1 + y * 64; -; CHECK-NEXT: long x3 = 64 - x; -; CHECK-NEXT: long x4 = 48 - y; -; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -y && get_sub_group_local_id() + 0 < x4; -; CHECK-NEXT: x1[0] = x5 && (0 >= -x && 0 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 0 * 1] : 0; -; CHECK-NEXT: x1[1] = x5 && (1 >= -x && 1 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 1 * 1] : 0; -; CHECK-NEXT: x1[2] = x5 && (2 >= -x && 2 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 2 * 1] : 0; -; CHECK-NEXT: x1[3] = x5 && (3 >= -x && 3 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 3 * 1] : 0; -; CHECK-NEXT: x1[4] = x5 && (4 >= -x && 4 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 4 * 1] : 0; -; CHECK-NEXT: x1[5] = x5 && (5 >= -x && 5 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 5 * 1] : 0; -; CHECK-NEXT: x1[6] = x5 && (6 >= -x && 6 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 6 * 1] : 0; -; CHECK-NEXT: x1[7] = x5 && (7 >= -x && 7 < x3) ? x2[64 * (get_sub_group_local_id() + 0) + 7 * 1] : 0; -; CHECK-NEXT: bool x6 = get_sub_group_local_id() + 16 >= -y && get_sub_group_local_id() + 16 < x4; -; CHECK-NEXT: x1[8] = x6 && (0 >= -x && 0 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 0 * 1] : 0; -; CHECK-NEXT: x1[9] = x6 && (1 >= -x && 1 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 1 * 1] : 0; -; CHECK-NEXT: x1[10] = x6 && (2 >= -x && 2 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 2 * 1] : 0; -; CHECK-NEXT: x1[11] = x6 && (3 >= -x && 3 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 3 * 1] : 0; -; CHECK-NEXT: x1[12] = x6 && (4 >= -x && 4 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 4 * 1] : 0; -; CHECK-NEXT: x1[13] = x6 && (5 >= -x && 5 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 5 * 1] : 0; -; CHECK-NEXT: x1[14] = x6 && (6 >= -x && 6 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 6 * 1] : 0; -; CHECK-NEXT: x1[15] = x6 && (7 >= -x && 7 < x3) ? x2[64 * (get_sub_group_local_id() + 16) + 7 * 1] : 0; -} - -func @coopmatrix_b_load_t(%B: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.t %B[%x,%y] : coopmatrix -; CHECK-LABEL: void coopmatrix_b_load_t({{.*}} -; CHECK: float x1[8]; -; CHECK-NEXT: global float* x2 = B + x * 1 + y * 64; -; CHECK-NEXT: x1[0] = x2[1 * (get_sub_group_local_id() + 0) + 0 * 64]; -; CHECK-NEXT: x1[1] = x2[1 * (get_sub_group_local_id() + 0) + 1 * 64]; -; CHECK-NEXT: x1[2] = x2[1 * (get_sub_group_local_id() + 0) + 2 * 64]; -; CHECK-NEXT: x1[3] = x2[1 * (get_sub_group_local_id() + 0) + 3 * 64]; -; CHECK-NEXT: x1[4] = x2[1 * (get_sub_group_local_id() + 0) + 4 * 64]; -; CHECK-NEXT: x1[5] = x2[1 * (get_sub_group_local_id() + 0) + 5 * 64]; -; CHECK-NEXT: x1[6] = x2[1 * (get_sub_group_local_id() + 0) + 6 * 64]; -; CHECK-NEXT: x1[7] = x2[1 * (get_sub_group_local_id() + 0) + 7 * 64]; -} - -func @coopmatrix_b_load_t_checked(%B: memref, %x: index, %y: index) subgroup_size(16) { - %0 = cooperative_matrix_load.t.both_checked %B[%x,%y] : coopmatrix -; CHECK-LABEL: void coopmatrix_b_load_t_checked({{.*}} -; CHECK: float x1[8]; -; CHECK-NEXT: global float* x2 = B + x * 1 + y * 64; -; CHECK-NEXT: long x3 = 64 - x; -; CHECK-NEXT: long x4 = 48 - y; -; CHECK-NEXT: bool x5 = get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x3; -; CHECK-NEXT: x1[0] = x5 && (0 >= -y && 0 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 0 * 64] : 0; -; CHECK-NEXT: x1[1] = x5 && (1 >= -y && 1 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 1 * 64] : 0; -; CHECK-NEXT: x1[2] = x5 && (2 >= -y && 2 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 2 * 64] : 0; -; CHECK-NEXT: x1[3] = x5 && (3 >= -y && 3 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 3 * 64] : 0; -; CHECK-NEXT: x1[4] = x5 && (4 >= -y && 4 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 4 * 64] : 0; -; CHECK-NEXT: x1[5] = x5 && (5 >= -y && 5 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 5 * 64] : 0; -; CHECK-NEXT: x1[6] = x5 && (6 >= -y && 6 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 6 * 64] : 0; -; CHECK-NEXT: x1[7] = x5 && (7 >= -y && 7 < x4) ? x2[1 * (get_sub_group_local_id() + 0) + 7 * 64] : 0; -} diff --git a/test/codegen/coopmatrix_mul_add.ir b/test/codegen/coopmatrix_mul_add.ir deleted file mode 100644 index 8ea8218f..00000000 --- a/test/codegen/coopmatrix_mul_add.ir +++ /dev/null @@ -1,92 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc -O0 < %s | filecheck %s -func @coopmatrix_mul_add_ff() subgroup_size(16) { - %a = constant 1.0 : coopmatrix - %b = constant 1.0 : coopmatrix - %c = constant 1.0 : coopmatrix - %c_next = cooperative_matrix_mul_add %a, %b, %c : coopmatrix -; CHECK-LABEL: void coopmatrix_mul_add_ff({{.*}} -; CHECK: float c_next[4]; -; CHECK-NEXT: c_next[0] = fma(a[0], sub_group_broadcast(b[0], 0), c[0]); -; CHECK-NEXT: c_next[1] = fma(a[0], sub_group_broadcast(b[0], 1), c[1]); -; CHECK-NEXT: c_next[2] = fma(a[0], sub_group_broadcast(b[0], 2), c[2]); -; CHECK-NEXT: c_next[3] = fma(a[0], sub_group_broadcast(b[0], 3), c[3]); -; CHECK-NEXT: c_next[0] = fma(a[1], sub_group_broadcast(b[1], 0), c_next[0]); -; CHECK-NEXT: c_next[1] = fma(a[1], sub_group_broadcast(b[1], 1), c_next[1]); -; CHECK-NEXT: c_next[2] = fma(a[1], sub_group_broadcast(b[1], 2), c_next[2]); -; CHECK-NEXT: c_next[3] = fma(a[1], sub_group_broadcast(b[1], 3), c_next[3]); -} - -func @coopmatrix_mul_add_cf() subgroup_size(16) { - %a = constant [1.0, 0.0] : coopmatrix - %b = constant 1.0 : coopmatrix - %c = constant [1.0, 0.0] : coopmatrix - %c_next = cooperative_matrix_mul_add %a, %b, %c : coopmatrix -; CHECK-LABEL: void coopmatrix_mul_add_cf({{.*}} -; CHECK: float2 c_next[4]; -; CHECK-NEXT: c_next[0] = c[0] + a[0] * sub_group_broadcast(b[0], 0); -; CHECK-NEXT: c_next[1] = c[1] + a[0] * sub_group_broadcast(b[0], 1); -; CHECK-NEXT: c_next[2] = c[2] + a[0] * sub_group_broadcast(b[0], 2); -; CHECK-NEXT: c_next[3] = c[3] + a[0] * sub_group_broadcast(b[0], 3); -; CHECK-NEXT: c_next[0] = c_next[0] + a[1] * sub_group_broadcast(b[1], 0); -; CHECK-NEXT: c_next[1] = c_next[1] + a[1] * sub_group_broadcast(b[1], 1); -; CHECK-NEXT: c_next[2] = c_next[2] + a[1] * sub_group_broadcast(b[1], 2); -; CHECK-NEXT: c_next[3] = c_next[3] + a[1] * sub_group_broadcast(b[1], 3); -} - -func @coopmatrix_mul_add_fc() subgroup_size(16) { - %a = constant 1.0 : coopmatrix - %b = constant [1.0, 0.0] : coopmatrix - %c = constant [1.0, 0.0] : coopmatrix - %c_next = cooperative_matrix_mul_add %a, %b, %c : coopmatrix -; CHECK-LABEL: void coopmatrix_mul_add_fc({{.*}} -; CHECK: float2 c_next[4]; -; CHECK-NEXT: c_next[0].x = c[0].x + a[0] * sub_group_broadcast(b[0].x, 0); -; CHECK-NEXT: c_next[0].y = c[0].y + a[0] * sub_group_broadcast(b[0].y, 0); -; CHECK-NEXT: c_next[1].x = c[1].x + a[0] * sub_group_broadcast(b[0].x, 1); -; CHECK-NEXT: c_next[1].y = c[1].y + a[0] * sub_group_broadcast(b[0].y, 1); -; CHECK-NEXT: c_next[2].x = c[2].x + a[0] * sub_group_broadcast(b[0].x, 2); -; CHECK-NEXT: c_next[2].y = c[2].y + a[0] * sub_group_broadcast(b[0].y, 2); -; CHECK-NEXT: c_next[3].x = c[3].x + a[0] * sub_group_broadcast(b[0].x, 3); -; CHECK-NEXT: c_next[3].y = c[3].y + a[0] * sub_group_broadcast(b[0].y, 3); -; CHECK-NEXT: c_next[0].x = c_next[0].x + a[1] * sub_group_broadcast(b[1].x, 0); -; CHECK-NEXT: c_next[0].y = c_next[0].y + a[1] * sub_group_broadcast(b[1].y, 0); -; CHECK-NEXT: c_next[1].x = c_next[1].x + a[1] * sub_group_broadcast(b[1].x, 1); -; CHECK-NEXT: c_next[1].y = c_next[1].y + a[1] * sub_group_broadcast(b[1].y, 1); -; CHECK-NEXT: c_next[2].x = c_next[2].x + a[1] * sub_group_broadcast(b[1].x, 2); -; CHECK-NEXT: c_next[2].y = c_next[2].y + a[1] * sub_group_broadcast(b[1].y, 2); -; CHECK-NEXT: c_next[3].x = c_next[3].x + a[1] * sub_group_broadcast(b[1].x, 3); -; CHECK-NEXT: c_next[3].y = c_next[3].y + a[1] * sub_group_broadcast(b[1].y, 3); -} - -func @coopmatrix_mul_add_cc() subgroup_size(16) { - %a = constant [1.0, 0.0] : coopmatrix - %b = constant [1.0, 0.0] : coopmatrix - %c = constant [1.0, 0.0] : coopmatrix - %c_next = cooperative_matrix_mul_add %a, %b, %c : coopmatrix -; CHECK-LABEL: void coopmatrix_mul_add_cc({{.*}} -; CHECK: float2 c_next[4]; -; CHECK-NEXT: float2 x[4]; -; CHECK-NEXT: c_next[0] = c[0] + a[0] * sub_group_broadcast(b[0].x, 0); -; CHECK-NEXT: x[0] = (float2) (0, 0) + a[0] * sub_group_broadcast(b[0].y, 0); -; CHECK-NEXT: c_next[1] = c[1] + a[0] * sub_group_broadcast(b[0].x, 1); -; CHECK-NEXT: x[1] = (float2) (0, 0) + a[0] * sub_group_broadcast(b[0].y, 1); -; CHECK-NEXT: c_next[2] = c[2] + a[0] * sub_group_broadcast(b[0].x, 2); -; CHECK-NEXT: x[2] = (float2) (0, 0) + a[0] * sub_group_broadcast(b[0].y, 2); -; CHECK-NEXT: c_next[3] = c[3] + a[0] * sub_group_broadcast(b[0].x, 3); -; CHECK-NEXT: x[3] = (float2) (0, 0) + a[0] * sub_group_broadcast(b[0].y, 3); -; CHECK-NEXT: c_next[0] = c_next[0] + a[1] * sub_group_broadcast(b[1].x, 0); -; CHECK-NEXT: x[0] = x[0] + a[1] * sub_group_broadcast(b[1].y, 0); -; CHECK-NEXT: c_next[1] = c_next[1] + a[1] * sub_group_broadcast(b[1].x, 1); -; CHECK-NEXT: x[1] = x[1] + a[1] * sub_group_broadcast(b[1].y, 1); -; CHECK-NEXT: c_next[2] = c_next[2] + a[1] * sub_group_broadcast(b[1].x, 2); -; CHECK-NEXT: x[2] = x[2] + a[1] * sub_group_broadcast(b[1].y, 2); -; CHECK-NEXT: c_next[3] = c_next[3] + a[1] * sub_group_broadcast(b[1].x, 3); -; CHECK-NEXT: x[3] = x[3] + a[1] * sub_group_broadcast(b[1].y, 3); -; CHECK-NEXT: c_next[0] += (float2) (-x[0].y, x[0].x); -; CHECK-NEXT: c_next[1] += (float2) (-x[1].y, x[1].x); -; CHECK-NEXT: c_next[2] += (float2) (-x[2].y, x[2].x); -; CHECK-NEXT: c_next[3] += (float2) (-x[3].y, x[3].x); -} diff --git a/test/codegen/coopmatrix_store.ir b/test/codegen/coopmatrix_store.ir deleted file mode 100644 index 692f7059..00000000 --- a/test/codegen/coopmatrix_store.ir +++ /dev/null @@ -1,83 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc -O0 < %s | filecheck %s -func @coopmatrix_a_store_n(%A: memref, %x: index, %y: index) subgroup_size(16) { - %c0 = constant 1.0 : coopmatrix - cooperative_matrix_store %c0, %A[%x,%y] -; CHECK-LABEL: void coopmatrix_a_store_n({{.*}} -; CHECK: global float* x1 = A + x * 1 + y * 64; -; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 0 * 64) = c0[0]; -; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 1 * 64) = c0[1]; -} - -func @coopmatrix_a_store_n_rows_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %c0 = constant 1.0 : coopmatrix - cooperative_matrix_store.rows_checked %c0, %A[%x,%y] -; CHECK-LABEL: void coopmatrix_a_store_n_rows_checked({{.*}} -; CHECK: global float* x1 = A + x * 1 + y * 64; -; CHECK-NEXT: long x2 = 64 - x; -; CHECK-NEXT: long x3 = 48 - y; -; CHECK-NEXT: if (get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x2) { -; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 0 * 64) = c0[0]; -; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 1 * 64) = c0[1]; -; CHECK-NEXT: } -} - -func @coopmatrix_a_store_n_cols_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %c0 = constant 1.0 : coopmatrix - cooperative_matrix_store.cols_checked %c0, %A[%x,%y] -; CHECK-LABEL: void coopmatrix_a_store_n_cols_checked({{.*}} -; CHECK: global float* x1 = A + x * 1 + y * 64; -; CHECK-NEXT: long x2 = 64 - x; -; CHECK-NEXT: long x3 = 48 - y; -; CHECK-NEXT: if (0 >= -y && 0 < x3) { -; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 0 * 64) = c0[0]; -; CHECK-NEXT: } -; CHECK-NEXT: if (1 >= -y && 1 < x3) { -; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 1 * 64) = c0[1]; -; CHECK-NEXT: } -} - -func @coopmatrix_a_store_n_checked(%A: memref, %x: index, %y: index) subgroup_size(16) { - %c0 = constant 1.0 : coopmatrix - cooperative_matrix_store.both_checked %c0, %A[%x,%y] -; CHECK-LABEL: void coopmatrix_a_store_n_checked({{.*}} -; CHECK: global float* x1 = A + x * 1 + y * 64; -; CHECK-NEXT: long x2 = 64 - x; -; CHECK-NEXT: long x3 = 48 - y; -; CHECK-NEXT: if (get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x2) { -; CHECK-NEXT: if (0 >= -y && 0 < x3) { -; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 0 * 64) = c0[0]; -; CHECK-NEXT: } -; CHECK-NEXT: if (1 >= -y && 1 < x3) { -; CHECK-NEXT: *(x1 + 1 * (get_sub_group_local_id() + 0) + 1 * 64) = c0[1]; -; CHECK-NEXT: } -; CHECK-NEXT: } -} - -func @coopmatrix_a_store_atomic_add(%A: memref, %x: index, %y: index) subgroup_size(16) { - %c0 = constant 1.0 : coopmatrix - cooperative_matrix_store.atomic_add %c0, %A[%x,%y] -; CHECK-LABEL: void coopmatrix_a_store_atomic_add({{.*}} -; CHECK: global float* x1 = A + x * 1 + y * 64; -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) (x1 + 1 * (get_sub_group_local_id() + 0) + 0 * 64), c0[0], memory_order_relaxed, memory_scope_work_group); -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) (x1 + 1 * (get_sub_group_local_id() + 0) + 1 * 64), c0[1], memory_order_relaxed, memory_scope_work_group); -} - -func @coopmatrix_a_store_checked_atomic_add(%A: memref, %x: index, %y: index) subgroup_size(16) { - %c0 = constant 1.0 : coopmatrix - cooperative_matrix_store.both_checked.atomic_add %c0, %A[%x,%y] -; CHECK-LABEL: void coopmatrix_a_store_checked_atomic_add({{.*}} -; CHECK: global float* x1 = A + x * 1 + y * 64; -; CHECK-NEXT: long x2 = 64 - x; -; CHECK-NEXT: long x3 = 48 - y; -; CHECK-NEXT: if (get_sub_group_local_id() + 0 >= -x && get_sub_group_local_id() + 0 < x2) { -; CHECK-NEXT: if (0 >= -y && 0 < x3) { -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) (x1 + 1 * (get_sub_group_local_id() + 0) + 0 * 64), c0[0], memory_order_relaxed, memory_scope_work_group); -; CHECK-NEXT: } -; CHECK-NEXT: if (1 >= -y && 1 < x3) { -; CHECK-NEXT: atomic_fetch_add_explicit((global volatile atomic_float*) (x1 + 1 * (get_sub_group_local_id() + 0) + 1 * 64), c0[1], memory_order_relaxed, memory_scope_work_group); -; CHECK-NEXT: } -; CHECK-NEXT: } -} diff --git a/test/codegen/dope_vector0.ir b/test/codegen/dope_vector0.ir deleted file mode 100644 index 53cd8c33..00000000 --- a/test/codegen/dope_vector0.ir +++ /dev/null @@ -1,18 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc -O0 < %s | filecheck %s -func @kernel(%K0: memref, %offset: index, %size: index) { - %0 = subview %K0[4:%size, %offset] : memref -; CHECK: void kernel({{.*}} -; CHECK-NEXT: global float* x = K0 + 4ll * 1 + offset * K0_stride1; -; CHECK-NEXT: long x_shape0 = size; -} - -func @kernel2(%K0: memref, %offset: index, %size: index) { - %0 = subview %K0[%offset, 4:%size] : memref> -; CHECK: void kernel2({{.*}} -; CHECK-NEXT: global float* x = K0 + offset * 1 + 4ll * K0_stride1; -; CHECK-NEXT: long x_shape0 = size; -; CHECK-NEXT: long x_stride0 = K0_stride1; -} diff --git a/test/codegen/dope_vector_group0.ir b/test/codegen/dope_vector_group0.ir deleted file mode 100644 index 86bd8348..00000000 --- a/test/codegen/dope_vector_group0.ir +++ /dev/null @@ -1,22 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc -O0 < %s | filecheck %s -func @kernel1(%in: group>) { -; CHECK: void kernel1(global float*global* in, global long* in_shape1, global long* in_stride2) - %c5 = constant 5 : index - %0 = load %in[%c5] : memref - ; CHECK-NEXT: long c5 = 5ll; - ; CHECK-NEXT: global float* x = *(in + c5) + 0; - ; CHECK-NEXT: long x_shape1 = in_shape1[c5]; - ; CHECK-NEXT: long x_stride2 = in_stride2[c5]; -} - -func @kernel2(%in: group, offset: ?>) { -; CHECK: void kernel2(global float*global* in, global long* in_shape0, long in_offset) - %c5 = constant 5 : index - %0 = load %in[%c5] : memref - ; CHECK-NEXT: long c5 = 5ll; - ; CHECK-NEXT: global float* x = *(in + c5) + in_offset; - ; CHECK-NEXT: long x_shape0 = in_shape0[c5]; -} diff --git a/test/codegen/expand.ir b/test/codegen/expand.ir deleted file mode 100644 index e7873b92..00000000 --- a/test/codegen/expand.ir +++ /dev/null @@ -1,134 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc -O0 < %s | filecheck %s -func @t1(%0: memref) { - %z = constant 0 : index - %1 = expand %0[1->2x8] : memref - %2 = load %1[%z,%z,%z,%z] : f32 -; CHECK-LABEL: void t1( -; CHECK: global float* x1 = x; -; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * 32 + z * 64 + z * 512); -} -func @t2(%0: memref) { - %z = constant 0 : index - %1 = expand %0[1->2x2x2x2] : memref - %2 = load %1[%z,%z,%z,%z,%z,%z] : f32 -; CHECK-LABEL: void t2( -; CHECK: global float* x1 = x; -; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * 32 + z * 64 + z * 128 + z * 256 + z * 512); -} -func @t3(%0: memref, %1: index) { - %z = constant 0 : index - %2 = expand %0[1->%1 x 2] : memref - %3 = load %2[%z,%z,%z] : f32 -; CHECK-LABEL: void t3( -; CHECK: global float* x2 = x; -; CHECK-NEXT: long x_shape11 = x1; -; CHECK-NEXT: long x_stride2 = 32 * x1; -; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * x_stride2); -} -func @t4(%0: memref, %1: index) { - %z = constant 0 : index - %2 = expand %0[1->2 x %1] : memref - %3 = load %2[%z,%z,%z] : f32 -; CHECK-LABEL: void t4( -; CHECK: global float* x2 = x; -; CHECK-NEXT: long x_shape2 = x1; -; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * 64); -} -func @t5(%0: memref, %1: index) { - %z = constant 0 : index - %2 = expand %0[1->%1 x 2] : memref - %3 = load %2[%z,%z,%z] : f32 -; CHECK-LABEL: void t5( -; CHECK: global float* x2 = x; -; CHECK-NEXT: long x_shape1 = x1; -; CHECK-NEXT: long x_stride2 = 32 * x1; -; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * x_stride2); -} -func @t6(%0: memref, %1: index) { - %z = constant 0 : index - %2 = expand %0[1->%1 x 2] : memref - %3 = load %2[%z,%z,%z] : f32 -; CHECK-LABEL: void t6( -; CHECK: global float* x2 = x; -; CHECK-NEXT: long x_shape11 = x1; -; CHECK-NEXT: long x_stride2 = 32 * x1; -; CHECK-NEXT: float x3 = *(x2 + z * 1 + z * 32 + z * x_stride2); -} -func @t7(%0: memref, %1: index, %2: index) { - %z = constant 0 : index - %3 = expand %0[1->%1 x %2 x 2] : memref - %4 = load %3[%z,%z,%z,%z] : f32 -; CHECK-LABEL: void t7( -; CHECK: global float* x3 = x; -; CHECK-NEXT: long x_shape1 = x1; -; CHECK-NEXT: long x_shape2 = x2; -; CHECK-NEXT: long x_stride2 = 32 * x1; -; CHECK-NEXT: long x_stride3 = 32 * x1 * x2; -; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x_stride2 + z * x_stride3); -} -func @t8(%0: memref, %1: index, %2: index) { - %z = constant 0 : index - %3 = expand %0[1->%2 x 2 x %1] : memref - %4 = load %3[%z,%z,%z,%z] : f32 -; CHECK-LABEL: void t8( -; CHECK: global float* x3 = x; -; CHECK-NEXT: long x_shape1 = x2; -; CHECK-NEXT: long x_stride2 = 32 * x2; -; CHECK-NEXT: long x_shape3 = x1; -; CHECK-NEXT: long x_stride3 = 32 * x2 * 2ll; -; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x_stride2 + z * x_stride3); -} -func @t9(%0: memref, %1: index, %2: index) { - %z = constant 0 : index - %3 = expand %0[1->%1 x %2] : memref - %4 = load %3[%z,%z,%z] : f32 -; CHECK-LABEL: void t9( -; CHECK: global float* x3 = x; -; CHECK-NEXT: long x_shape11 = x1; -; CHECK-NEXT: long x_shape2 = x2; -; CHECK-NEXT: long x_stride2 = 32 * x1; -; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x_stride2); -} -func @t10(%0: memref, %1: index, %2: index) { - %z = constant 0 : index - %3 = expand %0[1->%1 x %2] : memref - %4 = load %3[%z,%z,%z] : f32 -; CHECK-LABEL: void t10( -; CHECK: global float* x3 = x; -; CHECK-NEXT: long x_shape1 = x1; -; CHECK-NEXT: long x_shape2 = x2; -; CHECK-NEXT: long x_stride2 = 32 * x1; -; CHECK-NEXT: float x4 = *(x3 + z * 1 + z * 32 + z * x_stride2); -} -func @t11(%0: memref>) { - %z = constant 0 : index - %1 = expand %0[0->4 x 8] : memref> - %2 = load %1[%z,%z,%z] : f32 -; CHECK-LABEL: void t11( -; CHECK: global float* x1 = x; -; CHECK-NEXT: float x2 = *(x1 + z * 2 + z * 8 + z * 64); -} -func @t12(%0: memref>, %1: index) { - %z = constant 0 : index - %2 = expand %0[0->%1 x 4] : memref> - %3 = load %2[%z,%z,%z] : f32 -; CHECK-LABEL: void t12( -; CHECK: global float* x2 = x; -; CHECK-NEXT: long x_shape01 = x1; -; CHECK-NEXT: long x_stride11 = 2 * x1; -; CHECK-NEXT: long x_stride2 = x_stride1; -; CHECK-NEXT: float x3 = *(x2 + z * 2 + z * x_stride11 + z * x_stride2); -} -func @t13(%0: memref>, %1: index) { - %z = constant 0 : index - %2 = expand %0[0->4 x %1] : memref> - %3 = load %2[%z,%z,%z] : f32 -; CHECK-LABEL: void t13( -; CHECK: global float* x2 = x; -; CHECK-NEXT: long x_shape1 = x1; -; CHECK-NEXT: long x_stride2 = x_stride1; -; CHECK-NEXT: float x3 = *(x2 + z * 2 + z * 8 + z * x_stride2); -} diff --git a/test/codegen/for.ir b/test/codegen/for.ir deleted file mode 100644 index cbf47dad..00000000 --- a/test/codegen/for.ir +++ /dev/null @@ -1,40 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc < %s | filecheck %s -func @for1() { - %lb0 = constant 0 : index - %ub0 = constant 10 : index - for %0 = %lb0,%ub0 { - } - %lb1 = constant -2 : i16 - %ub1 = constant 2 : i16 - for %1:i16 = %lb1,%ub1 { - } -; CHECK-LABEL: void for1({{.*}} -; CHECK: for (long x = lb0; x < ub0; ++x) -; CHECK: for (short x = lb1; x < ub1; ++x) -} - -func @for2(%fib: memref) { - %from = constant 2 : i32 - %to = constant 6 : i32 - %f0 = constant 0 : i64 - %f1 = constant 1 : i64 - %fn_1, %fn = for %n:i32=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) { - %fn = arith.add %fn_2, %fn_1 : i64 - yield (%fn_1, %fn) - } - store %fn, %fib[] -; CHECK-LABEL: void for2({{.*}} -; CHECK: long f0 = 0ll; -; CHECK-NEXT: long f1 = 1ll; -; CHECK-NEXT: long fn_1 = f0; -; CHECK-NEXT: long fn = f1; -; CHECK-NEXT: for (int n = from; n < to; ++n) { -; CHECK-NEXT: long fn1 = fn_1 + fn; -; CHECK-NEXT: fn_1 = fn; -; CHECK-NEXT: fn = fn1; -; CHECK-NEXT: } -; CHECK-NEXT: *fib = fn; -} diff --git a/test/codegen/func_attributes.ir b/test/codegen/func_attributes.ir deleted file mode 100644 index fadadd80..00000000 --- a/test/codegen/func_attributes.ir +++ /dev/null @@ -1,27 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc < %s | filecheck %s -func @attr1() work_group_size(128, 4) { -} -; CHECK: kernel -; CHECK-NEXT: __attribute__((reqd_work_group_size(128,4,1))) -; CHECK-NEXT: __attribute__((intel_reqd_sub_group_size(32))) -; CHECK-NEXT: void attr1() { -; CHECK-NEXT: } - -func @attr2() work_group_size(128, 4) subgroup_size(16) { -} -; CHECK: kernel -; CHECK-NEXT: __attribute__((reqd_work_group_size(128,4,1))) -; CHECK-NEXT: __attribute__((intel_reqd_sub_group_size(16))) -; CHECK-NEXT: void attr2() { -; CHECK-NEXT: } - -func @attr3() subgroup_size(32) { -} -; CHECK: kernel -; CHECK-NEXT: __attribute__((reqd_work_group_size(32,1,1))) -; CHECK-NEXT: __attribute__((intel_reqd_sub_group_size(32))) -; CHECK-NEXT: void attr3() { -; CHECK-NEXT: } diff --git a/test/codegen/fuse.ir b/test/codegen/fuse.ir deleted file mode 100644 index 72e5a9fe..00000000 --- a/test/codegen/fuse.ir +++ /dev/null @@ -1,32 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc -O0 < %s | filecheck %s -func @t1(%0: memref) { - %z = constant 0 : index - %1 = fuse %0[1,3] : memref - %2 = load %1[%z,%z,%z] : f32 -; CHECK: float x2 = *(x1 + z * 1 + z * 32 + z * 16384); -} -func @t2(%0: memref) { - %z = constant 0 : index - %1 = fuse %0[1,3] : memref> - %2 = load %1[%z,%z,%z] : f32 -; CHECK: long x_shape1 = 16 * x_shape2 * 4; -; CHECK-NEXT: long x_stride2 = x_stride4; -; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * 32 + z * x_stride2); -} -func @t3(%0: memref>) { - %z = constant 0 : index - %1 = fuse %0[1,2] : memref> - %2 = load %1[%z,%z,%z] : f32 -; CHECK: float x2 = *(x1 + z * 1 + z * 48 + z * 1536); -} -func @t4(%0: memref>) { - %z = constant 0 : index - %1 = fuse %0[0,1] : memref> - %2 = load %1[%z,%z] : f32 -; CHECK: long x_shape0 = 8 * x_shape1; -; CHECK-NEXT: long x_stride11 = x_stride2; -; CHECK-NEXT: float x2 = *(x1 + z * 1 + z * x_stride11); -} diff --git a/test/codegen/if.ir b/test/codegen/if.ir deleted file mode 100644 index 3d45f546..00000000 --- a/test/codegen/if.ir +++ /dev/null @@ -1,98 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc < %s | filecheck %s -func @if0(%0: i32) { - %c16 = constant 16 : i32 - %c0 = constant 0 : i32 - %1 = cmp.lt %0, %c16 : bool - %2 = cmp.ge %0, %c0 : bool - %3 = arith.and %1, %2 : bool - if %3 { - } else { - } -; CHECK: bool x1 = x < c16; -; CHECK: bool x2 = x >= c0; -; CHECK: bool x3 = x1 && x2; -; CHECK: if (x3) { -; CHECK-NEXT: } -} - -func @if1(%0: i32) { - %c16 = constant 16 : i32 - %1 = cmp.lt %0, %c16 : bool - if %1 { - } else { - } -; CHECK: if (x1) { -; CHECK-NEXT: } -} - -func @if2(%0: i32) { - %c16 = constant 16 : i32 - %1 = cmp.lt %0, %c16 : bool - if %1 -> () { - yield () - } else { - yield () - } -; CHECK: if (x1) { -; CHECK-NEXT: } else { -; CHECK-NEXT: } -} - -func @if3(%0: i32) { - %c16 = constant 16 : i32 - %1 = cmp.lt %0, %c16 : bool - %x = if %1 -> (i32) { - yield (%0) - } else { - yield (%c16) - } -; CHECK: int x2; -; CHECK-NEXT: if (x1) { -; CHECK-NEXT: x2 = x; -; CHECK-NEXT: } else { -; CHECK-NEXT: x2 = c16; -; CHECK-NEXT: } -} - -func @if4(%0: i32) { - %c16 = constant 16 : i32 - %1 = cmp.lt %0, %c16 : bool - %x, %y = if %1 -> (i32, f32) { - if %1 { - } - %one = constant 1.0 : f32 - yield (%0, %one) - } else { - %z = if %1 -> (f32) { - %one = constant 1.0 : f32 - yield (%one) - } else { - %zero = constant 0.0 : f32 - yield (%zero) - } - yield (%c16, %z) - } -; CHECK: int x2; -; CHECK-NEXT: float y; -; CHECK-NEXT: if (x1) { -; CHECK-NEXT: if (x1) { -; CHECK-NEXT: } -; CHECK-NEXT: float one = 0x1p+0f; -; CHECK-NEXT: x2 = x; -; CHECK-NEXT: y = one; -; CHECK-NEXT: } else { -; CHECK-NEXT: float z; -; CHECK-NEXT: if (x1) { -; CHECK-NEXT: float one = 0x1p+0f; -; CHECK-NEXT: z = one; -; CHECK-NEXT: } else { -; CHECK-NEXT: float zero = 0x0p+0f; -; CHECK-NEXT: z = zero; -; CHECK-NEXT: } -; CHECK-NEXT: x2 = c16; -; CHECK-NEXT: y = z; -; CHECK-NEXT: } -} diff --git a/test/codegen/load.ir b/test/codegen/load.ir deleted file mode 100644 index f485201a..00000000 --- a/test/codegen/load.ir +++ /dev/null @@ -1,21 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc -O0 < %s | filecheck %s -func @kernel1(%a: memref, %b: memref, %c: group>) { - %c5 = constant 5 : index - %0 = load %a[] : f32 - %1 = builtin.group_id : index - %2 = load %b[%c5, %1] : f32 - %3 = load %c[%1] : memref - ; CHECK: float x = *a; - ; CHECK-NEXT: long x1 = get_global_id(2); - ; CHECK-NEXT: float x2 = *(b + c5 * 1 + x1 * 10); - ; CHECK-NEXT: global float* x3 = *(c + x1) + 0; -} - -func @kernel2(%c: group, offset: 21>) { - %0 = builtin.group_id : index - %1 = load %c[%0] : memref - ; CHECK: global float* x1 = *(c + x) + 21; -} diff --git a/test/codegen/scalar_arithmetic.ir b/test/codegen/scalar_arithmetic.ir deleted file mode 100644 index da8faa6c..00000000 --- a/test/codegen/scalar_arithmetic.ir +++ /dev/null @@ -1,106 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc -O0 < %s | filecheck %s -func @t1(%a: i32, %b: i32, %a1: bool, %b1: bool) { - %1 = arith.add %a, %b : i32 - %2 = arith.sub %a, %b : i32 - %3 = arith.mul %a, %b : i32 - %4 = arith.div %a, %b : i32 - %5 = arith.rem %a, %b : i32 - %6 = arith.shl %a, %b : i32 - %7 = arith.shr %a, %b : i32 - %8 = arith.and %a, %b : i32 - %9 = arith.and %a1, %b1 : bool - %10 = arith.or %a, %b : i32 - %11 = arith.or %a1, %b1 : bool - %12 = arith.xor %a, %b : i32 - %13 = arith.neg %a : i32 - %14 = arith.not %a : i32 - %15 = arith.not %a1 : bool - %16 = arith.abs %a : i32 -; CHECK: int x = a + b; -; CHECK-NEXT: int x1 = a - b; -; CHECK-NEXT: int x2 = a * b; -; CHECK-NEXT: int x3 = a / b; -; CHECK-NEXT: int x4 = a % b; -; CHECK-NEXT: int x5 = a << b; -; CHECK-NEXT: int x6 = a >> b; -; CHECK-NEXT: int x7 = a & b; -; CHECK-NEXT: bool x8 = a1 && b1; -; CHECK-NEXT: int x9 = a | b; -; CHECK-NEXT: bool x10 = a1 || b1; -; CHECK-NEXT: int x11 = a ^ b; -; CHECK-NEXT: int x12 = -a; -; CHECK-NEXT: int x13 = ~a; -; CHECK-NEXT: bool x14 = !a1; -; CHECK-NEXT: int x15 = abs(a); -} -func @t2(%a: i32, %b: i32) { - %1 = cmp.eq %a, %b : bool - %2 = cmp.ne %a, %b : bool - %3 = cmp.gt %a, %b : bool - %4 = cmp.ge %a, %b : bool - %5 = cmp.lt %a, %b : bool - %6 = cmp.le %a, %b : bool -; CHECK: bool x = a == b; -; CHECK-NEXT: bool x1 = a != b; -; CHECK-NEXT: bool x2 = a > b; -; CHECK-NEXT: bool x3 = a >= b; -; CHECK-NEXT: bool x4 = a < b; -; CHECK-NEXT: bool x5 = a <= b; -} -func @t3(%a: f32, %b: f32) { - %1 = arith.add %a, %b : f32 - %2 = arith.sub %a, %b : f32 - %3 = arith.mul %a, %b : f32 - %4 = arith.div %a, %b : f32 - %5 = arith.rem %a, %b : f32 - %6 = arith.neg %a : f32 - %7 = arith.abs %a : f32 -; CHECK: float x = a + b; -; CHECK-NEXT: float x1 = a - b; -; CHECK-NEXT: float x2 = a * b; -; CHECK-NEXT: float x3 = a / b; -; CHECK-NEXT: float x4 = fmod(a, b); -; CHECK-NEXT: float x5 = -a; -; CHECK-NEXT: float x6 = fabs(a); -} -func @t4(%a: f32, %b: f32) { - %1 = cmp.eq %a, %b : bool - %2 = cmp.ne %a, %b : bool - %3 = cmp.gt %a, %b : bool - %4 = cmp.ge %a, %b : bool - %5 = cmp.lt %a, %b : bool - %6 = cmp.le %a, %b : bool -; CHECK: bool x = a == b; -; CHECK-NEXT: bool x1 = a != b; -; CHECK-NEXT: bool x2 = a > b; -; CHECK-NEXT: bool x3 = a >= b; -; CHECK-NEXT: bool x4 = a < b; -; CHECK-NEXT: bool x5 = a <= b; -} -func @t5(%a: i32) { - %b = cast %a : index -; CHECK: long b = (long) a; -} -func @t6(%a: c32, %b: c32) { - %0 = arith.add %a, %b : c32 - %1 = arith.sub %a, %b : c32 - %2 = arith.mul %a, %b : c32 - %3 = arith.div %a, %b : c32 - %4 = arith.neg %a : c32 - %5 = arith.abs %a : f32 - %6 = arith.conj %a : c32 - %7 = arith.im %a : f32 - %8 = arith.re %a : f32 -; CHECK: float2 x = a + b; -; CHECK-NEXT: float2 x1 = a - b; -; CHECK-NEXT: float2 x2 = a * b.x + (float2) (-a.y, a.x) * b.y; -; CHECK-NEXT: float2 x3 = (a * b.x - (float2) (-a.y, a.x) * b.y) / (b.x * b.x + b.y * b.y); -; CHECK-NEXT: float2 x4 = -a; -; CHECK-NEXT: float x5 = sqrt(a.x * a.x + a.y * a.y); -; CHECK-NEXT: float2 x6 = (float2) (a.x, -a.y); -; CHECK-NEXT: float x7 = a.y; -; CHECK-NEXT: float x8 = a.x; -} diff --git a/test/codegen/size.ir b/test/codegen/size.ir deleted file mode 100644 index 02874f33..00000000 --- a/test/codegen/size.ir +++ /dev/null @@ -1,16 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc -O0 < %s | filecheck %s -func @t1(%0: memref) { - %1 = size %0[0] : index - %2 = size %0[1] : index -; CHECK: long x1 = 32; -; CHECK-NEXT: long x2 = 16; -} -func @t2(%0: memref) { - %1 = size %0[0] : index - %2 = size %0[1] : index -; CHECK: long x1 = x_shape0; -; CHECK-NEXT: long x2 = x_shape1; -} diff --git a/test/codegen/store.ir b/test/codegen/store.ir deleted file mode 100644 index 358b1e4c..00000000 --- a/test/codegen/store.ir +++ /dev/null @@ -1,12 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc < %s | filecheck %s -func @kernel(%a: memref, %b: memref, %c: f32) { - %c5 = constant 5 : index - %1 = builtin.group_id : index - store %c, %a[] - store %c, %b[%c5, %1] - ; CHECK: *a = c; - ; CHECK-NEXT: *(b + c5 * 1 + x * 10) = c; -} diff --git a/test/codegen/subgroup.ir b/test/codegen/subgroup.ir deleted file mode 100644 index 0881b43e..00000000 --- a/test/codegen/subgroup.ir +++ /dev/null @@ -1,16 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc -O0 < %s | filecheck %s -func @t1() { - parallel { - %0 = builtin.num_subgroups : i32 - %1 = builtin.subgroup_id : i32 - %2 = builtin.subgroup_local_id : i32 - %3 = builtin.subgroup_size : i32 - } -; CHECK: int x = get_num_sub_groups(); -; CHECK-NEXT: int x1 = get_sub_group_id(); -; CHECK-NEXT: int x2 = get_sub_group_local_id(); -; CHECK-NEXT: int x3 = get_sub_group_size(); -} diff --git a/test/codegen/work_group.ir b/test/codegen/work_group.ir deleted file mode 100644 index eed318bd..00000000 --- a/test/codegen/work_group.ir +++ /dev/null @@ -1,16 +0,0 @@ -; Copyright (C) 2024 Intel Corporation -; SPDX-License-Identifier: BSD-3-Clause - -; RUN: %tinytc-oc -O0 < %s | filecheck %s -func @t1() { - %0 = constant 1.0 : f32 - %1 = work_group.reduce_add %0 : f32 -; CHECK-LABEL: void t1({{.*}} -; CHECK: float x1 = work_group_reduce_add(x); -} -func @t2() { - %0 = constant [1.0, 0.0] : c32 - %1 = work_group.reduce_add %0 : c32 -; CHECK-LABEL: void t2({{.*}} -; CHECK: float2 x1 = (float2) (work_group_reduce_add(x.x), work_group_reduce_add(x.y)); -} diff --git a/test/codegen/axpby0.ir b/test/opt/check-ir/axpby0.ir similarity index 81% rename from test/codegen/axpby0.ir rename to test/opt/check-ir/axpby0.ir index b8591f97..51d087fb 100644 --- a/test/codegen/axpby0.ir +++ b/test/opt/check-ir/axpby0.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s +; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @axpby(%alpha: f32, %A: memref, %B: memref) { %zero = constant 0.0 : f32 axpby.n %alpha, %A, %zero, %B diff --git a/test/opt/check-ir/expand.ir b/test/opt/check-ir/expand.ir index 34180db0..ba1eaa5d 100644 --- a/test/opt/check-ir/expand.ir +++ b/test/opt/check-ir/expand.ir @@ -3,8 +3,7 @@ ; RUN: %tinytc-opt -pcheck-ir -O0 < %s | filecheck %s -; No real checks needed, just check that it does not crash, that is, -; the types put in load match those returned by expand +; No real checks needed, just check that it does not crash ; CHECK: func @t1({{.*}} func @t1(%0: memref) { diff --git a/test/opt/check-ir/fuse.ir b/test/opt/check-ir/fuse.ir new file mode 100644 index 00000000..9de14e6c --- /dev/null +++ b/test/opt/check-ir/fuse.ir @@ -0,0 +1,21 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-opt -pcheck-ir -O0 < %s | filecheck %s + +; No real checks needed, just check that it does not crash +; CHECK: func @t1({{.*}} + +func @t1(%0: memref) { + %1 = fuse %0[1,3] : memref +} +func @t2(%0: memref) { + %1 = fuse %0[1,3] : memref +} +func @t3(%0: memref>) { + %1 = fuse %0[1,2] : memref> + %2 = fuse %0[1,2] : memref> +} +func @t4(%0: memref>) { + %2 = fuse %0[0,1] : memref> +} diff --git a/test/codegen/scalar_arithmetic_error.ir b/test/opt/check-ir/scalar_arithmetic_error.ir similarity index 83% rename from test/codegen/scalar_arithmetic_error.ir rename to test/opt/check-ir/scalar_arithmetic_error.ir index 6cef12e9..29915edc 100644 --- a/test/codegen/scalar_arithmetic_error.ir +++ b/test/opt/check-ir/scalar_arithmetic_error.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s +; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @t1(%a: f32, %b: f32) { %1 = arith.and %a, %b : f32 ; CHECK: :6.8-29: Floating point type unsupported by instruction diff --git a/test/codegen/syntax_error0.ir b/test/opt/check-ir/syntax_error0.ir similarity index 75% rename from test/codegen/syntax_error0.ir rename to test/opt/check-ir/syntax_error0.ir index 61b32d27..12e37e8c 100644 --- a/test/codegen/syntax_error0.ir +++ b/test/opt/check-ir/syntax_error0.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s +; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @axpby(%B: memreff32x8x8>) ; CHECK: 5.23-25: syntax error, unexpected FLOATING_TYPE, expecting < } diff --git a/test/codegen/syntax_error1.ir b/test/opt/check-ir/syntax_error1.ir similarity index 72% rename from test/codegen/syntax_error1.ir rename to test/opt/check-ir/syntax_error1.ir index 2c6dd1df..b72334fb 100644 --- a/test/codegen/syntax_error1.ir +++ b/test/opt/check-ir/syntax_error1.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s +; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @axpby(%B: memref) { axby ; CHECK: 6.5: Unknown token diff --git a/test/codegen/type_mismatch0.ir b/test/opt/check-ir/type_mismatch0.ir similarity index 76% rename from test/codegen/type_mismatch0.ir rename to test/opt/check-ir/type_mismatch0.ir index 7cddb539..709043af 100644 --- a/test/codegen/type_mismatch0.ir +++ b/test/opt/check-ir/type_mismatch0.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s +; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @kernel(%K0: memref) { %0 = load %K0[] : f64 ; CHECK: 6.8-23: Type of operand must match return type diff --git a/test/codegen/type_mismatch1.ir b/test/opt/check-ir/type_mismatch1.ir similarity index 85% rename from test/codegen/type_mismatch1.ir rename to test/opt/check-ir/type_mismatch1.ir index c7e05a22..eb3415bb 100644 --- a/test/codegen/type_mismatch1.ir +++ b/test/opt/check-ir/type_mismatch1.ir @@ -1,7 +1,7 @@ ; Copyright (C) 2024 Intel Corporation ; SPDX-License-Identifier: BSD-3-Clause -; RUN: not %tinytc-oc < %s 2>&1 | filecheck %s +; RUN: not %tinytc-opt -pcheck-ir < %s 2>&1 | filecheck %s func @kernel(%K0: memref, %x: index, %y: index) { %z = constant 0 : index %0 = subview %K0[0:%x] : memref diff --git a/test/spv/func_attributes.ir b/test/spv/func_attributes.ir new file mode 100644 index 00000000..f74fb791 --- /dev/null +++ b/test/spv/func_attributes.ir @@ -0,0 +1,22 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -S < %s | filecheck %s + +; CHECK: OpEntryPoint Kernel %[[#ATTR1:]] "attr1" +; CHECK: OpEntryPoint Kernel %[[#ATTR2:]] "attr2" +; CHECK: OpEntryPoint Kernel %[[#ATTR3:]] "attr3" + +; CHECK: OpExecutionMode %[[#ATTR1]] LocalSize 128 4 1 +; CHECK: OpExecutionMode %[[#ATTR2]] LocalSize 128 4 1 +; CHECK: OpExecutionMode %[[#ATTR2]] SubgroupSize 16 +; CHECK: OpExecutionMode %[[#ATTR3]] SubgroupSize 32 + +func @attr1() work_group_size(128, 4) { +} + +func @attr2() work_group_size(128, 4) subgroup_size(16) { +} + +func @attr3() subgroup_size(32) { +} From 28b42b79f379f4785e4de78f33c27ef2479982d2 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Fri, 22 Nov 2024 15:35:50 +0100 Subject: [PATCH 131/297] Started with fp16 support Signed-off-by: Carsten Uphoff --- include/tinytc/tinytc.h | 22 ++++++++ include/tinytc/types.h | 11 ++-- include/tinytc/types.hpp | 1 + src/CMakeLists.txt | 1 + src/half.cpp | 113 +++++++++++++++++++++++++++++++++++++++ src/inst.cpp | 2 + src/parser/lexer.re | 3 +- src/scalar_type.cpp | 4 ++ src/spv/converter.cpp | 10 ++++ src/spv/uniquifier.cpp | 1 + test/CMakeLists.txt | 5 ++ test/math.cpp | 14 +++++ test/spv/arith.ir | 15 ++++++ test/spv/arith_unary.ir | 8 +++ test/spv/cast.ir | 11 ++++ test/spv/compare.ir | 15 ++++++ 16 files changed, 230 insertions(+), 6 deletions(-) create mode 100644 src/half.cpp create mode 100644 test/math.cpp diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index eede8224..2388d1a4 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -36,6 +36,28 @@ extern "C" { */ TINYTC_EXPORT char const *tinytc_error_string(tinytc_status_t status); +//////////////////////////// +////////// FP math ///////// +//////////////////////////// + +/** + * @brief Convert f32 number to f16 number + * + * @param x f32 number + * + * @return f16 number + */ +TINYTC_EXPORT uint16_t tinytc_f32_to_f16(float x); + +/** + * @brief Convert f16 number (represented as ushort) to f32 number + * + * @param x f16 number + * + * @return f32 number + */ +TINYTC_EXPORT float tinytc_f16_to_f32(uint16_t x); + //////////////////////////// //////// Scalar type /////// //////////////////////////// diff --git a/include/tinytc/types.h b/include/tinytc/types.h index 248f64c5..9a639f06 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -267,12 +267,13 @@ typedef enum { tinytc_scalar_type_i32 = 2, ///< Signed 32 bit integer tinytc_scalar_type_i64 = 3, ///< Signed 64 bit integer tinytc_scalar_type_index = 4, ///< Integer type for indices - tinytc_scalar_type_f32 = 5, ///< Single precision floating point (32 bit) - tinytc_scalar_type_f64 = 6, ///< Double precision floating point (64 bit) - tinytc_scalar_type_c32 = 7, ///< Single precision complex (2x32 bit) - tinytc_scalar_type_c64 = 8 ///< Double precision complex (2x64 bit) + tinytc_scalar_type_f16 = 5, ///< Half precision floating point (16 bit) + tinytc_scalar_type_f32 = 6, ///< Single precision floating point (32 bit) + tinytc_scalar_type_f64 = 7, ///< Double precision floating point (64 bit) + tinytc_scalar_type_c32 = 8, ///< Single precision complex (2x32 bit) + tinytc_scalar_type_c64 = 9 ///< Double precision complex (2x64 bit) } tinytc_scalar_type_t; -#define TINYTC_NUMBER_OF_SCALAR_TYPES 9 // @todo Keep up to date with tinytc_scalar_type_t +#define TINYTC_NUMBER_OF_SCALAR_TYPES 10 // @todo Keep up to date with tinytc_scalar_type_t //! Arithmetic operations typedef enum { diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 9444fc50..76931c98 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -245,6 +245,7 @@ enum class scalar_type { i32 = tinytc_scalar_type_i32, ///< Signed 32 bit integer i64 = tinytc_scalar_type_i64, ///< Signed 64 bit integer index = tinytc_scalar_type_index, ///< Unsigned Integer type for indices + f16 = tinytc_scalar_type_f16, ///< Half precision floating point (16 bit) f32 = tinytc_scalar_type_f32, ///< Single precision floating point (32 bit) f64 = tinytc_scalar_type_f64, ///< Double precision floating point (64 bit) c32 = tinytc_scalar_type_c32, ///< Single precision complex (2x32 bit) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ce47bf6c..317c30c4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -32,6 +32,7 @@ set(SOURCES func.cpp gemm_tools.cpp inst.cpp + half.cpp location.cpp node/data_type_node.cpp node/inst_node.cpp diff --git a/src/half.cpp b/src/half.cpp new file mode 100644 index 00000000..b88f2a63 --- /dev/null +++ b/src/half.cpp @@ -0,0 +1,113 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include +#include + +template struct fp_info { + constexpr static uint32_t exponent_bits = ExponentBits; + constexpr static uint32_t mantissa_bits = MantissaBits; + constexpr static uint32_t num_bits = 1 + exponent_bits + mantissa_bits; + constexpr static uint32_t bias = (1 << (exponent_bits - 1)) - 1; + constexpr static uint32_t max_biased_exponent = (1 << exponent_bits) - 1; + constexpr static uint32_t sign_mask = 1 << (num_bits - 1); + constexpr static uint32_t exponent_mask = max_biased_exponent << mantissa_bits; + constexpr static uint32_t mantissa_mask = (1 << mantissa_bits) - 1; +}; + +using f16i = fp_info<5, 10>; +using f32i = fp_info<8, 23>; + +extern "C" { + +uint16_t tinytc_f32_to_f16(float x) { + const uint32_t y = std::bit_cast(x); + const uint16_t sign = (y & f32i::sign_mask) >> (f32i::num_bits - f16i::num_bits); + uint32_t exponent = (y & f32i::exponent_mask) >> f32i::mantissa_bits; + uint32_t mantissa = y & f32i::mantissa_mask; + + if (exponent > f32i::bias + f16i::bias) { + // Large numbers are mapped to inf + exponent = f16i::max_biased_exponent; + mantissa = 0; + } else if (exponent > f32i::bias - f16i::bias) { + // Normal numbers + + // convert bias + // E_{32} = e + f32i::bias + // E_{16} = e + f16i::bias + // = E_{32} - f32i::bias + f16i::bias + // = E_{32} - (f32i::bias - f16i::bias) + exponent -= f32i::bias - f16i::bias; + + constexpr uint32_t num_shift_bits = f32i::mantissa_bits - f16i::mantissa_bits; + constexpr uint32_t midpoint = (1 << num_shift_bits) / 2; + constexpr uint32_t low_bit_mask = (1 << num_shift_bits) - 1; + const uint32_t truncated_bits = mantissa & low_bit_mask; + + // shift mantissa and round correctly + mantissa >>= num_shift_bits; + if (truncated_bits > midpoint) { + mantissa += 1; + } else if (truncated_bits == midpoint) { + // when there is a tie round to nearest even + mantissa += mantissa & 1; + } + // We had an overflow during rounding + if ((mantissa & (1 << f16i::mantissa_bits)) != 0) { + ++exponent; + if (exponent > f16i::max_biased_exponent) { + // Overflow to infinity + exponent = f16i::max_biased_exponent; + mantissa = 0; + } else { + mantissa &= f16i::mantissa_mask; + } + } + } else { + // @todo + } + + exponent <<= f16i::exponent_bits; + + return sign | static_cast(exponent) | static_cast(mantissa); +} + +float tinytc_f16_to_f32(uint16_t x) { + const uint32_t sign = (x & f16i::sign_mask) << (f32i::num_bits - f16i::num_bits); + uint32_t exponent = (x & f16i::exponent_mask) >> f16i::mantissa_bits; + uint32_t mantissa = x & f16i::mantissa_mask; + + if (exponent == f16i::max_biased_exponent) { + // Inf and NaN + exponent = (f32i::max_biased_exponent << f32i::mantissa_bits); + } else if (exponent != 0) { + // convert bias + // E_{16} = e + f16i::bias + // E_{32} = e + f32i::bias + // = E_{16} - f16i::bias + f32i::bias + // = E_{16} + (f32i::bias - f16i::bias) + exponent += f32i::bias - f16i::bias; + + // shift exponent + exponent <<= f32i::mantissa_bits; + } + + // Subnormal f16 numbers must be represented as f32 normal numbers + if (exponent == 0 && mantissa != 0) { + uint8_t shift_count = 1; + do { + mantissa <<= 1; + ++shift_count; + } while ((mantissa & (1 << f16i::mantissa_bits)) != (1 << f16i::mantissa_bits)); + mantissa &= f16i::mantissa_mask; + exponent = (f32i::bias - f16i::bias + 1) - shift_count; + } + + // shift mantissa + mantissa <<= f32i::mantissa_bits - f16i::mantissa_bits; + + const uint32_t y = sign | exponent | mantissa; + return std::bit_cast(y); +} +} diff --git a/src/inst.cpp b/src/inst.cpp index 8e2b9202..6dd203e4 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -283,6 +283,7 @@ tinytc_status_t tinytc_constant_inst_create_one(tinytc_inst_t *instr, tinytc_dat *instr = std::make_unique(std::int64_t{1}, ty, get_optional(loc)).release(); break; + case scalar_type::f16: case scalar_type::f32: case scalar_type::f64: *instr = std::make_unique(double{1}, ty, get_optional(loc)).release(); @@ -327,6 +328,7 @@ tinytc_status_t tinytc_constant_inst_create_zero(tinytc_inst_t *instr, tinytc_da *instr = std::make_unique(std::int64_t{0}, ty, get_optional(loc)).release(); break; + case scalar_type::f16: case scalar_type::f32: case scalar_type::f64: *instr = std::make_unique(double{0}, ty, get_optional(loc)).release(); diff --git a/src/parser/lexer.re b/src/parser/lexer.re index 38cb6091..ef9b7118 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -42,7 +42,7 @@ lex: global_identifier = "@" (unnamed_identifier | named_identifier); integer_type = "i" ("8" | "16" | "32" | "64") | "index"; - floating_type = ("f" | "c") ("32" | "64"); + floating_type = ("f" | "c") ("16" | "32" | "64"); digit = [0-9]; hexdigit = [0-9a-fA-F]; @@ -301,6 +301,7 @@ scalar_type lexer::lex_floating_type(char const *s, char const *) { re2c:yyfill:enable = 0; re2c:define:YYCURSOR = s; + "f16" { return scalar_type::f16; } "f32" { return scalar_type::f32; } "f64" { return scalar_type::f64; } "c32" { return scalar_type::c32; } diff --git a/src/scalar_type.cpp b/src/scalar_type.cpp index f189cf47..9fdfe92b 100644 --- a/src/scalar_type.cpp +++ b/src/scalar_type.cpp @@ -14,6 +14,7 @@ namespace tinytc { bool is_floating_type(scalar_type ty) { switch (ty) { + case scalar_type::f16: case scalar_type::f32: case scalar_type::f64: return true; @@ -84,6 +85,8 @@ char const *tinytc_scalar_type_to_string(tinytc_scalar_type_t ty) { return "i64"; case tinytc_scalar_type_index: return "index"; + case tinytc_scalar_type_f16: + return "f16"; case tinytc_scalar_type_f32: return "f32"; case tinytc_scalar_type_f64: @@ -100,6 +103,7 @@ size_t tinytc_scalar_type_size(tinytc_scalar_type_t ty) { case tinytc_scalar_type_i8: return 1; case tinytc_scalar_type_i16: + case tinytc_scalar_type_f16: return 2; case tinytc_scalar_type_i32: case tinytc_scalar_type_f32: diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index b3676e91..177a976b 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -245,6 +245,7 @@ auto inst_converter::make_binary_op(scalar_type sty, arithmetic op, spv_inst *ty case scalar_type::i64: case scalar_type::index: return make_int(op, ty, a, b); + case scalar_type::f16: case scalar_type::f32: case scalar_type::f64: return make_float(op, ty, a, b); @@ -280,6 +281,7 @@ auto inst_converter::make_cast(scalar_type to_ty, scalar_type a_ty, spv_inst *sp case scalar_type::i64: case scalar_type::index: return mod_->add(spv_to_ty, a); + case scalar_type::f16: case scalar_type::f32: case scalar_type::f64: return mod_->add(spv_to_ty, a); @@ -302,6 +304,7 @@ auto inst_converter::make_cast(scalar_type to_ty, scalar_type a_ty, spv_inst *sp case scalar_type::i64: case scalar_type::index: return mod_->add(spv_to_ty, a); + case scalar_type::f16: case scalar_type::f32: case scalar_type::f64: return mod_->add(spv_to_ty, a); @@ -335,6 +338,7 @@ auto inst_converter::make_cast(scalar_type to_ty, scalar_type a_ty, spv_inst *sp case scalar_type::i64: case scalar_type::index: return cast_from_int(to_ty, spv_to_ty, a); + case scalar_type::f16: case scalar_type::f32: case scalar_type::f64: return cast_from_float(to_ty, spv_to_ty, a); @@ -558,6 +562,7 @@ void inst_converter::make_store(store_flag flag, scalar_type sty, address_space case scalar_type::index: mod_->add(result_ty, pointer, scope, semantics, value); break; + case scalar_type::f16: case scalar_type::f32: case scalar_type::f64: mod_->add(result_ty, pointer, scope, semantics, value); @@ -738,6 +743,7 @@ void inst_converter::operator()(arith_unary_inst const &in) { case scalar_type::i64: case scalar_type::index: return make_int(op, ty, a); + case scalar_type::f16: case scalar_type::f32: case scalar_type::f64: return make_float(op, ty, a); @@ -903,6 +909,7 @@ void inst_converter::operator()(compare_inst const &in) { case scalar_type::i64: case scalar_type::index: return compare_int(cond, spv_to_ty, a, b); + case scalar_type::f16: case scalar_type::f32: case scalar_type::f64: return compare_float(cond, spv_to_ty, a, b); @@ -997,6 +1004,8 @@ void inst_converter::operator()(cooperative_matrix_load_inst const &in) { return mod_->add(spv_ty, value); }; switch (ot->element_ty()) { + case scalar_type::f16: + return cast_load_cast(scalar_type::i16); case scalar_type::f32: return cast_load_cast(scalar_type::i32); case scalar_type::f64: @@ -1733,6 +1742,7 @@ void inst_converter::operator()(work_group_inst const &in) { case scalar_type::i64: case scalar_type::index: return mod_->add(spv_ty, scope, GroupOperation::Reduce, operand); + case scalar_type::f16: case scalar_type::f32: case scalar_type::f64: case scalar_type::c32: diff --git a/src/spv/uniquifier.cpp b/src/spv/uniquifier.cpp index cc818b25..cf56258d 100644 --- a/src/spv/uniquifier.cpp +++ b/src/spv/uniquifier.cpp @@ -229,6 +229,7 @@ auto uniquifier::spv_ty(const_tinytc_data_type_t ty) -> spv_inst * { } return spv_ty(scalar_data_type::get(mod_->context(), scalar_type::i32)); } + case scalar_type::f16: case scalar_type::f32: case scalar_type::f64: return mod_->add_to(section::type_const_var, diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 05cd0e6f..cbe0afb6 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -18,6 +18,11 @@ set_cxx_common_options(test-lib) #target_link_libraries(test-lexer PRIVATE test-lib) #doctest_discover_tests(test-lexer) +add_executable(test-math math.cpp) +target_link_libraries(test-math PRIVATE test-lib) +doctest_discover_tests(test-math) +set_cxx_common_options(test-math) + add_executable(test-generator generator.cpp) target_link_libraries(test-generator PRIVATE test-lib) doctest_discover_tests(test-generator) diff --git a/test/math.cpp b/test/math.cpp new file mode 100644 index 00000000..10beba1f --- /dev/null +++ b/test/math.cpp @@ -0,0 +1,14 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "tinytc/tinytc.h" + +#include + +#include + +TEST_CASE("fp32 -> fp16 -> fp32 round trip") { + CHECK(tinytc_f32_to_f16(1.0f) == 1.0f); + CHECK(tinytc_f16_to_f32(tinytc_f32_to_f16(1.0f)) == 1.0f); + CHECK(tinytc_f16_to_f32(tinytc_f32_to_f16(5.0f)) == 5.0f); +} diff --git a/test/spv/arith.ir b/test/spv/arith.ir index cf3878d2..5f696227 100644 --- a/test/spv/arith.ir +++ b/test/spv/arith.ir @@ -7,6 +7,7 @@ ; CHECK: %[[#BOOL:]] = OpTypeBool ; CHECK: %[[#I64:]] = OpTypeInt 64 0 ; CHECK: %[[#F32:]] = OpTypeFloat 32 +; CHECK: %[[#F16:]] = OpTypeFloat 16 ; CHECK: %[[#C32:]] = OpTypeVector %[[#F32]] 2 func @tbool(%a: bool, %b: bool) { @@ -57,6 +58,20 @@ func @tfloat(%a: f32, %b: f32) { ; CHECK-NEXT: %[[#]] = OpFRem %[[#F32]] %[[#]] %[[#]] } +func @thalf(%a: f16, %b: f16) { + %0 = arith.add %a, %b : f16 + %1 = arith.sub %a, %b : f16 + %2 = arith.mul %a, %b : f16 + %3 = arith.div %a, %b : f16 + %4 = arith.rem %a, %b : f16 +; CHECK-LABEL: %[[#]] = OpFunction {{.*}} +; CHECK: %[[#]] = OpFAdd %[[#F16]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFSub %[[#F16]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFMul %[[#F16]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFDiv %[[#F16]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFRem %[[#F16]] %[[#]] %[[#]] +} + func @tcomplex(%a: c32, %b: c32) { %0 = arith.add %a, %b : c32 %1 = arith.sub %a, %b : c32 diff --git a/test/spv/arith_unary.ir b/test/spv/arith_unary.ir index 82a766fd..216deeaa 100644 --- a/test/spv/arith_unary.ir +++ b/test/spv/arith_unary.ir @@ -8,6 +8,7 @@ ; CHECK: %[[#BOOL:]] = OpTypeBool ; CHECK: %[[#I64:]] = OpTypeInt 64 0 ; CHECK: %[[#F32:]] = OpTypeFloat 32 +; CHECK: %[[#F16:]] = OpTypeFloat 16 ; CHECK: %[[#C32:]] = OpTypeVector %[[#F32]] 2 func @tbool(%a: bool) { @@ -31,6 +32,13 @@ func @tfloat(%a: f32) { ; CHECK-NEXT: OpFNegate %[[#F32]] %[[#]] } +func @thalf(%a: f16) { + %0 = arith.abs %a : f16 + %1 = arith.neg %a : f16 +; CHECK: OpExtInst %[[#F16]] %[[#EXT]] fabs %[[#]] +; CHECK-NEXT: OpFNegate %[[#F16]] %[[#]] +} + func @tcomplex(%a: c32) { %0 = arith.abs %a : f32 %1 = arith.neg %a : c32 diff --git a/test/spv/cast.ir b/test/spv/cast.ir index 984014b8..454eea0d 100644 --- a/test/spv/cast.ir +++ b/test/spv/cast.ir @@ -11,6 +11,7 @@ ; CHECK: %[[#F64:]] = OpTypeFloat 64 ; CHECK: %[[#C64:]] = OpTypeVector %[[#F64]] 2 ; CHECK: %[[#C64_NULL:]] = OpConstantNull %[[#C64]] +; CHECK: %[[#F16:]] = OpTypeFloat 16 func @tint(%a: i64) { %0 = cast %a : i64 @@ -34,6 +35,16 @@ func @tfloat(%a: f32) { ; CHECK-NEXT: %[[#]] = OpCompositeInsert %[[#C64]] %[[#F32_TO_F64]] %[[#C64_NULL]] 0 } +func @thalf(%a: f16) { + %1 = cast %a : i8 + %2 = cast %a : f64 + %3 = cast %a : c64 +; CHECK: %[[#]] = OpConvertFToS %[[#I8]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFConvert %[[#F64]] %[[#]] +; CHECK-NEXT: %[[#F16_TO_F64:]] = OpFConvert %[[#F64]] %[[#]] +; CHECK-NEXT: %[[#]] = OpCompositeInsert %[[#C64]] %[[#F16_TO_F64]] %[[#C64_NULL]] 0 +} + func @tcomplex(%a: c32) { %1 = cast %a : c64 ; CHECK: %[[#]] = OpFConvert %[[#C64]] %[[#]] diff --git a/test/spv/compare.ir b/test/spv/compare.ir index 2447cbc4..ef88597b 100644 --- a/test/spv/compare.ir +++ b/test/spv/compare.ir @@ -36,6 +36,21 @@ func @tfloat(%a: f32, %b: f32) { ; CHECK-NEXT: %[[#]] = OpFOrdLessThanEqual %[[#BOOL]] %[[#]] %[[#]] } +func @thalf(%a: f16, %b: f16) { + %0 = cmp.eq %a, %b : bool + %1 = cmp.ne %a, %b : bool + %2 = cmp.gt %a, %b : bool + %3 = cmp.ge %a, %b : bool + %4 = cmp.lt %a, %b : bool + %5 = cmp.le %a, %b : bool +; CHECK: %[[#]] = OpFOrdEqual %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFUnordNotEqual %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFOrdGreaterThan %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFOrdGreaterThanEqual %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFOrdLessThan %[[#BOOL]] %[[#]] %[[#]] +; CHECK-NEXT: %[[#]] = OpFOrdLessThanEqual %[[#BOOL]] %[[#]] %[[#]] +} + func @tcomplex(%a: c32, %b: c32) { %0 = cmp.eq %a, %b : bool %1 = cmp.ne %a, %b : bool From 5b3216f689f3ca6cfeb772236fe34b0a1185447d Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 25 Nov 2024 16:32:07 +0100 Subject: [PATCH 132/297] Fp16 ctd. Signed-off-by: Carsten Uphoff --- docs/api/core_capi.rst | 36 ++++++ docs/api/core_capi.yaml | 6 + docs/api/core_cxxapi.rst | 34 ++++++ docs/api/core_cxxapi.yaml | 6 + examples/benchmark/main.cpp | 8 +- examples/gemm_common.hpp | 4 +- examples/tall_and_skinny/main.cpp | 7 +- include/tinytc/tinytc.h | 24 +++- include/tinytc/tinytc.hpp | 126 +++++++++++++++++++++ src/CMakeLists.txt | 2 +- src/half.cpp | 181 +++++++++++++++++------------- src/pass/constant_folding.cpp | 4 + src/pass/constant_folding.hpp | 17 ++- src/recipe.cpp | 2 + src/spv/capex_util.hpp | 6 +- src/spv/converter.cpp | 2 + src/spv/defs.hpp | 9 +- src/spv/enums.hpp | 6 +- src/spv/instructions.hpp | 6 +- src/spv/names.hpp | 6 +- src/spv/pass/dump_asm.cpp | 5 + src/spv/visit.hpp | 6 +- test/CMakeLists.txt | 1 + test/math.cpp | 115 ++++++++++++++++++- test/opt/constant-propagation.ir | 24 ++++ test/spv/constant.ir | 35 ++++++ tools/spirvgen/spirvgen.py | 4 +- 27 files changed, 564 insertions(+), 118 deletions(-) create mode 100644 test/spv/constant.ir diff --git a/docs/api/core_capi.rst b/docs/api/core_capi.rst index 9afb93f1..d9401216 100644 --- a/docs/api/core_capi.rst +++ b/docs/api/core_capi.rst @@ -481,6 +481,42 @@ tinytc_core_feature_flags_t .. doxygentypedef:: tinytc_core_feature_flags_t +FP math +======= + +* Functions + + * :ref:`tinytc_f32_to_bf16_as_ui16` + + * :ref:`tinytc_f32_to_f16_as_ui16` + + * :ref:`tinytc_f16_as_ui16_to_f32` + + * :ref:`tinytc_bf16_as_ui16_to_f32` + +FP math Functions +----------------- + +tinytc_f32_to_bf16_as_ui16 +.......................... + +.. doxygenfunction:: tinytc_f32_to_bf16_as_ui16 + +tinytc_f32_to_f16_as_ui16 +......................... + +.. doxygenfunction:: tinytc_f32_to_f16_as_ui16 + +tinytc_f16_as_ui16_to_f32 +......................... + +.. doxygenfunction:: tinytc_f16_as_ui16_to_f32 + +tinytc_bf16_as_ui16_to_f32 +.......................... + +.. doxygenfunction:: tinytc_bf16_as_ui16_to_f32 + Parser ====== diff --git a/docs/api/core_capi.yaml b/docs/api/core_capi.yaml index 1983792e..323bf56c 100644 --- a/docs/api/core_capi.yaml +++ b/docs/api/core_capi.yaml @@ -75,6 +75,12 @@ Core C-API: - tinytc_core_info_retain typedef: - tinytc_core_feature_flags_t + FP math: + function: + - tinytc_f32_to_bf16_as_ui16 + - tinytc_f32_to_f16_as_ui16 + - tinytc_f16_as_ui16_to_f32 + - tinytc_bf16_as_ui16_to_f32 Parser: function: - tinytc_parse_file diff --git a/docs/api/core_cxxapi.rst b/docs/api/core_cxxapi.rst index 6c1128a4..94dfc555 100644 --- a/docs/api/core_cxxapi.rst +++ b/docs/api/core_cxxapi.rst @@ -288,6 +288,40 @@ core_info .. doxygenclass:: tinytc::core_info +FP math +======= + +* Classes + + * :ref:`lp_float` + +* Typedefs + + * :ref:`bfloat16` + + * :ref:`half` + +FP math Classes +--------------- + +lp_float +........ + +.. doxygenclass:: tinytc::lp_float + +FP math Typedefs +---------------- + +bfloat16 +........ + +.. doxygentypedef:: tinytc::bfloat16 + +half +.... + +.. doxygentypedef:: tinytc::half + Parser ====== diff --git a/docs/api/core_cxxapi.yaml b/docs/api/core_cxxapi.yaml index e49bae81..5453897f 100644 --- a/docs/api/core_cxxapi.yaml +++ b/docs/api/core_cxxapi.yaml @@ -48,6 +48,12 @@ Core C++-API: - tinytc::make_core_info_intel_from_name class: - tinytc::core_info + FP math: + class: + - tinytc::lp_float + typedef: + - tinytc::bfloat16 + - tinytc::half Parser: function: - tinytc::parse_file diff --git a/examples/benchmark/main.cpp b/examples/benchmark/main.cpp index 6f01a236..985b8c0c 100644 --- a/examples/benchmark/main.cpp +++ b/examples/benchmark/main.cpp @@ -216,10 +216,9 @@ template void test(queue q, args &a) { } double min_exec_time_ns = 0.0; - constexpr auto element_ty = to_scalar_type_v; try { auto src = gemm_kernel_with_inner_repetition( - element_ty, a.trans_a ? transpose::T : transpose::N, + a.ty, a.trans_a ? transpose::T : transpose::N, a.trans_b ? transpose::T : transpose::N, a.atomic, c.m, c.n, c.k, {1, a.trans_a ? c.k : c.m}, {1, a.trans_b ? c.n : c.k}, a.update, {1, c.m}, a.internal_repetitions, a.dump, q); @@ -257,7 +256,7 @@ template void test(queue q, args &a) { std::min(512 * 32 * 1.6e9, a.internal_repetitions * 2 * c.m * c.n * c.k / (sizeof(T) * (na + nb + nc) / 1.1e12)) / 1e9; - std::cout << to_string(element_ty) << "," << c.m << "," << c.n << "," << c.k << "," + std::cout << to_string(a.ty) << "," << c.m << "," << c.n << "," << c.k << "," << howmany << "," << min_exec_time_ns / 1e9 << "," << gflops << "," << roofline_gflops << "," << std::round(gflops / roofline_gflops * 100) << "%," << a.internal_repetitions << std::endl; @@ -325,6 +324,9 @@ int main(int argc, char **argv) { << std::endl; try { switch (a.ty) { + case scalar_type::f16: + test(std::move(q), a); + break; case scalar_type::f32: test(std::move(q), a); break; diff --git a/examples/gemm_common.hpp b/examples/gemm_common.hpp index b0efcdaf..2050ce55 100644 --- a/examples/gemm_common.hpp +++ b/examples/gemm_common.hpp @@ -19,7 +19,9 @@ struct test_case { }; inline auto convert_data_type(char const *str, scalar_type &val) -> cmd::parser_status { - if (std::strcmp(str, "f32") == 0) { + if (std::strcmp(str, "f16") == 0) { + val = scalar_type::f16; + } else if (std::strcmp(str, "f32") == 0) { val = scalar_type::f32; } else if (std::strcmp(str, "f64") == 0) { val = scalar_type::f64; diff --git a/examples/tall_and_skinny/main.cpp b/examples/tall_and_skinny/main.cpp index 96d79943..757fc6ea 100644 --- a/examples/tall_and_skinny/main.cpp +++ b/examples/tall_and_skinny/main.cpp @@ -128,7 +128,9 @@ template void test(queue q, args &a) { } auto tas = make_recipe_handler(q, r); - tall_and_skinny::set_args(tas, c.m, T{1}, A, c.m, B, c.k, beta, C, c.m); + tall_and_skinny::set_args(tas, c.m, T{1}, mem(A, mem_type::usm_pointer), c.m, + mem(B, mem_type::usm_pointer), c.k, beta, + mem(C, mem_type::usm_pointer), c.m); tas.submit(q).wait(); if (a.verify) { check(c.m, c.n); @@ -203,6 +205,9 @@ int main(int argc, char **argv) { std::cout << "precision,m,n,k,update,time,bandwidth,gflops" << std::endl; try { switch (a.ty) { + case scalar_type::f16: + test(std::move(q), a); + break; case scalar_type::f32: test(std::move(q), a); break; diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index 2388d1a4..bc2e6959 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -41,13 +41,31 @@ TINYTC_EXPORT char const *tinytc_error_string(tinytc_status_t status); //////////////////////////// /** - * @brief Convert f32 number to f16 number + * @brief Convert f32 number to bf16 number (represented as ushort) + * + * @param x f32 number + * + * @return bf16 number + */ +TINYTC_EXPORT uint16_t tinytc_f32_to_bf16_as_ui16(float x); + +/** + * @brief Convert bf16 number (represented as ushort) to f32 number + * + * @param x bf16 number + * + * @return f32 number + */ +TINYTC_EXPORT float tinytc_bf16_as_ui16_to_f32(uint16_t x); + +/** + * @brief Convert f32 number to f16 number (represented as ushort) * * @param x f32 number * * @return f16 number */ -TINYTC_EXPORT uint16_t tinytc_f32_to_f16(float x); +TINYTC_EXPORT uint16_t tinytc_f32_to_f16_as_ui16(float x); /** * @brief Convert f16 number (represented as ushort) to f32 number @@ -56,7 +74,7 @@ TINYTC_EXPORT uint16_t tinytc_f32_to_f16(float x); * * @return f32 number */ -TINYTC_EXPORT float tinytc_f16_to_f32(uint16_t x); +TINYTC_EXPORT float tinytc_f16_as_ui16_to_f32(uint16_t x); //////////////////////////// //////// Scalar type /////// diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 92ac5b95..a5e6b235 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -60,6 +60,119 @@ inline void CHECK_STATUS_LOC(tinytc_status_t code, location const &loc) { } } +//////////////////////////// +////////// FP math ///////// +//////////////////////////// + +/** + * @brief Low precision float type + * + * For all operations, low precision floats are converted single precision, the operation is done in + * single precision, and then the result is stored in the low precision type + * + * @tparam T storage type + * @tparam (*Extend)(T) Float widening function + * @tparam (*Truncate)(float) Float narrowing function + */ +template class lp_float { + public: + lp_float() = default; + + constexpr lp_float(lp_float const &) = default; + constexpr lp_float(lp_float &&) = default; + constexpr lp_float &operator=(lp_float const &) = default; + constexpr lp_float &operator=(lp_float &&) = default; + + //! Construct from float + lp_float(float const &rhs) : data_{Truncate(rhs)} {} + //! assign float + auto operator=(float const &rhs) -> lp_float & { + data_ = Truncate(rhs); + return *this; + } + //! implicit conversion to float + operator float() const { return Extend(data_); } + + //! add + auto operator+(lp_float const &rhs) const -> lp_float { + return operator float() + static_cast(rhs); + } + //! add to + auto operator+=(lp_float const &rhs) -> lp_float & { return *this = *this + rhs; } + //! subtract + auto operator-(lp_float const &rhs) const -> lp_float { + return operator float() - static_cast(rhs); + } + //! subtract from + auto operator-=(lp_float const &rhs) -> lp_float & { return *this = *this - rhs; } + //! multiply + auto operator*(lp_float const &rhs) const -> lp_float { + return operator float() * static_cast(rhs); + } + //! multiply with + auto operator*=(lp_float const &rhs) -> lp_float & { return *this = *this * rhs; } + //! divide + auto operator/(lp_float const &rhs) const -> lp_float { + return operator float() / static_cast(rhs); + } + //! divide with + auto operator/=(lp_float const &rhs) -> lp_float & { return *this = *this / rhs; } + //! unary minus + auto operator-() -> lp_float { return -operator float(); } + //! pre-increase by 1 + auto operator++() -> lp_float & { return *this = operator float() + 1.0f; } + //! post-increase by 1 + auto operator++(int) -> lp_float { + lp_float tmp = *this; + operator++(); + return tmp; + } + //! pre-decrease by 1 + auto operator--() -> lp_float & { return *this = operator float() - 1.0f; } + //! post-decrease by 1 + auto operator--(int) -> lp_float { + lp_float tmp = *this; + operator--(); + return tmp; + } + //! equal + auto operator==(lp_float const &rhs) const -> bool { + return operator float() == static_cast(rhs); + } + //! not equal + auto operator!=(lp_float const &rhs) const -> bool { + return operator float() == static_cast(rhs); + } + //! greater than + auto operator>(lp_float const &rhs) const -> bool { + return operator float() > static_cast(rhs); + } + //! greater than or equal + auto operator>=(lp_float const &rhs) const -> bool { + return operator float() >= static_cast(rhs); + } + //! less than + auto operator<(lp_float const &rhs) const -> bool { + return operator float() < static_cast(rhs); + } + //! less than or equal + auto operator<=(lp_float const &rhs) const -> bool { + return operator float() <= static_cast(rhs); + } + + private: + T data_; +}; + +/** + * @brief fp16 host emulation type + */ +using half = lp_float; +/** + * @brief bf16 host emulation type + */ +using bfloat16 = lp_float; + //////////////////////////// //////// Scalar type /////// //////////////////////////// @@ -97,6 +210,10 @@ template <> struct to_scalar_type { static constexpr scalar_type value = scalar_type::i64; ///< value }; //! to_scalar_type specialization +template <> struct to_scalar_type { + static constexpr scalar_type value = scalar_type::f16; ///< value +}; +//! to_scalar_type specialization template <> struct to_scalar_type { static constexpr scalar_type value = scalar_type::f32; ///< value }; @@ -2583,4 +2700,13 @@ inline auto make_tall_and_skinny_specialized(core_info const &info, scalar_type } // namespace tinytc +namespace std { +template <> struct hash { + size_t operator()(tinytc::half const &val) const noexcept { + return hash{}(static_cast(val)); + } +}; + +} // namespace std + #endif // TINYTC_20240403_HPP diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 317c30c4..c141c46e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -31,8 +31,8 @@ set(SOURCES error.cpp func.cpp gemm_tools.cpp - inst.cpp half.cpp + inst.cpp location.cpp node/data_type_node.cpp node/inst_node.cpp diff --git a/src/half.cpp b/src/half.cpp index b88f2a63..9dd906e1 100644 --- a/src/half.cpp +++ b/src/half.cpp @@ -4,7 +4,9 @@ #include #include -template struct fp_info { +namespace tinytc { + +template struct ieee754_info { constexpr static uint32_t exponent_bits = ExponentBits; constexpr static uint32_t mantissa_bits = MantissaBits; constexpr static uint32_t num_bits = 1 + exponent_bits + mantissa_bits; @@ -15,99 +17,118 @@ template struct fp_info { constexpr static uint32_t mantissa_mask = (1 << mantissa_bits) - 1; }; -using f16i = fp_info<5, 10>; -using f32i = fp_info<8, 23>; - -extern "C" { +using bf16i = ieee754_info<8, 7>; +using f16i = ieee754_info<5, 10>; +using f32i = ieee754_info<8, 23>; + +template +auto ieee754_truncate(UI x) -> UITrunc { + constexpr UI num_shift_bits = F32i::mantissa_bits - F16i::mantissa_bits; + auto const round_nearest_even_and_truncate = [](UI mantissa32) { + constexpr UI midpoint = (1 << num_shift_bits) / 2; + const UI bias = ((mantissa32 >> num_shift_bits) & 0x1) + (midpoint - 1); + return (mantissa32 + bias) >> num_shift_bits; + }; + + const UITrunc sign = (x & F32i::sign_mask) >> (F32i::num_bits - F16i::num_bits); + const UI exponent32 = (x & F32i::exponent_mask) >> F32i::mantissa_bits; + const UI mantissa32 = x & F32i::mantissa_mask; + + UITrunc exponent16 = 0; + UITrunc mantissa16 = 0; + if (exponent32 > F32i::bias + F16i::bias) { + exponent16 = F16i::max_biased_exponent; + // Map numbers except NaN to inf + if (exponent32 < F32i::max_biased_exponent) { + mantissa16 = 0; + } else { + // Need to ceil to make sure that NaN is not truncated to inf + mantissa16 = 1 + ((mantissa32 - 1) >> num_shift_bits); + } + } else if (F32i::bias == F16i::bias || exponent32 > F32i::bias - F16i::bias) { + // convert bias + // E_{32} = e + F32i::bias + // E_{16} = e + F16i::bias + // = E_{32} - F32i::bias + F16i::bias + // = E_{32} - (F32i::bias - F16i::bias) + exponent16 = exponent32 - (F32i::bias - F16i::bias); + mantissa16 = round_nearest_even_and_truncate(mantissa32); + } else if (exponent32 >= F32i::bias + 1 - F16i::bias - F16i::mantissa_bits) { + exponent16 = 0; + mantissa16 = round_nearest_even_and_truncate((mantissa32 | (1 << F32i::mantissa_bits)) >> + ((F32i::bias + 1 - F16i::bias) - exponent32)); + } -uint16_t tinytc_f32_to_f16(float x) { - const uint32_t y = std::bit_cast(x); - const uint16_t sign = (y & f32i::sign_mask) >> (f32i::num_bits - f16i::num_bits); - uint32_t exponent = (y & f32i::exponent_mask) >> f32i::mantissa_bits; - uint32_t mantissa = y & f32i::mantissa_mask; + exponent16 <<= F16i::mantissa_bits; - if (exponent > f32i::bias + f16i::bias) { - // Large numbers are mapped to inf - exponent = f16i::max_biased_exponent; - mantissa = 0; - } else if (exponent > f32i::bias - f16i::bias) { - // Normal numbers + // Need to add mantissa as it might overflow during rounding and then we need to increase the + // exponent by 1 + return (sign | exponent16) + mantissa16; +} - // convert bias - // E_{32} = e + f32i::bias - // E_{16} = e + f16i::bias - // = E_{32} - f32i::bias + f16i::bias - // = E_{32} - (f32i::bias - f16i::bias) - exponent -= f32i::bias - f16i::bias; - - constexpr uint32_t num_shift_bits = f32i::mantissa_bits - f16i::mantissa_bits; - constexpr uint32_t midpoint = (1 << num_shift_bits) / 2; - constexpr uint32_t low_bit_mask = (1 << num_shift_bits) - 1; - const uint32_t truncated_bits = mantissa & low_bit_mask; - - // shift mantissa and round correctly - mantissa >>= num_shift_bits; - if (truncated_bits > midpoint) { - mantissa += 1; - } else if (truncated_bits == midpoint) { - // when there is a tie round to nearest even - mantissa += mantissa & 1; +template +auto ieee754_extend(UI x) -> UIExt { + const UIExt sign = (x & F16i::sign_mask) << (F32i::num_bits - F16i::num_bits); + const UIExt exponent16 = (x & F16i::exponent_mask) >> F16i::mantissa_bits; + const UIExt mantissa16 = x & F16i::mantissa_mask; + + UIExt exponent32 = exponent16; + UIExt mantissa32 = mantissa16; + if (F32i::exponent_bits != F16i::exponent_bits) { + if (exponent16 == F16i::max_biased_exponent) { + // Inf and NaN + exponent32 = F32i::max_biased_exponent; + } else if (exponent16 != 0) { + // convert bias + // E_{16} = e + F16i::bias + // E_{32} = e + F32i::bias + // = E_{16} - F16i::bias + F32i::bias + // = E_{16} + (F32i::bias - F16i::bias) + exponent32 += F32i::bias - F16i::bias; } - // We had an overflow during rounding - if ((mantissa & (1 << f16i::mantissa_bits)) != 0) { - ++exponent; - if (exponent > f16i::max_biased_exponent) { - // Overflow to infinity - exponent = f16i::max_biased_exponent; - mantissa = 0; - } else { - mantissa &= f16i::mantissa_mask; - } + + // Subnormal f16 numbers must be represented as f32 normal numbers + if (exponent16 == 0 && mantissa16 != 0) { + UIExt shift_count = 0; + do { + mantissa32 <<= 1; + ++shift_count; + } while ((mantissa32 & (1 << F16i::mantissa_bits)) != (1 << F16i::mantissa_bits)); + mantissa32 &= F16i::mantissa_mask; + exponent32 = F32i::bias + 1 - F16i::bias - shift_count; } - } else { - // @todo } - exponent <<= f16i::exponent_bits; + // shift mantissa + mantissa32 <<= F32i::mantissa_bits - F16i::mantissa_bits; + + // shift exponent + exponent32 <<= F32i::mantissa_bits; - return sign | static_cast(exponent) | static_cast(mantissa); + return sign | exponent32 | mantissa32; } -float tinytc_f16_to_f32(uint16_t x) { - const uint32_t sign = (x & f16i::sign_mask) << (f32i::num_bits - f16i::num_bits); - uint32_t exponent = (x & f16i::exponent_mask) >> f16i::mantissa_bits; - uint32_t mantissa = x & f16i::mantissa_mask; +} // namespace tinytc - if (exponent == f16i::max_biased_exponent) { - // Inf and NaN - exponent = (f32i::max_biased_exponent << f32i::mantissa_bits); - } else if (exponent != 0) { - // convert bias - // E_{16} = e + f16i::bias - // E_{32} = e + f32i::bias - // = E_{16} - f16i::bias + f32i::bias - // = E_{16} + (f32i::bias - f16i::bias) - exponent += f32i::bias - f16i::bias; - - // shift exponent - exponent <<= f32i::mantissa_bits; - } +using namespace tinytc; - // Subnormal f16 numbers must be represented as f32 normal numbers - if (exponent == 0 && mantissa != 0) { - uint8_t shift_count = 1; - do { - mantissa <<= 1; - ++shift_count; - } while ((mantissa & (1 << f16i::mantissa_bits)) != (1 << f16i::mantissa_bits)); - mantissa &= f16i::mantissa_mask; - exponent = (f32i::bias - f16i::bias + 1) - shift_count; - } +extern "C" { - // shift mantissa - mantissa <<= f32i::mantissa_bits - f16i::mantissa_bits; +uint16_t tinytc_f32_to_f16_as_ui16(float x) { + return ieee754_truncate(std::bit_cast(x)); +} + +float tinytc_f16_as_ui16_to_f32(uint16_t x) { + const auto y = ieee754_extend(x); + return std::bit_cast(y); +} + +uint16_t tinytc_f32_to_bf16_as_ui16(float x) { + return ieee754_truncate(std::bit_cast(x)); +} - const uint32_t y = sign | exponent | mantissa; +float tinytc_bf16_as_ui16_to_f32(uint16_t x) { + const auto y = ieee754_extend(x); return std::bit_cast(y); } } diff --git a/src/pass/constant_folding.cpp b/src/pass/constant_folding.cpp index 1c453c12..575cdcd9 100644 --- a/src/pass/constant_folding.cpp +++ b/src/pass/constant_folding.cpp @@ -49,6 +49,8 @@ template class unary_op_dispatcher { } auto operator()(double const &A) -> fold_result { switch (switch_ty) { + case scalar_type::f16: + return computer.template operator()(A); case scalar_type::f32: return computer.template operator()(A); case scalar_type::f64: @@ -99,6 +101,8 @@ template class binary_op_dispatcher { } auto operator()(double const &A, double const &B) -> fold_result { switch (switch_ty) { + case scalar_type::f16: + return computer.template operator()(A, B); case scalar_type::f32: return computer.template operator()(A, B); case scalar_type::f64: diff --git a/src/pass/constant_folding.hpp b/src/pass/constant_folding.hpp index 69d303dd..cddeeeb0 100644 --- a/src/pass/constant_folding.hpp +++ b/src/pass/constant_folding.hpp @@ -48,6 +48,13 @@ requires(std::is_floating_point_v) struct is_complex> : public std::true_type {}; template inline constexpr bool is_complex_v = is_complex::value; +template struct is_floating_point_or_lp_float : public std::false_type {}; +template +requires(std::is_floating_point_v || std::is_same_v || std::is_same_v) +struct is_floating_point_or_lp_float : public std::true_type {}; +template +inline constexpr bool is_floating_point_or_lp_float_v = is_floating_point_or_lp_float::value; + struct compute_unary_op { arithmetic_unary operation; data_type ty; @@ -86,7 +93,7 @@ struct compute_unary_op { } template - requires(std::is_floating_point_v) + requires(is_floating_point_or_lp_float_v) auto operator()(T a) -> fold_result { T val = 0; switch (operation) { @@ -237,14 +244,14 @@ struct compute_binary_op { val = a / b; break; case arithmetic::rem: - if constexpr (!std::is_floating_point_v) { + if constexpr (is_complex_v) { throw compilation_error(loc, status::ir_complex_unsupported); } else { val = std::fmod(a, b); } break; default: - if constexpr (!std::is_floating_point_v) { + if constexpr (is_complex_v) { throw compilation_error(loc, status::ir_complex_unsupported); } throw compilation_error(loc, status::ir_fp_unsupported); @@ -378,7 +385,7 @@ struct compute_compare { location const &loc; template - requires(std::is_integral_v || std::is_floating_point_v) + requires(std::is_integral_v || is_floating_point_or_lp_float_v) auto operator()(T a, T b) -> fold_result { bool val = false; switch (cond) { @@ -463,6 +470,8 @@ auto compute_cast(scalar_data_type *to_ty, T A, location const &loc) -> fold_res return make_constant(value_cast(A), to_ty, loc); case scalar_type::index: return make_constant(value_cast(A), to_ty, loc); + case scalar_type::f16: + return make_constant(value_cast(A), to_ty, loc); case scalar_type::f32: return make_constant(value_cast(A), to_ty, loc); case scalar_type::f64: diff --git a/src/recipe.cpp b/src/recipe.cpp index 0727ab3d..2dc3c442 100644 --- a/src/recipe.cpp +++ b/src/recipe.cpp @@ -29,6 +29,8 @@ auto is_argument_zero(scalar_type type, std::size_t arg_size, const void *arg_va return is_argument_zero(arg_size, arg_value); case scalar_type::i64: return is_argument_zero(arg_size, arg_value); + case scalar_type::f16: + return is_argument_zero(arg_size, arg_value); case scalar_type::f32: return is_argument_zero(arg_size, arg_value); case scalar_type::f64: diff --git a/src/spv/capex_util.hpp b/src/spv/capex_util.hpp index 9396f285..1aeb4a96 100644 --- a/src/spv/capex_util.hpp +++ b/src/spv/capex_util.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_CAPEX_UTIL_20241113_HPP -#define GENERATED_CAPEX_UTIL_20241113_HPP +#ifndef GENERATED_CAPEX_UTIL_20241125_HPP +#define GENERATED_CAPEX_UTIL_20241125_HPP #include "enums.hpp" #include "tinytc/tinytc.hpp" @@ -23,4 +23,4 @@ auto extensions(ExecutionMode op) -> array_view; } // namespace tinytc::spv -#endif // GENERATED_CAPEX_UTIL_20241113_HPP +#endif // GENERATED_CAPEX_UTIL_20241125_HPP diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp index 177a976b..8e120a47 100644 --- a/src/spv/converter.cpp +++ b/src/spv/converter.cpp @@ -412,6 +412,8 @@ auto inst_converter::make_constant(scalar_type sty, spv_inst *spv_ty, }, [&](double d) -> spv_inst * { switch (sty) { + case scalar_type::f16: + return unique_.constant(half{static_cast(d)}); case scalar_type::f32: return unique_.constant(static_cast(d)); case scalar_type::f64: diff --git a/src/spv/defs.hpp b/src/spv/defs.hpp index bc6f8f5d..49ec8965 100644 --- a/src/spv/defs.hpp +++ b/src/spv/defs.hpp @@ -4,11 +4,12 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_DEFS_20241113_HPP -#define GENERATED_DEFS_20241113_HPP +#ifndef GENERATED_DEFS_20241125_HPP +#define GENERATED_DEFS_20241125_HPP #include "enums.hpp" #include "support/ilist_base.hpp" +#include "tinytc/tinytc.hpp" #include #include @@ -46,7 +47,7 @@ class spv_inst : public ilist_node { using DecorationAttr = std::variant>; using ExecutionModeAttr = std::variant>; using LiteralContextDependentNumber = - std::variant; + std::variant; using LiteralString = std::string; using LiteralInteger = std::int32_t; using LiteralExtInstInteger = std::int32_t; @@ -62,4 +63,4 @@ using PairIdRefLiteralInteger = std::pair; } // namespace tinytc::spv -#endif // GENERATED_DEFS_20241113_HPP +#endif // GENERATED_DEFS_20241125_HPP diff --git a/src/spv/enums.hpp b/src/spv/enums.hpp index 90ca6a81..00138228 100644 --- a/src/spv/enums.hpp +++ b/src/spv/enums.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_ENUMS_20241113_HPP -#define GENERATED_ENUMS_20241113_HPP +#ifndef GENERATED_ENUMS_20241125_HPP +#define GENERATED_ENUMS_20241125_HPP #include @@ -1431,4 +1431,4 @@ enum class FPEncoding {}; } // namespace tinytc::spv -#endif // GENERATED_ENUMS_20241113_HPP +#endif // GENERATED_ENUMS_20241125_HPP diff --git a/src/spv/instructions.hpp b/src/spv/instructions.hpp index 59276d0e..3879d9eb 100644 --- a/src/spv/instructions.hpp +++ b/src/spv/instructions.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_INSTRUCTIONS_20241113_HPP -#define GENERATED_INSTRUCTIONS_20241113_HPP +#ifndef GENERATED_INSTRUCTIONS_20241125_HPP +#define GENERATED_INSTRUCTIONS_20241125_HPP #include "defs.hpp" #include "enums.hpp" @@ -6771,4 +6771,4 @@ class OpAtomicFAddEXT : public spv_inst { } // namespace tinytc::spv -#endif // GENERATED_INSTRUCTIONS_20241113_HPP +#endif // GENERATED_INSTRUCTIONS_20241125_HPP diff --git a/src/spv/names.hpp b/src/spv/names.hpp index c84df1e8..c8759ad2 100644 --- a/src/spv/names.hpp +++ b/src/spv/names.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_NAMES_20241113_HPP -#define GENERATED_NAMES_20241113_HPP +#ifndef GENERATED_NAMES_20241125_HPP +#define GENERATED_NAMES_20241125_HPP #include "enums.hpp" @@ -68,4 +68,4 @@ auto to_string(FPEncoding e) -> char const *; } // namespace tinytc::spv -#endif // GENERATED_NAMES_20241113_HPP +#endif // GENERATED_NAMES_20241125_HPP diff --git a/src/spv/pass/dump_asm.cpp b/src/spv/pass/dump_asm.cpp index 70d0e0be..393e4625 100644 --- a/src/spv/pass/dump_asm.cpp +++ b/src/spv/pass/dump_asm.cpp @@ -78,6 +78,11 @@ void dump_asm_pass::operator()(LiteralContextDependentNumber const &l) { auto flags = os_->flags(); *os_ << " " << std::hexfloat << l; os_->flags(flags); + }, + [&](half const &l) { + auto flags = os_->flags(); + *os_ << " " << std::hexfloat << l; + os_->flags(flags); }}, l); } diff --git a/src/spv/visit.hpp b/src/spv/visit.hpp index 4a1268b4..311d57b4 100644 --- a/src/spv/visit.hpp +++ b/src/spv/visit.hpp @@ -4,8 +4,8 @@ // This file is generated // Do not edit manually -#ifndef GENERATED_VISIT_20241113_HPP -#define GENERATED_VISIT_20241113_HPP +#ifndef GENERATED_VISIT_20241125_HPP +#define GENERATED_VISIT_20241125_HPP #include "defs.hpp" #include "enums.hpp" @@ -4437,4 +4437,4 @@ template class default_visitor { } // namespace tinytc::spv -#endif // GENERATED_VISIT_20241113_HPP +#endif // GENERATED_VISIT_20241125_HPP diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index cbe0afb6..0916a550 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -68,6 +68,7 @@ if(SPIRVTools_FOUND) spv/cooperative_matrix_mul_add.ir spv/cooperative_matrix_scale.ir spv/cooperative_matrix_store.ir + spv/constant.ir spv/expand.ir spv/for.ir spv/fuse.ir diff --git a/test/math.cpp b/test/math.cpp index 10beba1f..ee340268 100644 --- a/test/math.cpp +++ b/test/math.cpp @@ -5,10 +5,117 @@ #include +#include #include -TEST_CASE("fp32 -> fp16 -> fp32 round trip") { - CHECK(tinytc_f32_to_f16(1.0f) == 1.0f); - CHECK(tinytc_f16_to_f32(tinytc_f32_to_f16(1.0f)) == 1.0f); - CHECK(tinytc_f16_to_f32(tinytc_f32_to_f16(5.0f)) == 5.0f); +TEST_CASE("f16 -> f32") { + // Regular numbers + CHECK(tinytc_f16_as_ui16_to_f32(0x0000) == 0.0f); + CHECK(tinytc_f16_as_ui16_to_f32(0x3c00) == std::bit_cast(0x3f800000)); // 1.0f + CHECK(tinytc_f16_as_ui16_to_f32(0x5148) == std::bit_cast(0x42290000)); // 42.25f + CHECK(tinytc_f16_as_ui16_to_f32(0xd148) == std::bit_cast(0xc2290000)); // -42.25f + CHECK(tinytc_f16_as_ui16_to_f32(0xfbff) == std::bit_cast(0xc77fe000)); // -65504.0f + + // Subnormals + CHECK(tinytc_f16_as_ui16_to_f32(0x0001) == std::bit_cast(0x33800000)); // 2^-24 + CHECK(tinytc_f16_as_ui16_to_f32(0x03ff) == + std::bit_cast(0x387fc000)); // 1.111111111 * 2^-15 + CHECK(tinytc_f16_as_ui16_to_f32(0x0021) == + std::bit_cast(0x36040000)); // 1.966953277587890625e-6f); + + // Inf and NaN + CHECK(tinytc_f16_as_ui16_to_f32(0x7c00) == std::bit_cast(0x7f800000)); // inf + CHECK(tinytc_f16_as_ui16_to_f32(0xfc00) == std::bit_cast(0xff800000)); // -inf + CHECK(std::bit_cast(tinytc_f16_as_ui16_to_f32(0x7c01)) == 0x7f802000); // nan + CHECK(std::bit_cast(tinytc_f16_as_ui16_to_f32(0xfc01)) == 0xff802000); // -nan +} + +TEST_CASE("f32 -> f16") { + // Lossless conversion + CHECK(tinytc_f32_to_f16_as_ui16(0.0f) == 0x0000); + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x3f800000)) == 0x3c00); // 1.0f + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x42290000)) == 0x5148); // 42.25f + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0xc2290000)) == 0xd148); // -42.25f + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0xc77fe000)) == 0xfbff); // -65504.0f + + // Big number -> inf + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x7c010840)) == 0x7c00); // inf + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0xfc010840)) == 0xfc00); // -inf + + // Round to nearest even + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x41fa0000)) == 0x4fd0); // round down + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x41fa1fff)) == 0x4fd1); // round up + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x41fa0fff)) == 0x4fd0); // round down + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x41fa1001)) == 0x4fd1); // round up + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x41fa1000)) == 0x4fd0); // tie + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x41fa3000)) == 0x4fd2); // tie + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x46ffffff)) == + 0x7800); // 32767.998 -> 2^15 + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x477fffff)) == + 0x7c00); // 65535.996 -> inf + + // Subnormals + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x33800000)) == 0x0001); // 2^-24 + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x387fc000)) == + 0x03ff); // 1.111111111 * 2^-15 + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x36040000)) == + 0x0021); // 1.966953277587890625e-6f); + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x3607ffff)) == + 0x0022); // 1.966953277587890625e-6f); + + // Inf and NaN + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x7f800000)) == 0x7c00); // inf + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0xff800000)) == 0xfc00); // -inf + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0x7f802000)) == 0x7c01); // nan + CHECK(tinytc_f32_to_f16_as_ui16(std::bit_cast(0xff802000)) == 0xfc01); // -nan +} + +TEST_CASE("bf16 -> f32") { + // Regular numbers + CHECK(tinytc_bf16_as_ui16_to_f32(0x0000) == 0.0f); + CHECK(tinytc_bf16_as_ui16_to_f32(0x3f80) == std::bit_cast(0x3f800000)); // 1.0f + CHECK(tinytc_bf16_as_ui16_to_f32(0x4229) == std::bit_cast(0x42290000)); // 42.25f + CHECK(tinytc_bf16_as_ui16_to_f32(0xc229) == std::bit_cast(0xc2290000)); // -42.25f + CHECK(tinytc_bf16_as_ui16_to_f32(0xc77f) == std::bit_cast(0xc77f0000)); // -65280.0f + + // Subnormals + CHECK(tinytc_bf16_as_ui16_to_f32(0x0001) == std::bit_cast(0x00010000)); // 2^-133 + CHECK(tinytc_bf16_as_ui16_to_f32(0x03ff) == std::bit_cast(0x03ff0000)); + CHECK(tinytc_bf16_as_ui16_to_f32(0x0021) == std::bit_cast(0x00210000)); + + // Inf and NaN + CHECK(tinytc_bf16_as_ui16_to_f32(0x7f80) == std::bit_cast(0x7f800000)); // inf + CHECK(tinytc_bf16_as_ui16_to_f32(0xff80) == std::bit_cast(0xff800000)); // -inf + CHECK(std::bit_cast(tinytc_bf16_as_ui16_to_f32(0x7f81)) == 0x7f810000); // nan + CHECK(std::bit_cast(tinytc_bf16_as_ui16_to_f32(0xff81)) == 0xff810000); // -nan +} + +TEST_CASE("f32 -> bf16") { + // Lossless conversion + CHECK(tinytc_f32_to_bf16_as_ui16(0.0f) == 0x0000); + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0x3f800000)) == 0x3f80); // 1.0f + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0x42290000)) == 0x4229); // 42.25f + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0xc2290000)) == 0xc229); // -42.25f + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0xc77f0000)) == 0xc77f); // -65280.0f + + // Round to nearest even + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0x41fa0000)) == 0x41fa); // round down + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0x41faffff)) == 0x41fb); // round up + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0x41fa0fff)) == 0x41fa); // round down + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0x41fa8001)) == 0x41fb); // round up + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0x41fa8000)) == 0x41fa); // tie + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0x41fb8000)) == 0x41fc); // tie + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0x46ffffff)) == + 0x4700); // 32767.998 -> 2^15 + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0x7f7fffff)) == + 0x7f80); // 65535.996 -> inf + + // Subnormals + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0x00010000)) == 0x0001); // 2^-24 + + // Inf and NaN + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0x7f800000)) == 0x7f80); // inf + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0xff800000)) == 0xff80); // -inf + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0x7f802000)) == 0x7f81); // nan + CHECK(tinytc_f32_to_bf16_as_ui16(std::bit_cast(0xff802000)) == 0xff81); // -nan } diff --git a/test/opt/constant-propagation.ir b/test/opt/constant-propagation.ir index fd2ff580..0884fdf3 100644 --- a/test/opt/constant-propagation.ir +++ b/test/opt/constant-propagation.ir @@ -155,3 +155,27 @@ func @known_arith_complex() { ; CHECK-NEXT: %16 = constant 0x1.8p+1 : f32 ; CHECK-NEXT: %17 = arith.re %a : f32 } + +func @known_arith_f16() { + %a = constant 3.0 : f16 + %b = constant -1.0 : f16 + %0 = arith.add %a, %b : f16 + %1 = arith.sub %a, %b : f16 + %2 = arith.mul %a, %b : f16 + %3 = arith.div %a, %b : f16 + %4 = arith.neg %a : f16 + %5 = arith.abs %b : f16 +; CHECK-LABEL: func @known_arith_f16({{.*}} +; CHECK: %0 = constant 0x1p+1 : f16 +; CHECK-NEXT: %1 = arith.add %a, %b : f16 +; CHECK-NEXT: %2 = constant 0x1p+2 : f16 +; CHECK-NEXT: %3 = arith.sub %a, %b : f16 +; CHECK-NEXT: %4 = constant -0x1.8p+1 : f16 +; CHECK-NEXT: %5 = arith.mul %a, %b : f16 +; CHECK-NEXT: %6 = constant -0x1.8p+1 : f16 +; CHECK-NEXT: %7 = arith.div %a, %b : f16 +; CHECK-NEXT: %8 = constant -0x1.8p+1 : f16 +; CHECK-NEXT: %9 = arith.neg %a : f16 +; CHECK-NEXT: %10 = constant 0x1p+0 : f16 +; CHECK-NEXT: %11 = arith.abs %b : f16 +} diff --git a/test/spv/constant.ir b/test/spv/constant.ir new file mode 100644 index 00000000..13e242d6 --- /dev/null +++ b/test/spv/constant.ir @@ -0,0 +1,35 @@ +; Copyright (C) 2024 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +; RUN: %tinytc-oc -S -O0 < %s | filecheck %s + +func @t1() { + %0 = constant true : bool + %1 = constant 42 : i8 + %2 = constant -42 : i8 + %3 = constant 42 : i16 + %4 = constant -42 : i16 + %5 = constant 42 : i32 + %6 = constant -42 : i32 + %7 = constant 42 : i64 + %8 = constant -42 : i64 + %9 = constant 42.42424242 : f16 + %10 = constant 42.42424242 : f32 + %11 = constant 42.42424242 : f64 + %12 = constant [42.42424242, 0.0] : c32 + %13 = constant [42.42424242, 0.0] : c64 +; CHECK: %[[#]] = OpConstant %[[#]] 42 +; CHECK: %[[#]] = OpConstant %[[#]] 214 +; CHECK: %[[#]] = OpConstant %[[#]] 42 +; CHECK: %[[#]] = OpConstant %[[#]] 65494 +; CHECK: %[[#]] = OpConstant %[[#]] 42 +; CHECK: %[[#]] = OpConstant %[[#]] 4294967254 +; CHECK: %[[#]] = OpConstant %[[#]] 42 +; CHECK: %[[#]] = OpConstant %[[#]] 18446744073709551574 +; CHECK: %[[#]] = OpConstant %[[#]] 0x1.538p+5 +; CHECK: %[[#CST32:]] = OpConstant %[[#]] 0x1.5364dap+5 +; CHECK: %[[#CST64:]] = OpConstant %[[#]] 0x1.5364d935bbceap+5 +; CHECK: %[[#]] = OpConstantComposite %[[#]] %[[#CST32]] %[[#]] +; CHECK: %[[#]] = OpConstantComposite %[[#]] %[[#CST64]] %[[#]] +} + diff --git a/tools/spirvgen/spirvgen.py b/tools/spirvgen/spirvgen.py index 34605fed..221ba2df 100755 --- a/tools/spirvgen/spirvgen.py +++ b/tools/spirvgen/spirvgen.py @@ -28,7 +28,7 @@ spv_names_cpp_includes = [spv_names, spv_enums] spv_defs = 'defs.hpp' spv_defs_includes = [ - spv_enums, 'support/ilist_base.hpp', None, '', '', + spv_enums, 'support/ilist_base.hpp', 'tinytc/tinytc.hpp', None, '', '', '', '', '' ] spv_ops = 'instructions.hpp' @@ -228,7 +228,7 @@ class spv_inst : public ilist_node { using DecorationAttr = std::variant>; using ExecutionModeAttr = std::variant>; using LiteralContextDependentNumber - = std::variant; + = std::variant; using LiteralString = std::string; using LiteralInteger = std::int32_t; using LiteralExtInstInteger = std::int32_t; From 51bc1255e720135c5d572cf55371b8feb75b5ec1 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 26 Nov 2024 10:56:26 +0100 Subject: [PATCH 133/297] Make lp_float constexpr if compiled with C++20 Signed-off-by: Carsten Uphoff --- docs/api/core_cxxapi.rst | 52 ++++++ docs/api/core_cxxapi.yaml | 8 + examples/benchmark/main.cpp | 2 +- examples/gemm_common.hpp | 19 ++ examples/tall_and_skinny/main.cpp | 2 +- include/tinytc/tinytc.hpp | 276 ++++++++++++++++++++++++++---- src/half.cpp | 117 +------------ src/pass/constant_folding.hpp | 14 +- src/support/fp_util.hpp | 36 ++++ 9 files changed, 362 insertions(+), 164 deletions(-) create mode 100644 src/support/fp_util.hpp diff --git a/docs/api/core_cxxapi.rst b/docs/api/core_cxxapi.rst index 94dfc555..482d40bc 100644 --- a/docs/api/core_cxxapi.rst +++ b/docs/api/core_cxxapi.rst @@ -291,16 +291,45 @@ core_info FP math ======= +* Functions + + * :ref:`ieee754_extend` + + * :ref:`ieee754_truncate` + * Classes * :ref:`lp_float` +* Structures + + * :ref:`ieee754_format` + * Typedefs + * :ref:`bf16_format` + * :ref:`bfloat16` + * :ref:`f16_format` + + * :ref:`f32_format` + * :ref:`half` +FP math Functions +----------------- + +ieee754_extend +.............. + +.. doxygenfunction:: tinytc::ieee754_extend + +ieee754_truncate +................ + +.. doxygenfunction:: tinytc::ieee754_truncate + FP math Classes --------------- @@ -309,14 +338,37 @@ lp_float .. doxygenclass:: tinytc::lp_float +FP math Structures +------------------ + +ieee754_format +.............. + +.. doxygenstruct:: tinytc::ieee754_format + FP math Typedefs ---------------- +bf16_format +........... + +.. doxygentypedef:: tinytc::bf16_format + bfloat16 ........ .. doxygentypedef:: tinytc::bfloat16 +f16_format +.......... + +.. doxygentypedef:: tinytc::f16_format + +f32_format +.......... + +.. doxygentypedef:: tinytc::f32_format + half .... diff --git a/docs/api/core_cxxapi.yaml b/docs/api/core_cxxapi.yaml index 5453897f..dff20fb7 100644 --- a/docs/api/core_cxxapi.yaml +++ b/docs/api/core_cxxapi.yaml @@ -49,10 +49,18 @@ Core C++-API: class: - tinytc::core_info FP math: + function: + - tinytc::ieee754_extend + - tinytc::ieee754_truncate class: - tinytc::lp_float + struct: + - tinytc::ieee754_format typedef: + - tinytc::bf16_format - tinytc::bfloat16 + - tinytc::f16_format + - tinytc::f32_format - tinytc::half Parser: function: diff --git a/examples/benchmark/main.cpp b/examples/benchmark/main.cpp index 985b8c0c..5f8aa8d7 100644 --- a/examples/benchmark/main.cpp +++ b/examples/benchmark/main.cpp @@ -161,7 +161,7 @@ template void test(queue q, args &a) { q.copy(C, C_host, total_reals).wait(); std::size_t num_err = 0; for (std::size_t i = 0; i < M * N * howmany; ++i) { - const auto err = std::abs(C_host[i] - C_ref_host[i]); + const auto err = examples::compute_error(C_host[i], C_ref_host[i]); if (err > 10.0 * std::numeric_limits::epsilon()) { if (num_err < 10) { std::cout << i << " " << err << " " << C_host[i] << " " << C_ref_host[i] diff --git a/examples/gemm_common.hpp b/examples/gemm_common.hpp index 2050ce55..45d8789e 100644 --- a/examples/gemm_common.hpp +++ b/examples/gemm_common.hpp @@ -7,8 +7,13 @@ #include "argparser.hpp" #include "tinytc/types.hpp" +#include +#include #include #include +#include +#include +#include namespace tinytc::examples { @@ -58,6 +63,20 @@ inline auto validate_test_case(test_case const &tc) -> bool { return tc.m > 0 && tc.n > 0 && tc.k > 0; }; +template inline auto fabs(T x) { + if constexpr (std::is_same_v) { + return sycl::fabs(x); + } else { + return std::abs(x); + } +} + +template inline auto compute_error(T x, T x_ref) { + auto err = examples::fabs(x - x_ref); + const auto scale = examples::fabs(x_ref); + return scale > std::numeric_limits::epsilon() ? err / scale : err; +} + } // namespace tinytc::examples #endif // GEMM_COMMON_20241014_HPP diff --git a/examples/tall_and_skinny/main.cpp b/examples/tall_and_skinny/main.cpp index 757fc6ea..db47e2ef 100644 --- a/examples/tall_and_skinny/main.cpp +++ b/examples/tall_and_skinny/main.cpp @@ -78,7 +78,7 @@ template void test(queue q, args &a) { q.copy(C, C_host.data(), M * N).wait(); std::size_t num_err = 0; for (std::int64_t i = 0; i < M * N; ++i) { - auto err = std::abs(C_host[i] - C_ref_host[i]); + auto err = examples::compute_error(C_host[i], C_ref_host[i]); if (err > 10.0 * std::numeric_limits::epsilon()) { if (num_err < 10) { std::cout << i << " " << err << " " << C_host[i] << " " << C_ref_host[i] diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index a5e6b235..878aa07c 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -18,6 +18,13 @@ #include #include +// For bit_cast, memcpy for C++ < 2020 +#if __cplusplus >= 202002L +#include +#else +#include +#endif + namespace tinytc { //////////////////////////// @@ -64,6 +71,155 @@ inline void CHECK_STATUS_LOC(tinytc_status_t code, location const &loc) { ////////// FP math ///////// //////////////////////////// +/** + * @brief IEEE754 floating point format parameters + * + * @tparam ExponentBits Number of exponent bits + * @tparam MantissaBits Number of mantissa bits + */ +template struct ieee754_format { + constexpr static uint32_t exponent_bits = ExponentBits; ///< Number of exponent bits + constexpr static uint32_t mantissa_bits = MantissaBits; ///< Number of mantissa bits + //! Total number of bits + constexpr static uint32_t num_bits = 1 + exponent_bits + mantissa_bits; + //! Bias + constexpr static uint32_t bias = (1 << (exponent_bits - 1)) - 1; + //! Max exponent when encoded with bias added + constexpr static uint32_t max_biased_exponent = (1 << exponent_bits) - 1; + //! Bit mask for sign bit + constexpr static uint32_t sign_mask = 1 << (num_bits - 1); + //! Bit mask for exponent bits + constexpr static uint32_t exponent_mask = max_biased_exponent << mantissa_bits; + //! Bit mask for exponent mantissa bits + constexpr static uint32_t mantissa_mask = (1 << mantissa_bits) - 1; + //! Number of bytes + constexpr static uint32_t num_bytes = 1 + (num_bits - 1) / 8; + //! Unsigned integer type large enough to store bit pattern + using bits_type = std::conditional_t< + num_bytes == 1, std::uint8_t, + std::conditional_t< + num_bytes == 2, std::uint16_t, + std::conditional_t>>>; +}; + +//! Floating point format for bf16 (bfloat16) +using bf16_format = ieee754_format<8, 7>; +//! Floating point format for f16 (half) +using f16_format = ieee754_format<5, 10>; +//! Floating point format for f32 (float) +using f32_format = ieee754_format<8, 23>; + +/** + * @brief Truncate high precision floating point number and return low precision floating point + * number + * + * @tparam F16f low precision floating point format + * @tparam F32f high precision floating point format + * @param x bit pattern of high precision number + * + * @return bit pattern of low precision number + */ +template +constexpr auto ieee754_truncate(typename F32f::bits_type x) -> F16f::bits_type { + using UI = F32f::bits_type; + using UITrunc = F16f::bits_type; + constexpr UI num_shift_bits = F32f::mantissa_bits - F16f::mantissa_bits; + auto const round_nearest_even_and_truncate = [](UI mantissa32) { + constexpr UI midpoint = (1 << num_shift_bits) / 2; + const UI bias = ((mantissa32 >> num_shift_bits) & 0x1) + (midpoint - 1); + return (mantissa32 + bias) >> num_shift_bits; + }; + + const UITrunc sign = (x & F32f::sign_mask) >> (F32f::num_bits - F16f::num_bits); + const UI exponent32 = (x & F32f::exponent_mask) >> F32f::mantissa_bits; + const UI mantissa32 = x & F32f::mantissa_mask; + + UITrunc exponent16 = 0; + UITrunc mantissa16 = 0; + if (exponent32 > F32f::bias + F16f::bias) { + exponent16 = F16f::max_biased_exponent; + // Map numbers except NaN to inf + if (exponent32 < F32f::max_biased_exponent) { + mantissa16 = 0; + } else { + // Need to ceil to make sure that NaN is not truncated to inf + mantissa16 = 1 + ((mantissa32 - 1) >> num_shift_bits); + } + } else if (F32f::bias == F16f::bias || exponent32 > F32f::bias - F16f::bias) { + // convert bias + // E_{32} = e + F32f::bias + // E_{16} = e + F16f::bias + // = E_{32} - F32f::bias + F16f::bias + // = E_{32} - (F32f::bias - F16f::bias) + exponent16 = exponent32 - (F32f::bias - F16f::bias); + mantissa16 = round_nearest_even_and_truncate(mantissa32); + } else if (exponent32 >= F32f::bias + 1 - F16f::bias - F16f::mantissa_bits) { + exponent16 = 0; + mantissa16 = round_nearest_even_and_truncate((mantissa32 | (1 << F32f::mantissa_bits)) >> + ((F32f::bias + 1 - F16f::bias) - exponent32)); + } + + exponent16 <<= F16f::mantissa_bits; + + // Need to add mantissa as it might overflow during rounding and then we need to increase the + // exponent by 1 + return (sign | exponent16) + mantissa16; +} + +/** + * @brief Extend low precision floating point number and return high precision floating point + * number + * + * @tparam F32f high precision floating point format + * @tparam F16f low precision floating point format + * @param x bit pattern of low precision number + * + * @return bit pattern of high precision number + */ +template +constexpr auto ieee754_extend(typename F16f::bits_type x) -> F32f::bits_type { + using UIExt = F32f::bits_type; + const UIExt sign = (x & F16f::sign_mask) << (F32f::num_bits - F16f::num_bits); + const UIExt exponent16 = (x & F16f::exponent_mask) >> F16f::mantissa_bits; + const UIExt mantissa16 = x & F16f::mantissa_mask; + + UIExt exponent32 = exponent16; + UIExt mantissa32 = mantissa16; + if (F32f::exponent_bits != F16f::exponent_bits) { + if (exponent16 == F16f::max_biased_exponent) { + // Inf and NaN + exponent32 = F32f::max_biased_exponent; + } else if (exponent16 != 0) { + // convert bias + // E_{16} = e + F16f::bias + // E_{32} = e + F32f::bias + // = E_{16} - F16f::bias + F32f::bias + // = E_{16} + (F32f::bias - F16f::bias) + exponent32 += F32f::bias - F16f::bias; + } + + // Subnormal f16 numbers must be represented as f32 normal numbers + if (exponent16 == 0 && mantissa16 != 0) { + UIExt shift_count = 0; + do { + mantissa32 <<= 1; + ++shift_count; + } while ((mantissa32 & (1 << F16f::mantissa_bits)) != (1 << F16f::mantissa_bits)); + mantissa32 &= F16f::mantissa_mask; + exponent32 = F32f::bias + 1 - F16f::bias - shift_count; + } + } + + // shift mantissa + mantissa32 <<= F32f::mantissa_bits - F16f::mantissa_bits; + + // shift exponent + exponent32 <<= F32f::mantissa_bits; + + return sign | exponent32 | mantissa32; +} + /** * @brief Low precision float type * @@ -71,92 +227,135 @@ inline void CHECK_STATUS_LOC(tinytc_status_t code, location const &loc) { * single precision, and then the result is stored in the low precision type * * @tparam T storage type - * @tparam (*Extend)(T) Float widening function - * @tparam (*Truncate)(float) Float narrowing function + * @tparam F16f low precision floating point format */ -template class lp_float { +template class lp_float { public: - lp_float() = default; + using lp_format = F16f; + + constexpr lp_float() = default; constexpr lp_float(lp_float const &) = default; constexpr lp_float(lp_float &&) = default; constexpr lp_float &operator=(lp_float const &) = default; constexpr lp_float &operator=(lp_float &&) = default; - //! Construct from float - lp_float(float const &rhs) : data_{Truncate(rhs)} {} +#if __cplusplus >= 202002L +#define TINYTC_LPFLOAT_CONSTEXPR constexpr + //! construct from float + constexpr lp_float(float const &val) + : data_{ieee754_truncate(std::bit_cast(val))} {} //! assign float - auto operator=(float const &rhs) -> lp_float & { - data_ = Truncate(rhs); - return *this; + constexpr auto operator=(float const &rhs) -> lp_float & { return *this = lp_float{rhs}; } + //! implicit conversion to float + constexpr operator float() const { + auto bits = ieee754_extend(data_); + return std::bit_cast(bits); + } +#else +#define TINYTC_LPFLOAT_CONSTEXPR + //! construct from float + lp_float(float const &val) { + f32_format::bits_type bits; + memcpy(&bits, &val, sizeof(f32_format::bits_type)); + data_ = ieee754_truncate(bits); } + //! assign float + auto operator=(float const &rhs) -> lp_float & { return *this = lp_float{rhs}; } //! implicit conversion to float - operator float() const { return Extend(data_); } + operator float() const { + auto bits = ieee754_extend(data_); + float number; + memcpy(&number, &bits, sizeof(f32_format::bits_type)); + return number; + } +#endif + + //! Get bit representation + TINYTC_LPFLOAT_CONSTEXPR auto bits() const -> T { return data_; } + //! Construct lp_float from bit representation + constexpr static auto from_bits(T const &val) -> lp_float { + auto r = lp_float{}; + r.data_ = val; + return r; + } //! add - auto operator+(lp_float const &rhs) const -> lp_float { + TINYTC_LPFLOAT_CONSTEXPR auto operator+(lp_float const &rhs) const -> lp_float { return operator float() + static_cast(rhs); } //! add to - auto operator+=(lp_float const &rhs) -> lp_float & { return *this = *this + rhs; } + TINYTC_LPFLOAT_CONSTEXPR auto operator+=(lp_float const &rhs) -> lp_float & { + return *this = *this + rhs; + } //! subtract - auto operator-(lp_float const &rhs) const -> lp_float { + TINYTC_LPFLOAT_CONSTEXPR auto operator-(lp_float const &rhs) const -> lp_float { return operator float() - static_cast(rhs); } //! subtract from - auto operator-=(lp_float const &rhs) -> lp_float & { return *this = *this - rhs; } + TINYTC_LPFLOAT_CONSTEXPR auto operator-=(lp_float const &rhs) -> lp_float & { + return *this = *this - rhs; + } //! multiply - auto operator*(lp_float const &rhs) const -> lp_float { + TINYTC_LPFLOAT_CONSTEXPR auto operator*(lp_float const &rhs) const -> lp_float { return operator float() * static_cast(rhs); } //! multiply with - auto operator*=(lp_float const &rhs) -> lp_float & { return *this = *this * rhs; } + TINYTC_LPFLOAT_CONSTEXPR auto operator*=(lp_float const &rhs) -> lp_float & { + return *this = *this * rhs; + } //! divide - auto operator/(lp_float const &rhs) const -> lp_float { + TINYTC_LPFLOAT_CONSTEXPR auto operator/(lp_float const &rhs) const -> lp_float { return operator float() / static_cast(rhs); } //! divide with - auto operator/=(lp_float const &rhs) -> lp_float & { return *this = *this / rhs; } + TINYTC_LPFLOAT_CONSTEXPR auto operator/=(lp_float const &rhs) -> lp_float & { + return *this = *this / rhs; + } //! unary minus - auto operator-() -> lp_float { return -operator float(); } + TINYTC_LPFLOAT_CONSTEXPR auto operator-() -> lp_float { return -operator float(); } //! pre-increase by 1 - auto operator++() -> lp_float & { return *this = operator float() + 1.0f; } + TINYTC_LPFLOAT_CONSTEXPR auto operator++() -> lp_float & { + return *this = operator float() + 1.0f; + } //! post-increase by 1 - auto operator++(int) -> lp_float { + TINYTC_LPFLOAT_CONSTEXPR auto operator++(int) -> lp_float { lp_float tmp = *this; operator++(); return tmp; } //! pre-decrease by 1 - auto operator--() -> lp_float & { return *this = operator float() - 1.0f; } + TINYTC_LPFLOAT_CONSTEXPR auto operator--() -> lp_float & { + return *this = operator float() - 1.0f; + } //! post-decrease by 1 - auto operator--(int) -> lp_float { + TINYTC_LPFLOAT_CONSTEXPR auto operator--(int) -> lp_float { lp_float tmp = *this; operator--(); return tmp; } //! equal - auto operator==(lp_float const &rhs) const -> bool { + TINYTC_LPFLOAT_CONSTEXPR auto operator==(lp_float const &rhs) const -> bool { return operator float() == static_cast(rhs); } //! not equal - auto operator!=(lp_float const &rhs) const -> bool { + TINYTC_LPFLOAT_CONSTEXPR auto operator!=(lp_float const &rhs) const -> bool { return operator float() == static_cast(rhs); } //! greater than - auto operator>(lp_float const &rhs) const -> bool { + TINYTC_LPFLOAT_CONSTEXPR auto operator>(lp_float const &rhs) const -> bool { return operator float() > static_cast(rhs); } //! greater than or equal - auto operator>=(lp_float const &rhs) const -> bool { + TINYTC_LPFLOAT_CONSTEXPR auto operator>=(lp_float const &rhs) const -> bool { return operator float() >= static_cast(rhs); } //! less than - auto operator<(lp_float const &rhs) const -> bool { + TINYTC_LPFLOAT_CONSTEXPR auto operator<(lp_float const &rhs) const -> bool { return operator float() < static_cast(rhs); } //! less than or equal - auto operator<=(lp_float const &rhs) const -> bool { + TINYTC_LPFLOAT_CONSTEXPR auto operator<=(lp_float const &rhs) const -> bool { return operator float() <= static_cast(rhs); } @@ -165,13 +364,13 @@ template class lp_float { }; /** - * @brief fp16 host emulation type + * @brief bf16 host emulation type */ -using half = lp_float; +using bfloat16 = lp_float; /** - * @brief bf16 host emulation type + * @brief fp16 host emulation type */ -using bfloat16 = lp_float; +using half = lp_float; //////////////////////////// //////// Scalar type /////// @@ -2036,8 +2235,8 @@ class region_builder { /** * @brief Build if with functor then(region_builder&) -> void * - * Note: If the if instruction returns values then we must have a "yield" instruction in both - * the "then" and the "else" branch. So to return values use the "ifelse" function. + * Note: If the if instruction returns values then we must have a "yield" instruction in + * both the "then" and the "else" branch. So to return values use the "ifelse" function. * * @tparam F Functor type * @param condition Condition value @@ -2701,9 +2900,10 @@ inline auto make_tall_and_skinny_specialized(core_info const &info, scalar_type } // namespace tinytc namespace std { -template <> struct hash { - size_t operator()(tinytc::half const &val) const noexcept { - return hash{}(static_cast(val)); +template struct hash> { + size_t operator()(tinytc::lp_float const &val) const noexcept { + using h = hash::lp_format::bits_type>; + return h{}(val.bits()); } }; diff --git a/src/half.cpp b/src/half.cpp index 9dd906e1..222be937 100644 --- a/src/half.cpp +++ b/src/half.cpp @@ -1,134 +1,29 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include -#include - -namespace tinytc { - -template struct ieee754_info { - constexpr static uint32_t exponent_bits = ExponentBits; - constexpr static uint32_t mantissa_bits = MantissaBits; - constexpr static uint32_t num_bits = 1 + exponent_bits + mantissa_bits; - constexpr static uint32_t bias = (1 << (exponent_bits - 1)) - 1; - constexpr static uint32_t max_biased_exponent = (1 << exponent_bits) - 1; - constexpr static uint32_t sign_mask = 1 << (num_bits - 1); - constexpr static uint32_t exponent_mask = max_biased_exponent << mantissa_bits; - constexpr static uint32_t mantissa_mask = (1 << mantissa_bits) - 1; -}; - -using bf16i = ieee754_info<8, 7>; -using f16i = ieee754_info<5, 10>; -using f32i = ieee754_info<8, 23>; - -template -auto ieee754_truncate(UI x) -> UITrunc { - constexpr UI num_shift_bits = F32i::mantissa_bits - F16i::mantissa_bits; - auto const round_nearest_even_and_truncate = [](UI mantissa32) { - constexpr UI midpoint = (1 << num_shift_bits) / 2; - const UI bias = ((mantissa32 >> num_shift_bits) & 0x1) + (midpoint - 1); - return (mantissa32 + bias) >> num_shift_bits; - }; - - const UITrunc sign = (x & F32i::sign_mask) >> (F32i::num_bits - F16i::num_bits); - const UI exponent32 = (x & F32i::exponent_mask) >> F32i::mantissa_bits; - const UI mantissa32 = x & F32i::mantissa_mask; - - UITrunc exponent16 = 0; - UITrunc mantissa16 = 0; - if (exponent32 > F32i::bias + F16i::bias) { - exponent16 = F16i::max_biased_exponent; - // Map numbers except NaN to inf - if (exponent32 < F32i::max_biased_exponent) { - mantissa16 = 0; - } else { - // Need to ceil to make sure that NaN is not truncated to inf - mantissa16 = 1 + ((mantissa32 - 1) >> num_shift_bits); - } - } else if (F32i::bias == F16i::bias || exponent32 > F32i::bias - F16i::bias) { - // convert bias - // E_{32} = e + F32i::bias - // E_{16} = e + F16i::bias - // = E_{32} - F32i::bias + F16i::bias - // = E_{32} - (F32i::bias - F16i::bias) - exponent16 = exponent32 - (F32i::bias - F16i::bias); - mantissa16 = round_nearest_even_and_truncate(mantissa32); - } else if (exponent32 >= F32i::bias + 1 - F16i::bias - F16i::mantissa_bits) { - exponent16 = 0; - mantissa16 = round_nearest_even_and_truncate((mantissa32 | (1 << F32i::mantissa_bits)) >> - ((F32i::bias + 1 - F16i::bias) - exponent32)); - } - - exponent16 <<= F16i::mantissa_bits; - - // Need to add mantissa as it might overflow during rounding and then we need to increase the - // exponent by 1 - return (sign | exponent16) + mantissa16; -} - -template -auto ieee754_extend(UI x) -> UIExt { - const UIExt sign = (x & F16i::sign_mask) << (F32i::num_bits - F16i::num_bits); - const UIExt exponent16 = (x & F16i::exponent_mask) >> F16i::mantissa_bits; - const UIExt mantissa16 = x & F16i::mantissa_mask; - - UIExt exponent32 = exponent16; - UIExt mantissa32 = mantissa16; - if (F32i::exponent_bits != F16i::exponent_bits) { - if (exponent16 == F16i::max_biased_exponent) { - // Inf and NaN - exponent32 = F32i::max_biased_exponent; - } else if (exponent16 != 0) { - // convert bias - // E_{16} = e + F16i::bias - // E_{32} = e + F32i::bias - // = E_{16} - F16i::bias + F32i::bias - // = E_{16} + (F32i::bias - F16i::bias) - exponent32 += F32i::bias - F16i::bias; - } - - // Subnormal f16 numbers must be represented as f32 normal numbers - if (exponent16 == 0 && mantissa16 != 0) { - UIExt shift_count = 0; - do { - mantissa32 <<= 1; - ++shift_count; - } while ((mantissa32 & (1 << F16i::mantissa_bits)) != (1 << F16i::mantissa_bits)); - mantissa32 &= F16i::mantissa_mask; - exponent32 = F32i::bias + 1 - F16i::bias - shift_count; - } - } - - // shift mantissa - mantissa32 <<= F32i::mantissa_bits - F16i::mantissa_bits; +#include "tinytc/tinytc.hpp" - // shift exponent - exponent32 <<= F32i::mantissa_bits; - - return sign | exponent32 | mantissa32; -} - -} // namespace tinytc +#include using namespace tinytc; extern "C" { uint16_t tinytc_f32_to_f16_as_ui16(float x) { - return ieee754_truncate(std::bit_cast(x)); + return ieee754_truncate(std::bit_cast(x)); } float tinytc_f16_as_ui16_to_f32(uint16_t x) { - const auto y = ieee754_extend(x); + const auto y = ieee754_extend(x); return std::bit_cast(y); } uint16_t tinytc_f32_to_bf16_as_ui16(float x) { - return ieee754_truncate(std::bit_cast(x)); + return ieee754_truncate(std::bit_cast(x)); } float tinytc_bf16_as_ui16_to_f32(uint16_t x) { - const auto y = ieee754_extend(x); + const auto y = ieee754_extend(x); return std::bit_cast(y); } } diff --git a/src/pass/constant_folding.hpp b/src/pass/constant_folding.hpp index cddeeeb0..f41c2c8e 100644 --- a/src/pass/constant_folding.hpp +++ b/src/pass/constant_folding.hpp @@ -10,6 +10,7 @@ #include "node/value_node.hpp" #include "scalar_type.hpp" #include "support/casting.hpp" +#include "support/fp_util.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" @@ -42,19 +43,6 @@ class constant_folding { bool unsafe_fp_math_; }; -template struct is_complex : public std::false_type {}; -template -requires(std::is_floating_point_v) -struct is_complex> : public std::true_type {}; -template inline constexpr bool is_complex_v = is_complex::value; - -template struct is_floating_point_or_lp_float : public std::false_type {}; -template -requires(std::is_floating_point_v || std::is_same_v || std::is_same_v) -struct is_floating_point_or_lp_float : public std::true_type {}; -template -inline constexpr bool is_floating_point_or_lp_float_v = is_floating_point_or_lp_float::value; - struct compute_unary_op { arithmetic_unary operation; data_type ty; diff --git a/src/support/fp_util.hpp b/src/support/fp_util.hpp new file mode 100644 index 00000000..44492dab --- /dev/null +++ b/src/support/fp_util.hpp @@ -0,0 +1,36 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef FP_UTIL_20241126_HPP +#define FP_UTIL_20241126_HPP + +#include "tinytc/tinytc.hpp" + +#include +#include + +namespace tinytc { + +template class U> +struct is_instance_of : public std::false_type {}; +template